From c4790c854d878b3a7093c547c9b081c06efbf677 Mon Sep 17 00:00:00 2001 From: OctavianLee Date: Wed, 6 Apr 2016 22:26:12 +0800 Subject: [PATCH 001/547] Add the compatibility with PYPY by psycopg2cffi. --- pony/orm/dbproviders/postgres.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 7a233fecd..4c2d48dff 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -5,7 +5,12 @@ from datetime import datetime, date, time, timedelta from uuid import UUID -import psycopg2 +try: + import psycopg2 +except ImportError: + from psycopg2cffi import compat + compat.register() + from psycopg2 import extensions import psycopg2.extras From 2524ccc8a8bf5af33c7674dd8989fb9278cbbb6b Mon Sep 17 00:00:00 2001 From: OctavianLee Date: Wed, 6 Apr 2016 22:26:12 +0800 Subject: [PATCH 002/547] Add the compatibility with PYPY by psycopg2cffi. --- pony/orm/dbproviders/postgres.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 7a233fecd..4c2d48dff 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -5,7 +5,12 @@ from datetime import datetime, date, time, timedelta from uuid import UUID -import psycopg2 +try: + import psycopg2 +except ImportError: + from psycopg2cffi import compat + compat.register() + from psycopg2 import extensions import psycopg2.extras From fedcb70f0be1c36a826bf7a5e475b4422c850024 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 28 Jul 2016 12:33:47 +0300 Subject: [PATCH 003/547] Oracle bug fixed --- pony/orm/dbproviders/oracle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index b5226e3cd..d08648db1 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -498,13 +498,14 @@ def output_type_handler(cursor, name, defaultType, size, precision, scale): class OraPool(object): forked_pools = [] def __init__(pool, **kwargs): + pool.kwargs = kwargs pool.cx_pool = cx_Oracle.SessionPool(**kwargs) pool.pid = os.getpid() def connect(pool): pid = os.getpid() if pool.pid != pid: pool.forked_pools.append((pool.cx_pool, pool.pid)) - pool.cx_pool = cx_Oracle.SessionPool(**kwargs) + pool.cx_pool = cx_Oracle.SessionPool(**pool.kwargs) pool.pid = os.getpid() if core.debug: log_orm('GET CONNECTION') con = pool.cx_pool.acquire() From fd6f2f4e236cb645380f9ecaac862b7a16fbca8f Mon Sep 17 00:00:00 2001 From: Vitalii Date: Wed, 23 Mar 2016 19:24:23 +0300 Subject: [PATCH 004/547] Remove restriction on using dict type in queries --- pony/orm/asttranslation.py | 21 ++++++++++--------- pony/orm/ormtypes.py | 3 +-- pony/orm/tests/test_declarative_exceptions.py | 4 ++-- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index f989d18c5..e7bda4a9e 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -274,16 +274,17 @@ def postCallFunc(translator, node): expr = '.'.join(reversed(attrs)) x = eval(expr, translator.globals, translator.locals) try: hash(x) - except TypeError: x = None - if x in translator.special_functions: - if x.__name__ == 'raw_sql': node.raw_sql = True - else: node.external = False - elif x in translator.const_functions: - for arg in node.args: - if not arg.constant: return - if node.star_args is not None and not node.star_args.constant: return - if node.dstar_args is not None and not node.dstar_args.constant: return - node.constant = True + except TypeError: pass + else: + if x in translator.special_functions: + if x.__name__ == 'raw_sql': node.raw_sql = True + else: node.external = False + elif x in translator.const_functions: + for arg in node.args: + if not arg.constant: return + if node.star_args is not None and not node.star_args.constant: return + if node.dstar_args is not None and not node.dstar_args.constant: return + node.constant = True extractors_cache = {} diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 050a93a79..ae18c9412 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -130,8 +130,6 @@ def __ne__(self, other): def get_normalized_type_of(value): t = type(value) if t is tuple: return tuple(get_normalized_type_of(item) for item in value) - try: hash(value) # without this, cannot do tests like 'if value in special_fucntions...' - except TypeError: throw(TypeError, 'Unsupported type %r' % t.__name__) if t.__name__ == 'EntityMeta': return SetType(value) if t.__name__ == 'EntityIter': return SetType(value.entity) if PY2 and isinstance(value, str): @@ -155,6 +153,7 @@ def normalize_type(t): t = type_normalization_dict.get(t, t) if t in primitive_types: return t if issubclass(t, basestring): return unicode + if issubclass(t, dict): return dict throw(TypeError, 'Unsupported type %r' % t.__name__) coercions = { diff --git a/pony/orm/tests/test_declarative_exceptions.py b/pony/orm/tests/test_declarative_exceptions.py index 0d66417a0..04974a91f 100644 --- a/pony/orm/tests/test_declarative_exceptions.py +++ b/pony/orm/tests/test_declarative_exceptions.py @@ -67,7 +67,7 @@ def test4(self): select(s for s in Student if s.name.upper(*args)) if sys.version_info[:2] < (3, 5): # TODO - @raises_exception(TypeError, "Expression `{'a':'b', 'c':'d'}` has unsupported type 'dict'") + @raises_exception(NotImplementedError) # "**{'a': 'b', 'c': 'd'} is not supported def test5(self): select(s for s in Student if s.name.upper(**{'a':'b', 'c':'d'})) @@ -210,7 +210,7 @@ def test49(self): sum(s.name for s in Student) if sys.version_info[:2] < (3, 5): # TODO - @raises_exception(TypeError, "Expression `{'a':'b'}` has unsupported type 'dict'") + @raises_exception(NotImplementedError) # "Parameter {'a': 'b'} has unsupported type 'dict' def test50(self): select(s for s in Student if s.name == {'a' : 'b'}) From 68ffcb846580895db57ec7e937bf03ece5966b89 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 28 Jun 2016 13:37:02 +0300 Subject: [PATCH 005/547] Fix decompiling BUILD_MAP for Python 3.5 --- pony/orm/decompiling.py | 12 ++++++++---- pony/orm/tests/test_declarative_exceptions.py | 9 +++------ 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index d3927e4bc..d989caf14 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, izip, xrange -import types +import sys, types from opcode import opname as opnames, HAVE_ARGUMENT, EXTENDED_ARG, cmp_op from opcode import hasconst, hasname, hasjrel, haslocal, hascompare, hasfree @@ -150,9 +150,13 @@ def BINARY_SUBSCR(decompiler): def BUILD_LIST(decompiler, size): return ast.List(decompiler.pop_items(size)) - def BUILD_MAP(decompiler, not_used): - # Pushes a new empty dictionary object onto the stack. The argument is ignored and set to zero by the compiler - return ast.Dict(()) + def BUILD_MAP(decompiler, length): + if sys.version_info < (3, 5): + return ast.Dict(()) + data = decompiler.pop_items(2 * length) # [key1, value1, key2, value2, ...] + it = iter(data) + pairs = list(izip(it, it)) # [(key1, value1), (key2, value2), ...] + return ast.Dict(pairs) def BUILD_SET(decompiler, size): return ast.Set(decompiler.pop_items(size)) diff --git a/pony/orm/tests/test_declarative_exceptions.py b/pony/orm/tests/test_declarative_exceptions.py index 04974a91f..1e1d8b22b 100644 --- a/pony/orm/tests/test_declarative_exceptions.py +++ b/pony/orm/tests/test_declarative_exceptions.py @@ -208,12 +208,9 @@ def test48(self): @raises_exception(TypeError, "'sum' is valid for numeric attributes only") def test49(self): sum(s.name for s in Student) - - if sys.version_info[:2] < (3, 5): # TODO - @raises_exception(NotImplementedError) # "Parameter {'a': 'b'} has unsupported type 'dict' - def test50(self): - select(s for s in Student if s.name == {'a' : 'b'}) - + @raises_exception(NotImplementedError) # Parameter {'a': 'b'} has unsupported type 'dict' + def test50(self): + select(s for s in Student if s.name == {'a' : 'b'}) @raises_exception(IncomparableTypesError, "Incomparable types '%s' and 'int' in expression: s.name > a & 2" % unicode.__name__) def test51(self): a = 1 From e780dcf53e9885742f3f0e7e3233cc564e5c388a Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 8 Jul 2016 19:57:44 +0300 Subject: [PATCH 006/547] Improved exception message --- pony/orm/sqltranslation.py | 5 ++++- pony/orm/tests/test_declarative_exceptions.py | 2 +- pony/orm/tests/test_query.py | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index f1385d383..1299ba9f9 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -77,6 +77,9 @@ def dispatch(translator, node): if hasattr(node, 'monad'): return # monad already assigned somehow if not getattr(node, 'external', False) or getattr(node, 'constant', False): return ASTTranslator.dispatch(translator, node) # default route + translator.call(translator.dispatch_external, node) + + def dispatch_external(translator, node): varkey = translator.filter_num, node.src t = translator.vartypes[varkey] tt = type(t) @@ -1550,7 +1553,7 @@ def new(translator, type, paramkey): elif type is buffer: cls = translator.BufferParamMonad elif type is UUID: cls = translator.UuidParamMonad elif isinstance(type, EntityMeta): cls = translator.ObjectParamMonad - else: throw(NotImplementedError, type) # pragma: no cover + else: throw(NotImplementedError, 'Parameter {EXPR} has unsupported type %r' % (type)) result = cls(translator, type, paramkey) result.aggregated = False return result diff --git a/pony/orm/tests/test_declarative_exceptions.py b/pony/orm/tests/test_declarative_exceptions.py index 1e1d8b22b..8721272bb 100644 --- a/pony/orm/tests/test_declarative_exceptions.py +++ b/pony/orm/tests/test_declarative_exceptions.py @@ -208,7 +208,7 @@ def test48(self): @raises_exception(TypeError, "'sum' is valid for numeric attributes only") def test49(self): sum(s.name for s in Student) - @raises_exception(NotImplementedError) # Parameter {'a': 'b'} has unsupported type 'dict' + @raises_exception(NotImplementedError, "Parameter {'a':'b'} has unsupported type ") def test50(self): select(s for s in Student if s.name == {'a' : 'b'}) @raises_exception(IncomparableTypesError, "Incomparable types '%s' and 'int' in expression: s.name > a & 2" % unicode.__name__) diff --git a/pony/orm/tests/test_query.py b/pony/orm/tests/test_query.py index 46d68f19d..7f6c6986c 100644 --- a/pony/orm/tests/test_query.py +++ b/pony/orm/tests/test_query.py @@ -58,7 +58,7 @@ def test6(self): def f1(x): return x + 1 select(s for s in Student if f1(s.gpa) > 3) - @raises_exception(NotImplementedError, "m1(s.gpa, 1) > 3") + @raises_exception(NotImplementedError, "m1") def test7(self): class C1(object): def method1(self, a, b): From 97db11cdefea25394d08b7c507c583ba50ddf2eb Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 16 Aug 2016 11:19:41 +0300 Subject: [PATCH 007/547] Empty Dict AST node should be marked as external --- pony/orm/asttranslation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index e7bda4a9e..fd03f6dbe 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -260,6 +260,8 @@ def postName(translator, node): node.external = True def postConst(translator, node): node.external = node.constant = True + def postDict(translator, node): + node.external = True def postKeyword(translator, node): node.constant = node.expr.constant def postCallFunc(translator, node): From b083c72dbe0961c88f9295a415f2fc81fcd12edc Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 16 Aug 2016 11:36:29 +0300 Subject: [PATCH 008/547] Empty List AST node should be marked as external --- pony/orm/asttranslation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index fd03f6dbe..7d67bfd0a 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -262,6 +262,8 @@ def postConst(translator, node): node.external = node.constant = True def postDict(translator, node): node.external = True + def postList(translator, node): + node.external = True def postKeyword(translator, node): node.constant = node.expr.constant def postCallFunc(translator, node): From 6f62bf6fd6f8140b3136226b1f91cd1984f33afc Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 4 Jul 2016 14:46:25 +0300 Subject: [PATCH 009/547] Remove unused code --- pony/orm/sqltranslation.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 1299ba9f9..c55ba983f 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1721,7 +1721,6 @@ def negate(monad): return monad.translator.CmpMonad(cmp_negate[monad.op], monad.left, monad.right) def getsql(monad, subquery=None): op = monad.op - sql = [] left_sql = monad.left.getsql() if op == 'is': return [ sqland([ [ 'IS_NULL', item ] for item in left_sql ]) ] From c4a6b22c27c43f0de1af987dbb6c3f586236f1a8 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 29 Jun 2016 19:48:53 +0300 Subject: [PATCH 010/547] Remove trailing whitespaces --- pony/orm/dbproviders/oracle.py | 14 +++++++------- pony/orm/dbproviders/postgres.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index d08648db1..7b654475a 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -53,7 +53,7 @@ def get_create_command(sequence): schema = sequence.table.schema seq_name = schema.provider.quote_name(sequence.name) return schema.case('CREATE SEQUENCE %s NOCACHE') % seq_name - + trigger_template = """ CREATE TRIGGER %s BEFORE INSERT ON %s @@ -88,7 +88,7 @@ def exists(trigger, provider, connection, case_sensitive=True): def get_create_command(trigger): schema = trigger.table.schema quote_name = schema.provider.quote_name - trigger_name = quote_name(trigger.name) + trigger_name = quote_name(trigger.name) table_name = quote_name(trigger.table.name) column_name = quote_name(trigger.column.name) seq_name = quote_name(trigger.sequence.name) @@ -111,7 +111,7 @@ class OraConstMonad(sqltranslation.ConstMonad): @staticmethod def new(translator, value): if value == '': value = None - return sqltranslation.ConstMonad.new(translator, value) + return sqltranslation.ConstMonad.new(translator, value) class OraTranslator(sqltranslation.SQLTranslator): dialect = 'Oracle' @@ -174,7 +174,7 @@ def SELECT(builder, *sections): else: indent0 = '' x = 't.*' - + if not limit: pass elif not offset: result = [ indent0, 'SELECT * FROM (\n' ] @@ -227,7 +227,7 @@ def DATETIME_SUB(builder, expr, delta): return '(', builder(expr), ' - ', builder(delta), ')' class OraBoolConverter(dbapiprovider.BoolConverter): - if not PY2: + if not PY2: def py2sql(converter, val): # Fixes cx_Oracle 5.1.3 Python 3 bug: # "DatabaseError: OCI-22062: invalid input string [True]" @@ -287,7 +287,7 @@ def __init__(converter, provider, py_type, attr=None): dbapiprovider.TimeConverter.__init__(converter, provider, py_type, attr) if attr is not None and converter.precision > 0: # cx_Oracle 5.1.3 corrupts microseconds for values of DAY TO SECOND type - converter.precision = 0 + converter.precision = 0 def sql2py(converter, val): if isinstance(val, timedelta): total_seconds = val.days * (24 * 60 * 60) + val.seconds @@ -308,7 +308,7 @@ def __init__(converter, provider, py_type, attr=None): dbapiprovider.TimedeltaConverter.__init__(converter, provider, py_type, attr) if attr is not None and converter.precision > 0: # cx_Oracle 5.1.3 corrupts microseconds for values of DAY TO SECOND type - converter.precision = 0 + converter.precision = 0 class OraDatetimeConverter(dbapiprovider.DatetimeConverter): sql_type_name = 'TIMESTAMP' diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 4c2d48dff..c2bfba44e 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -190,7 +190,7 @@ def table_exists(provider, connection, table_name, case_sensitive=True): cursor.execute(sql, (schema_name, table_name)) row = cursor.fetchone() return row[0] if row is not None else None - + def index_exists(provider, connection, table_name, index_name, case_sensitive=True): schema_name, table_name = provider.split_table_name(table_name) cursor = connection.cursor() From cb6177d70c15e8a1bcc11bce51f05e9a39f8bc00 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 4 May 2016 17:37:07 +0300 Subject: [PATCH 011/547] Remove unused imports --- pony/orm/dbproviders/mysql.py | 2 +- pony/orm/dbproviders/sqlite.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 1a12b6978..4e9165a4d 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -28,7 +28,7 @@ mysql_module_name = 'pymysql' from pony.orm import core, dbschema, dbapiprovider -from pony.orm.core import log_orm, OperationalError +from pony.orm.core import log_orm from pony.orm.dbapiprovider import DBAPIProvider, Pool, get_version_tuple, wrap_dbapi_exceptions from pony.orm.sqltranslation import SQLTranslator from pony.orm.sqlbuilding import SQLBuilder, join diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 7964b151e..126487220 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -15,7 +15,7 @@ from pony.orm.core import log_orm from pony.orm.sqlbuilding import SQLBuilder, join, make_unary_func from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions -from pony.utils import localbase, datetime2timestamp, timestamp2datetime, decorator, absolutize_path, throw +from pony.utils import datetime2timestamp, timestamp2datetime, absolutize_path, throw NoneType = type(None) From 135d1284ac3ad25c155ac43c9d0b0978d72e03eb Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 4 May 2016 14:08:15 +0300 Subject: [PATCH 012/547] Renaming: result -> value --- pony/orm/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 82b95a916..20d20cfae 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1937,11 +1937,11 @@ def load(attr, obj): def __get__(attr, obj, cls=None): if obj is None: return attr if attr.pk_offset is not None: return attr.get(obj) - result = attr.get(obj) + value = attr.get(obj) bit = obj._bits_except_volatile_[attr] wbits = obj._wbits_ if wbits is not None and not wbits & bit: obj._rbits_ |= bit - return result + return value def get(attr, obj): if attr.pk_offset is None and obj._status_ in ('deleted', 'cancelled'): throw_object_was_deleted(obj) From 4151b61b836c57a49966e935c18857afbdceb4d5 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 20 May 2016 10:01:38 +0300 Subject: [PATCH 013/547] Add 'optimistic' option to attribute --- pony/orm/core.py | 11 ++++++++--- pony/orm/dbapiprovider.py | 1 + 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 20d20cfae..2b549f29e 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1700,7 +1700,8 @@ class Attribute(object): 'id', 'pk_offset', 'pk_columns_offset', 'py_type', 'sql_type', 'entity', 'name', \ 'lazy', 'lazy_sql_cache', 'args', 'auto', 'default', 'reverse', 'composite_keys', \ 'column', 'columns', 'col_paths', '_columns_checked', 'converters', 'kwargs', \ - 'cascade_delete', 'index', 'original_default', 'sql_default', 'py_check', 'hidden' + 'cascade_delete', 'index', 'original_default', 'sql_default', 'py_check', 'hidden', \ + 'optimistic' def __deepcopy__(attr, memo): return attr # Attribute cannot be cloned by deepcopy() @cut_traceback @@ -1765,6 +1766,7 @@ def __init__(attr, py_type, *args, **kwargs): attr.lazy = kwargs.pop('lazy', getattr(py_type, 'lazy', False)) attr.lazy_sql_cache = None attr.is_volatile = kwargs.pop('volatile', False) + attr.optimistic = kwargs.pop('optimistic', True) attr.sql_default = kwargs.pop('sql_default', None) attr.py_check = kwargs.pop('py_check', None) attr.hidden = kwargs.pop('hidden', False) @@ -4554,10 +4556,13 @@ def _construct_optimistic_criteria_(obj): optimistic_converters = [] optimistic_values = [] for attr in obj._attrs_with_bit_(obj._attrs_with_columns_, obj._rbits_): + converters = attr.converters + assert converters + if not (attr.optimistic and converters[0].optimistic): continue dbval = obj._dbvals_[attr] optimistic_columns.extend(attr.columns) - if dbval is not None: converters = attr.converters - else: converters = repeat(None, len(attr.converters)) + if dbval is None: + converters = repeat(None, len(attr.converters)) optimistic_converters.extend(converters) optimistic_values.extend(attr.get_raw_values(dbval)) return optimistic_columns, optimistic_converters, optimistic_values diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index f5a4b8bbf..b0e625a95 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -326,6 +326,7 @@ def disconnect(pool): if con is not None: con.close() class Converter(object): + optimistic = True def __deepcopy__(converter, memo): return converter # Converter instances are "immutable" def __init__(converter, provider, py_type, attr=None): From 073700b6868d262ff25c05aace00e97935bb7a95 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 16 Aug 2016 11:33:27 +0300 Subject: [PATCH 014/547] Fix ErrorSpecialFuncMonad handling: exception was not raised properly --- pony/orm/sqltranslation.py | 4 ++-- pony/orm/tests/test_declarative_sqltranslator.py | 2 +- pony/orm/tests/test_query.py | 2 +- pony/orm/tests/test_raw_sql.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index c55ba983f..1b72c1005 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -741,8 +741,8 @@ def postCallFunc(translator, node): kwargs[arg.name] = arg.expr.monad else: args.append(arg.monad) func_monad = node.node.monad - if isinstance(func_monad, ErrorSpecialFuncMonad): - 'Function %r cannot be used in this way: %s' % (func_monad.func.__name__, ast2src(node)) + if isinstance(func_monad, ErrorSpecialFuncMonad): throw(TypeError, + 'Function %r cannot be used this way: %s' % (func_monad.func.__name__, ast2src(node))) return func_monad(*args, **kwargs) def postKeyword(translator, node): pass # this node will be processed by postCallFunc diff --git a/pony/orm/tests/test_declarative_sqltranslator.py b/pony/orm/tests/test_declarative_sqltranslator.py index 012906195..00a1af29a 100644 --- a/pony/orm/tests/test_declarative_sqltranslator.py +++ b/pony/orm/tests/test_declarative_sqltranslator.py @@ -343,7 +343,7 @@ def test_tuple_param_2(self): x = Student[1], None result = set(select(s for s in Student if s not in x)) self.assertEqual(result, set([Student[3]])) - @raises_exception(TypeError, "f(s)") + @raises_exception(TypeError, "Function 'f' cannot be used this way: f(s)") def test_unknown_func(self): def f(x): return x select(s for s in Student if f(s)) diff --git a/pony/orm/tests/test_query.py b/pony/orm/tests/test_query.py index 7f6c6986c..0220e73e6 100644 --- a/pony/orm/tests/test_query.py +++ b/pony/orm/tests/test_query.py @@ -53,7 +53,7 @@ def test4(self): def test5(self): x = ['A'] select(s for s in Student if s.name == x) - @raises_exception(TypeError, "f1(s.gpa)") + @raises_exception(TypeError, "Function 'f1' cannot be used this way: f1(s.gpa)") def test6(self): def f1(x): return x + 1 diff --git a/pony/orm/tests/test_raw_sql.py b/pony/orm/tests/test_raw_sql.py index 56dce042e..151be61f9 100644 --- a/pony/orm/tests/test_raw_sql.py +++ b/pony/orm/tests/test_raw_sql.py @@ -152,7 +152,7 @@ def test_18(self): self.assertEqual(persons, [Person[1], Person[3], Person[2]]) @db_session - @raises_exception(TypeError, "raw_sql(p.name)") + @raises_exception(TypeError, "Function 'raw_sql' cannot be used this way: raw_sql(p.name)") def test_19(self): # raw_sql argument cannot depend on iterator variables select(p for p in Person if raw_sql(p.name))[:] From 5997f5428d1e312d274ebbd47b1706cf85237536 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 30 Jun 2016 13:41:31 +0300 Subject: [PATCH 015/547] Ellipsis support in decompiler/translator --- pony/orm/asttranslation.py | 2 ++ pony/orm/core.py | 33 +++++++++++++++++++++---- pony/orm/sqltranslation.py | 6 +++++ pony/thirdparty/compiler/transformer.py | 6 +++++ 4 files changed, 42 insertions(+), 5 deletions(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 7d67bfd0a..7e80ce910 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -171,6 +171,8 @@ def postConst(translator, node): s = str(value) if float(s) == value: return s return repr(value) + def postEllipsis(translator, node): + return '...' def postList(translator, node): node.priority = 1 return '[%s]' % ', '.join(item.src for item in node.nodes) diff --git a/pony/orm/core.py b/pony/orm/core.py index 2b549f29e..1449243fd 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2,7 +2,7 @@ from pony.py23compat import PY2, izip, imap, iteritems, itervalues, items_list, values_list, xrange, cmp, \ basestring, unicode, buffer, int_types, builtins, pickle, with_metaclass -import json, re, sys, types, datetime, logging, itertools +import io, json, re, sys, types, datetime, logging, itertools from operator import attrgetter, itemgetter from itertools import chain, starmap, repeat from time import time @@ -4972,6 +4972,29 @@ def extract_vars(extractors, globals, locals, cells=None): def unpickle_query(query_result): return query_result +def persistent_id(obj): + if obj is Ellipsis: + return "Ellipsis" + +def persistent_load(persid): + if persid == "Ellipsis": + return Ellipsis + raise pickle.UnpicklingError("unsupported persistent object") + +def pickle_ast(val): + pickled = io.BytesIO() + pickler = pickle.Pickler(pickled) + pickler.persistent_id = persistent_id + pickler.dump(val) + return pickled + +def unpickle_ast(pickled): + pickled.seek(0) + unpickler = pickle.Unpickler(pickled) + unpickler.persistent_load = persistent_load + return unpickler.load() + + class Query(object): def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False): assert isinstance(tree, ast.GenExprInner) @@ -4996,13 +5019,13 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False translator = database._translator_cache.get(query._key) if translator is None: - pickled_tree = pickle.dumps(tree, 2) - tree = pickle.loads(pickled_tree) # tree = deepcopy(tree) + pickled_tree = pickle_ast(tree) + tree = unpickle_ast(pickled_tree) # tree = deepcopy(tree) translator_cls = database.provider.translator_cls translator = translator_cls(tree, extractors, vartypes, left_join=left_join) name_path = translator.can_be_optimized() if name_path: - tree = pickle.loads(pickled_tree) # tree = deepcopy(tree) + tree = unpickle_ast(pickled_tree) # tree = deepcopy(tree) try: translator = translator_cls(tree, extractors, vartypes, left_join=True, optimize=name_path) except OptimizationFailed: translator.optimization_failed = True translator.pickled_tree = pickled_tree @@ -5311,7 +5334,7 @@ def _process_lambda(query, func, globals, locals, order_by): if not prev_optimized: name_path = new_translator.can_be_optimized() if name_path: - tree = pickle.loads(prev_translator.pickled_tree) # tree = deepcopy(tree) + tree = unpickle_ast(prev_translator.pickled_tree) # tree = deepcopy(tree) prev_extractors = prev_translator.extractors prev_vartypes = prev_translator.vartypes translator_cls = prev_translator.__class__ diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 1b72c1005..156f42647 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -659,6 +659,8 @@ def postConst(translator, node): return translator.ConstMonad.new(translator, value) else: return translator.ListMonad(translator, [ translator.ConstMonad.new(translator, item) for item in value ]) + def postEllipsis(translator, node): + return translator.ConstMonad.new(translator, Ellipsis) def postList(translator, node): return translator.ListMonad(translator, [ item.monad for item in node.nodes ]) def postTuple(translator, node): @@ -1633,6 +1635,7 @@ def new(translator, value): elif value_type is datetime: cls = translator.DatetimeConstMonad elif value_type is NoneType: cls = translator.NoneMonad elif value_type is buffer: cls = translator.BufferConstMonad + elif issubclass(value_type, type(Ellipsis)): cls = translator.EllipsisMonad else: throw(NotImplementedError, value_type) # pragma: no cover result = cls(translator, value) result.aggregated = False @@ -1653,6 +1656,9 @@ def __init__(monad, translator, value=None): assert value is None ConstMonad.__init__(monad, translator, value) +class EllipsisMonad(ConstMonad): + pass + class BufferConstMonad(BufferMixin, ConstMonad): pass class StringConstMonad(StringMixin, ConstMonad): diff --git a/pony/thirdparty/compiler/transformer.py b/pony/thirdparty/compiler/transformer.py index 8e08361a1..182712439 100644 --- a/pony/thirdparty/compiler/transformer.py +++ b/pony/thirdparty/compiler/transformer.py @@ -136,6 +136,9 @@ def __init__(self): if PY2: self._atom_dispatch.update({ token.BACKQUOTE: self.atom_backquote, }) + if not PY2: self._atom_dispatch.update({ + token.ELLIPSIS: self.atom_ellipsis + }) self.encoding = None def transform(self, tree): @@ -786,6 +789,9 @@ def atom_lbrace(self, nodelist): def atom_backquote(self, nodelist): return Backquote(self.com_node(nodelist[1])) + def atom_ellipsis(self, nodelist): + return Ellipsis() + def atom_number(self, nodelist): ### need to verify this matches compile.c k = eval(nodelist[0][1]) From 4622dd9882f733bb61941d3adb047f4668ce7186 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 30 Jun 2016 18:57:52 +0300 Subject: [PATCH 016/547] SQL equality comparison operation can be customized in converter for specific datatype now --- pony/orm/core.py | 64 +++++++++++++++++++------------------- pony/orm/dbapiprovider.py | 2 ++ pony/orm/sqltranslation.py | 8 +++-- 3 files changed, 39 insertions(+), 35 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 1449243fd..e3f3c0863 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1921,7 +1921,7 @@ def load(attr, obj): from_list = [ 'FROM', [ None, 'TABLE', entity._table_ ] ] pk_columns = entity._pk_columns_ pk_converters = entity._pk_converters_ - criteria_list = [ [ 'EQ', [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ] + criteria_list = [ [ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ] for i, (column, converter) in enumerate(izip(pk_columns, pk_converters)) ] sql_ast = [ 'SELECT', select_list, from_list, [ 'WHERE' ] + criteria_list ] sql, adapter = database._ast2sql(sql_ast) @@ -2413,7 +2413,7 @@ def param(i, j, converter): else: return [ 'PARAM', (i, j, None), converter ] if batch_size == 1: - return [ [ 'EQ', [ 'COLUMN', alias, column ], param(start, j, converter) ] + return [ [ converter.EQ, [ 'COLUMN', alias, column ], param(start, j, converter) ] for j, (column, converter) in enumerate(izip(columns, converters)) ] if len(columns) == 1: column = columns[0] @@ -2428,7 +2428,7 @@ def param(i, j, converter): condition = [ 'IN', row, param_list ] return [ condition ] else: - conditions = [ [ 'AND' ] + [ [ 'EQ', [ 'COLUMN', alias, column ], param(i+start, j, converter) ] + conditions = [ [ 'AND' ] + [ [ converter.EQ, [ 'COLUMN', alias, column ], param(i+start, j, converter) ] for j, (column, converter) in enumerate(izip(columns, converters)) ] for i in xrange(batch_size) ] return [ [ 'OR' ] + conditions ] @@ -2796,7 +2796,7 @@ def remove_m2m(attr, removed): columns = reverse.columns + attr.columns converters = reverse.converters + attr.converters for i, (column, converter) in enumerate(izip(columns, converters)): - where_list.append([ 'EQ', ['COLUMN', None, column], [ 'PARAM', (i, None, None), converter ] ]) + where_list.append([ converter.EQ, ['COLUMN', None, column], [ 'PARAM', (i, None, None), converter ] ]) from_ast = [ 'FROM', [ None, 'TABLE', attr.table ] ] sql_ast = [ 'DELETE', None, from_ast, where_list ] sql, adapter = database._ast2sql(sql_ast) @@ -2894,7 +2894,7 @@ def is_empty(wrapper): if cached_sql is None: where_list = [ 'WHERE' ] for i, (column, converter) in enumerate(izip(reverse.columns, reverse.converters)): - where_list.append([ 'EQ', [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ]) + where_list.append([ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ]) if not reverse.is_collection: table_name = rentity._table_ select_list, attr_offsets = rentity._construct_select_clause_() @@ -2947,7 +2947,7 @@ def count(wrapper): if cached_sql is None: where_list = [ 'WHERE' ] for i, (column, converter) in enumerate(izip(reverse.columns, reverse.converters)): - where_list.append([ 'EQ', [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ]) + where_list.append([ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ]) if not reverse.is_collection: table_name = reverse.entity._table_ else: table_name = attr.table sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', 'ALL' ] ], @@ -3785,7 +3785,8 @@ def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_upda if attr_is_none: where_list.append([ 'IS_NULL', [ 'COLUMN', None, attr.column ] ]) else: if len(attr.converters) > 1: throw(NotImplementedError) - where_list.append([ 'EQ', [ 'COLUMN', None, attr.column ], [ 'PARAM', (attr, None, None), attr.converters[0] ] ]) + converter = attr.converters[0] + where_list.append([ converter.EQ, [ 'COLUMN', None, attr.column ], [ 'PARAM', (attr, None, None), converter ] ]) elif not attr.columns: throw(NotImplementedError) else: attr_entity = attr.py_type; assert attr_entity == attr.reverse.entity @@ -3794,7 +3795,7 @@ def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_upda where_list.append([ 'IS_NULL', [ 'COLUMN', None, column ] ]) else: for j, (column, converter) in enumerate(izip(attr.columns, attr_entity._pk_converters_)): - where_list.append([ 'EQ', [ 'COLUMN', None, column ], [ 'PARAM', (attr, None, j), converter ] ]) + where_list.append([ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (attr, None, j), converter ] ]) if not for_update: sql_ast = [ 'SELECT', select_list, from_list, where_list ] else: sql_ast = [ 'SELECT_FOR_UPDATE', bool(nowait), select_list, from_list, where_list ] @@ -4079,14 +4080,12 @@ def _get_attrs_(entity, only=None, exclude=None, with_collections=False, with_la entity._attrnames_cache_[key] = attrs return attrs -def populate_criteria_list(criteria_list, columns, converters, params_count=0, table_alias=None): - assert len(columns) == len(converters) - for column, converter in izip(columns, converters): - if converter is not None: - criteria_list.append([ 'EQ', [ 'COLUMN', table_alias, column ], - [ 'PARAM', (params_count, None, None), converter ] ]) +def populate_criteria_list(criteria_list, columns, operations, params_count=0, table_alias=None): + for column, op in izip(columns, operations): + if op == 'IS_NULL': + criteria_list.append([ op, [ 'COLUMN', None, column ] ]) else: - criteria_list.append([ 'IS_NULL', [ 'COLUMN', None, column ] ]) + criteria_list.append([ op, [ 'COLUMN', table_alias, column ], [ 'PARAM', (params_count, None, None), None ] ]) params_count += 1 return params_count @@ -4292,7 +4291,7 @@ def load(obj, *attrs): offsets.append(len(select_list) - 1) select_list.append([ 'COLUMN', None, column ]) from_list = [ 'FROM', [ None, 'TABLE', entity._table_ ]] - criteria_list = [ [ 'EQ', [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ] + criteria_list = [ [ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ] for i, (column, converter) in enumerate(izip(obj._pk_columns_, obj._pk_converters_)) ] where_list = [ 'WHERE' ] + criteria_list @@ -4553,19 +4552,21 @@ def _attrs_with_bit_(entity, attrs, mask=-1): if get_bit(attr) & mask: yield attr def _construct_optimistic_criteria_(obj): optimistic_columns = [] - optimistic_converters = [] optimistic_values = [] + optimistic_operations = [] for attr in obj._attrs_with_bit_(obj._attrs_with_columns_, obj._rbits_): converters = attr.converters assert converters if not (attr.optimistic and converters[0].optimistic): continue dbval = obj._dbvals_[attr] optimistic_columns.extend(attr.columns) + values = attr.get_raw_values(dbval) + optimistic_values.extend(values) if dbval is None: - converters = repeat(None, len(attr.converters)) - optimistic_converters.extend(converters) - optimistic_values.extend(attr.get_raw_values(dbval)) - return optimistic_columns, optimistic_converters, optimistic_values + optimistic_operations.append('IS_NULL') + else: + optimistic_operations.extend(converter.EQ for converter in converters) + return optimistic_operations, optimistic_columns, optimistic_values def _save_principal_objects_(obj, dependent_objects): if dependent_objects is None: dependent_objects = [] elif obj in dependent_objects: @@ -4675,12 +4676,11 @@ def _save_updated_(obj): values.extend(attr.get_raw_values(val)) cache = obj._session_cache_ if obj not in cache.for_update: - optimistic_columns, optimistic_converters, optimistic_values = \ + optimistic_ops, optimistic_columns, optimistic_values = \ obj._construct_optimistic_criteria_() values.extend(optimistic_values) - else: optimistic_columns = optimistic_converters = () - query_key = (tuple(update_columns), tuple(optimistic_columns), - tuple(converter is not None for converter in optimistic_converters)) + else: optimistic_columns = optimistic_values = optimistic_ops = () + query_key = tuple(update_columns), tuple(optimistic_columns), tuple(optimistic_ops) database = obj._database_ cached_sql = obj._update_sql_cache_.get(query_key) if cached_sql is None: @@ -4693,9 +4693,9 @@ def _save_updated_(obj): where_list = [ 'WHERE' ] pk_columns = obj._pk_columns_ pk_converters = obj._pk_converters_ - params_count = populate_criteria_list(where_list, pk_columns, pk_converters, params_count) + params_count = populate_criteria_list(where_list, pk_columns, repeat('EQ'), params_count) if optimistic_columns: - populate_criteria_list(where_list, optimistic_columns, optimistic_converters, params_count) + populate_criteria_list(where_list, optimistic_columns, optimistic_ops, params_count) sql_ast = [ 'UPDATE', obj._table_, list(izip(update_columns, update_params)), where_list ] sql, adapter = database._ast2sql(sql_ast) obj._update_sql_cache_[query_key] = sql, adapter @@ -4713,18 +4713,18 @@ def _save_deleted_(obj): values.extend(obj._get_raw_pkval_()) cache = obj._session_cache_ if obj not in cache.for_update: - optimistic_columns, optimistic_converters, optimistic_values = \ + optimistic_ops, optimistic_columns, optimistic_values = \ obj._construct_optimistic_criteria_() values.extend(optimistic_values) - else: optimistic_columns = optimistic_converters = () - query_key = (tuple(optimistic_columns), tuple(converter is not None for converter in optimistic_converters)) + else: optimistic_columns = optimistic_values = optimistic_ops = () + query_key = tuple(optimistic_columns), tuple(optimistic_ops) database = obj._database_ cached_sql = obj._delete_sql_cache_.get(query_key) if cached_sql is None: where_list = [ 'WHERE' ] - params_count = populate_criteria_list(where_list, obj._pk_columns_, obj._pk_converters_) + params_count = populate_criteria_list(where_list, obj._pk_columns_, repeat('EQ')) if optimistic_columns: - populate_criteria_list(where_list, optimistic_columns, optimistic_converters, params_count) + populate_criteria_list(where_list, optimistic_columns, optimistic_ops, params_count) from_ast = [ 'FROM', [ None, 'TABLE', obj._table_ ] ] sql_ast = [ 'DELETE', None, from_ast, where_list ] sql, adapter = database._ast2sql(sql_ast) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index b0e625a95..c2669064e 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -326,6 +326,8 @@ def disconnect(pool): if con is not None: con.close() class Converter(object): + EQ = 'EQ' + NE = 'NE' optimistic = True def __deepcopy__(converter, memo): return converter # Converter instances are "immutable" diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 156f42647..eb27e7adf 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1705,6 +1705,8 @@ def negate(monad): cmp_negate.update((b, a) for a, b in items_list(cmp_negate)) class CmpMonad(BoolMonad): + EQ = 'EQ' + NE = 'NE' def __init__(monad, op, left, right): translator = left.translator if op == '<>': op = '!=' @@ -1746,13 +1748,13 @@ def getsql(monad, subquery=None): return [ [ cmp_ops[op], [ 'ROW' ] + left_sql, [ 'ROW' ] + right_sql ] ] clauses = [] for i in xrange(1, size): - clauses.append(sqland([ [ 'EQ', left_sql[j], right_sql[j] ] for j in xrange(1, i) ] + clauses.append(sqland([ [ monad.EQ, left_sql[j], right_sql[j] ] for j in xrange(1, i) ] + [ [ cmp_ops[op[0] if i < size - 1 else op], left_sql[i], right_sql[i] ] ])) return [ sqlor(clauses) ] if op == '==': - return [ sqland([ [ 'EQ', a, b ] for a, b in izip(left_sql, right_sql) ]) ] + return [ sqland([ [ monad.EQ, a, b ] for a, b in izip(left_sql, right_sql) ]) ] if op == '!=': - return [ sqlor([ [ 'NE', a, b ] for a, b in izip(left_sql, right_sql) ]) ] + return [ sqlor([ [ monad.NE, a, b ] for a, b in izip(left_sql, right_sql) ]) ] assert False, op # pragma: no cover class LogicalBinOpMonad(BoolMonad): From f70a56a1435e5fece7bd20658d94bf666ef7cf60 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 17 Aug 2016 16:31:52 +0300 Subject: [PATCH 017/547] Specify converters to optimistic criteria parameters --- pony/orm/core.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index e3f3c0863..62b96fcbf 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4080,12 +4080,12 @@ def _get_attrs_(entity, only=None, exclude=None, with_collections=False, with_la entity._attrnames_cache_[key] = attrs return attrs -def populate_criteria_list(criteria_list, columns, operations, params_count=0, table_alias=None): - for column, op in izip(columns, operations): +def populate_criteria_list(criteria_list, columns, converters, operations, params_count=0, table_alias=None): + for column, op, converter in izip(columns, operations, converters): if op == 'IS_NULL': criteria_list.append([ op, [ 'COLUMN', None, column ] ]) else: - criteria_list.append([ op, [ 'COLUMN', table_alias, column ], [ 'PARAM', (params_count, None, None), None ] ]) + criteria_list.append([ op, [ 'COLUMN', table_alias, column ], [ 'PARAM', (params_count, None, None), converter ] ]) params_count += 1 return params_count @@ -4552,6 +4552,7 @@ def _attrs_with_bit_(entity, attrs, mask=-1): if get_bit(attr) & mask: yield attr def _construct_optimistic_criteria_(obj): optimistic_columns = [] + optimistic_converters = [] optimistic_values = [] optimistic_operations = [] for attr in obj._attrs_with_bit_(obj._attrs_with_columns_, obj._rbits_): @@ -4560,13 +4561,14 @@ def _construct_optimistic_criteria_(obj): if not (attr.optimistic and converters[0].optimistic): continue dbval = obj._dbvals_[attr] optimistic_columns.extend(attr.columns) + optimistic_converters.extend(attr.converters) values = attr.get_raw_values(dbval) optimistic_values.extend(values) if dbval is None: optimistic_operations.append('IS_NULL') else: optimistic_operations.extend(converter.EQ for converter in converters) - return optimistic_operations, optimistic_columns, optimistic_values + return optimistic_operations, optimistic_columns, optimistic_converters, optimistic_values def _save_principal_objects_(obj, dependent_objects): if dependent_objects is None: dependent_objects = [] elif obj in dependent_objects: @@ -4676,10 +4678,10 @@ def _save_updated_(obj): values.extend(attr.get_raw_values(val)) cache = obj._session_cache_ if obj not in cache.for_update: - optimistic_ops, optimistic_columns, optimistic_values = \ + optimistic_ops, optimistic_columns, optimistic_converters, optimistic_values = \ obj._construct_optimistic_criteria_() values.extend(optimistic_values) - else: optimistic_columns = optimistic_values = optimistic_ops = () + else: optimistic_columns = optimistic_converters = optimistic_ops = () query_key = tuple(update_columns), tuple(optimistic_columns), tuple(optimistic_ops) database = obj._database_ cached_sql = obj._update_sql_cache_.get(query_key) @@ -4693,9 +4695,9 @@ def _save_updated_(obj): where_list = [ 'WHERE' ] pk_columns = obj._pk_columns_ pk_converters = obj._pk_converters_ - params_count = populate_criteria_list(where_list, pk_columns, repeat('EQ'), params_count) + params_count = populate_criteria_list(where_list, pk_columns, pk_converters, repeat('EQ'), params_count) if optimistic_columns: - populate_criteria_list(where_list, optimistic_columns, optimistic_ops, params_count) + populate_criteria_list(where_list, optimistic_columns, optimistic_converters, optimistic_ops, params_count) sql_ast = [ 'UPDATE', obj._table_, list(izip(update_columns, update_params)), where_list ] sql, adapter = database._ast2sql(sql_ast) obj._update_sql_cache_[query_key] = sql, adapter @@ -4713,18 +4715,18 @@ def _save_deleted_(obj): values.extend(obj._get_raw_pkval_()) cache = obj._session_cache_ if obj not in cache.for_update: - optimistic_ops, optimistic_columns, optimistic_values = \ + optimistic_ops, optimistic_columns, optimistic_converters, optimistic_values = \ obj._construct_optimistic_criteria_() values.extend(optimistic_values) - else: optimistic_columns = optimistic_values = optimistic_ops = () + else: optimistic_columns = optimistic_converters = optimistic_ops = () query_key = tuple(optimistic_columns), tuple(optimistic_ops) database = obj._database_ cached_sql = obj._delete_sql_cache_.get(query_key) if cached_sql is None: where_list = [ 'WHERE' ] - params_count = populate_criteria_list(where_list, obj._pk_columns_, repeat('EQ')) + params_count = populate_criteria_list(where_list, obj._pk_columns_, obj._pk_converters_, repeat('EQ')) if optimistic_columns: - populate_criteria_list(where_list, optimistic_columns, optimistic_ops, params_count) + populate_criteria_list(where_list, optimistic_columns, optimistic_converters, optimistic_ops, params_count) from_ast = [ 'FROM', [ None, 'TABLE', obj._table_ ] ] sql_ast = [ 'DELETE', None, from_ast, where_list ] sql, adapter = database._ast2sql(sql_ast) From b9fe819cf9b6204b808279219a84b9ce8037317a Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 30 Jun 2016 18:11:17 +0300 Subject: [PATCH 018/547] val2dbval <-> dbval2val --- pony/orm/core.py | 20 +++++++++++++++----- pony/orm/dbapiprovider.py | 4 ++++ pony/orm/sqlbuilding.py | 10 ++++++++-- pony/orm/sqltranslation.py | 6 ++++-- 4 files changed, 31 insertions(+), 9 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 62b96fcbf..e811ed6ac 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2078,8 +2078,13 @@ def db_set(attr, obj, new_dbval, is_reverse_call=False): vals[i] = new_dbval new_vals = tuple(vals) cache.db_update_composite_index(obj, attrs, old_vals, new_vals) - if new_dbval is NOT_LOADED: obj._vals_.pop(attr, None) - else: obj._vals_[attr] = new_dbval + if new_dbval is NOT_LOADED: + obj._vals_.pop(attr, None) + elif attr.reverse: + obj._vals_[attr] = new_dbval + else: + assert len(attr.converters) == 1 + obj._vals_[attr] = attr.converters[0].dbval2val(new_dbval, obj) reverse = attr.reverse if not reverse: pass @@ -4357,8 +4362,10 @@ def _db_set_(obj, avdict, unpickling=False): new_vals = tuple(vals) cache.db_update_composite_index(obj, attrs, prev_vals, new_vals) - for attr, new_dbval in iteritems(avdict): - obj._vals_[attr] = new_dbval + for attr, new_val in iteritems(avdict): + converter = attr.converters[0] + new_val = converter.dbval2val(new_val, obj) + obj._vals_[attr] = new_val def _delete_(obj, undo_funcs=None): status = obj._status_ if status in del_statuses: return @@ -4604,7 +4611,10 @@ def _update_dbvals_(obj, after_create): elif after_create and val is None: obj._rbits_ &= ~bits[attr] del vals[attr] - else: dbvals[attr] = val + else: + # TODO this conversion should be unnecessary + converter = attr.converters[0] + dbvals[attr] = converter.val2dbval(val, obj) def _save_created_(obj): auto_pk = (obj._pkval_ is None) attrs = [] diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index c2669064e..f266b1778 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -348,6 +348,10 @@ def py2sql(converter, val): return val def sql2py(converter, val): return val + def val2dbval(self, val, obj=None): + return val + def dbval2val(self, dbval, obj=None): + return dbval def get_sql_type(converter, attr=None): if attr is not None and attr.sql_type is not None: return attr.sql_type diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 0b61a8641..eec272bc5 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -13,12 +13,18 @@ class AstError(Exception): pass class Param(object): - __slots__ = 'style', 'id', 'paramkey', 'py2sql' + __slots__ = 'style', 'id', 'paramkey', 'converter' def __init__(param, paramstyle, id, paramkey, converter=None): param.style = paramstyle param.id = id param.paramkey = paramkey - param.py2sql = converter.py2sql if converter else (lambda val: val) + param.converter = converter + def py2sql(param, val): + converter = param.converter + if converter is not None: + val = converter.val2dbval(val) + val = converter.py2sql(val) + return val def __unicode__(param): paramstyle = param.style if paramstyle == 'qmark': return u'?' diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index eb27e7adf..9564ae240 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -346,9 +346,11 @@ def func(values, constructor=expr_type._get_by_raw_pkval_): offset = next_offset else: converter = provider.get_converter_by_py_type(expr_type) - def func(value, sql2py=converter.sql2py): + def func(value, converter=converter): if value is None: return None - return sql2py(value) + value = converter.sql2py(value) + value = converter.dbval2val(value) + return value row_layout.append((func, offset, ast2src(m.node))) m.orderby_columns = (offset+1,) offset += 1 From 15277ea3c625e7fc44a9cd5d82ffee29b6aea190 Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Fri, 19 Aug 2016 19:59:59 +0300 Subject: [PATCH 019/547] Updating setup.py links and removing Python 2.6 --- setup.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/setup.py b/setup.py index e8f6bdbfe..95020b54a 100644 --- a/setup.py +++ b/setup.py @@ -43,12 +43,12 @@ Pony ORM Links: ================= -- Main site: http://ponyorm.com -- Documentation: http://doc.ponyorm.com +- Main site: https://ponyorm.com +- Documentation: https://docs.ponyorm.com - GitHub: https://github.com/ponyorm/pony - Mailing list: http://ponyorm-list.ponyorm.com - ER Diagram Editor: https://editor.ponyorm.com -- Blog: http://blog.ponyorm.com +- Blog: https://blog.ponyorm.com """ classifiers = [ @@ -62,7 +62,6 @@ 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.6', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.3', @@ -74,7 +73,7 @@ author = "Alexander Kozlovsky, Alexey Malashkevich" author_email = "team@ponyorm.com" -url = "http://ponyorm.com" +url = "https://ponyorm.com" lic = "AGPL, Commercial, Free for educational and non-commercial use" packages = [ @@ -92,8 +91,8 @@ if __name__ == "__main__": pv = sys.version_info[:2] - if pv not in ((2, 6), (2, 7), (3, 3), (3, 4), (3, 5)): - s = "Sorry, but %s %s requires Python of one of the following versions: 2.6, 2.7, 3.3, 3.4 and 3.5." \ + if pv not in ((2, 7), (3, 3), (3, 4), (3, 5)): + s = "Sorry, but %s %s requires Python of one of the following versions: 2.7, 3.3, 3.4 and 3.5." \ " You have version %s" print(s % (name, version, sys.version.split(' ', 1)[0])) sys.exit(1) From e0761611f93dab084b9964ca0d10b5249b525688 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 19 Aug 2016 20:46:30 +0300 Subject: [PATCH 020/547] Typo --- pony/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/utils.py b/pony/utils.py index d4fbb3707..a6475e790 100644 --- a/pony/utils.py +++ b/pony/utils.py @@ -427,7 +427,7 @@ def strjoin(sep, strings, source_encoding='ascii', dest_encoding=None): strings[i] = s.decode(source_encoding, 'replace').replace(u'\ufffd', '?') result = sep.join(strings) if dest_encoding is None: return result - return result.encode(dest_encoding, replace) + return result.encode(dest_encoding, 'replace') def make_offsets(s): offsets = [ 0 ] From bde5e1ffa069e481aa8e13a37d2223175ebb2092 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 30 Jun 2016 21:09:48 +0300 Subject: [PATCH 021/547] TrackedDict & TrackedList classes added --- pony/orm/core.py | 21 +++++++++ pony/orm/ormtypes.py | 69 +++++++++++++++++++++++++++- pony/orm/tests/test_tracked_value.py | 56 ++++++++++++++++++++++ 3 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 pony/orm/tests/test_tracked_value.py diff --git a/pony/orm/core.py b/pony/orm/core.py index e811ed6ac..ee23e7528 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4311,6 +4311,27 @@ def load(obj, *attrs): objects = entity._fetch_objects(cursor, attr_offsets) if obj not in objects: throw(UnrepeatableReadError, 'Phantom object %s disappeared' % safe_repr(obj)) + def _attr_changed_(obj, attr): + cache = obj._session_cache_ + if not cache.is_alive: throw( + DatabaseSessionIsOver, + 'Cannot assign new value to attribute %s.%s: the database session' + ' is over' % (safe_repr(obj), attr.name)) + if obj._status_ in del_statuses: + throw_object_was_deleted(obj) + status = obj._status_ + wbits = obj._wbits_ + bit = obj._bits_[attr] + objects_to_save = cache.objects_to_save + if wbits is not None and bit: + obj._wbits_ |= bit + if status != 'modified': + assert status in ('loaded', 'inserted', 'updated') + assert obj._save_pos_ is None + obj._status_ = 'modified' + obj._save_pos_ = len(objects_to_save) + objects_to_save.append(obj) + cache.modified = True def _db_set_(obj, avdict, unpickling=False): assert obj._status_ not in created_or_deleted_statuses cache = obj._session_cache_ diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index ae18c9412..0b49d3120 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -1,9 +1,10 @@ from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, items_list, izip, basestring, unicode, buffer, int_types -import types +import types, weakref from decimal import Decimal from datetime import date, time, datetime, timedelta +from functools import wraps, WRAPPER_ASSIGNMENTS from uuid import UUID from pony.utils import throw, parse_expr @@ -213,3 +214,69 @@ def are_comparable_types(t1, t2, op='=='): return False if t1 is t2 and t1 in comparable_types: return True return (t1, t2) in coercions + +class TrackedValue(object): + def __init__(self, obj, attr): + self.obj_ref = weakref.ref(obj) + self.attr = attr + @classmethod + def make(cls, obj, attr, value): + if isinstance(value, dict): + return TrackedDict(obj, attr, value) + if isinstance(value, list): + return TrackedList(obj, attr, value) + return value + def _changed_(self): + obj = self.obj_ref() + if obj is not None: + obj._attr_changed_(self.attr) + def get_untracked(self): + assert False, 'Abstract method' # pragma: no cover + +def tracked_method(func): + @wraps(func, assigned=('__name__', '__doc__') if PY2 else WRAPPER_ASSIGNMENTS) + def new_func(self, *args, **kw): + result = func(self, *args, **kw) + self._changed_() + return result + return new_func + +class TrackedDict(TrackedValue, dict): + def __init__(self, obj, attr, value): + TrackedValue.__init__(self, obj, attr) + dict.__init__(self, ((key, self.make(obj, attr, val)) + for key, val in value.items())) + def __reduce__(self): + return dict, (dict(self),) + __setitem__ = tracked_method(dict.__setitem__) + __delitem__ = tracked_method(dict.__delitem__) + update = tracked_method(dict.update) + setdefault = tracked_method(dict.setdefault) + pop = tracked_method(dict.pop) + popitem = tracked_method(dict.popitem) + clear = tracked_method(dict.clear) + def get_untracked(self): + return {key: val.get_untracked() if isinstance(val, TrackedValue) else val + for key, val in self.items()} + +class TrackedList(TrackedValue, list): + def __init__(self, obj, attr, value): + TrackedValue.__init__(self, obj, attr) + list.__init__(self, (self.make(obj, attr, val) for val in value)) + def __reduce__(self): + return list, (list(self),) + __setitem__ = tracked_method(list.__setitem__) + __delitem__ = tracked_method(list.__delitem__) + extend = tracked_method(list.extend) + append = tracked_method(list.append) + pop = tracked_method(list.pop) + remove = tracked_method(list.remove) + insert = tracked_method(list.insert) + reverse = tracked_method(list.reverse) + sort = tracked_method(list.sort) + if PY2: + __setslice__ = tracked_method(list.__setslice__) + else: + clear = tracked_method(list.clear) + def get_untracked(self): + return [val.get_untracked() if isinstance(val, TrackedValue) else val for val in self] diff --git a/pony/orm/tests/test_tracked_value.py b/pony/orm/tests/test_tracked_value.py new file mode 100644 index 000000000..f24ce5a7c --- /dev/null +++ b/pony/orm/tests/test_tracked_value.py @@ -0,0 +1,56 @@ +import unittest + +from pony.orm.ormtypes import TrackedList, TrackedDict, TrackedValue + +class Object(object): + def __init__(self): + self.on_attr_changed = None + def _attr_changed_(self, attr): + if self.on_attr_changed is not None: + self.on_attr_changed(attr) + + +class Attr(object): + pass + + +class TestTrackedValue(unittest.TestCase): + + def test_make(self): + obj = Object() + attr = Attr() + value = {'items': ['one', 'two', 'three']} + tracked_value = TrackedValue.make(obj, attr, value) + self.assertEqual(type(tracked_value), TrackedDict) + self.assertEqual(type(tracked_value['items']), TrackedList) + + def test_dict_setitem(self): + obj = Object() + attr = Attr() + value = {'items': ['one', 'two', 'three']} + tracked_value = TrackedValue.make(obj, attr, value) + log = [] + obj.on_attr_changed = lambda x: log.append(x) + tracked_value['items'] = [1, 2, 3] + self.assertEqual(log, [attr]) + + def test_list_append(self): + obj = Object() + attr = Attr() + value = {'items': ['one', 'two', 'three']} + tracked_value = TrackedValue.make(obj, attr, value) + log = [] + obj.on_attr_changed = lambda x: log.append(x) + tracked_value['items'].append('four') + self.assertEqual(log, [attr]) + + def test_list_setslice(self): + obj = Object() + attr = Attr() + value = {'items': ['one', 'two', 'three']} + tracked_value = TrackedValue.make(obj, attr, value) + log = [] + obj.on_attr_changed = lambda x: log.append(x) + tracked_value['items'][1:2] = ['a', 'b', 'c'] + self.assertEqual(log, [attr]) + self.assertEqual(tracked_value['items'], ['one', 'a', 'b', 'c', 'three']) From 7f928a0342250ab3d936cd5df4ac4899c9097517 Mon Sep 17 00:00:00 2001 From: Vitalii Date: Wed, 23 Mar 2016 19:24:23 +0300 Subject: [PATCH 022/547] JSON support added --- pony/fixtures.py | 340 ++++++++++++++ pony/orm/core.py | 4 +- pony/orm/dbapiprovider.py | 23 +- pony/orm/dbproviders/mysql.py | 97 +++- pony/orm/dbproviders/oracle.py | 53 +++ pony/orm/dbproviders/postgres.py | 124 +++++- pony/orm/dbproviders/sqlite.py | 164 ++++++- pony/orm/ormtypes.py | 16 +- pony/orm/sqlbuilding.py | 14 + pony/orm/sqltranslation.py | 272 +++++++++++- pony/orm/tests/test_json/__init__.py | 47 ++ pony/orm/tests/test_json/_postgres.py | 127 ++++++ pony/orm/tests/test_json/test.py | 609 ++++++++++++++++++++++++++ pony/utils/__init__.py | 4 + pony/utils/properties.py | 40 ++ pony/{ => utils}/utils.py | 3 +- 16 files changed, 1924 insertions(+), 13 deletions(-) create mode 100644 pony/fixtures.py create mode 100644 pony/orm/tests/test_json/__init__.py create mode 100644 pony/orm/tests/test_json/_postgres.py create mode 100644 pony/orm/tests/test_json/test.py create mode 100644 pony/utils/__init__.py create mode 100644 pony/utils/properties.py rename pony/{ => utils}/utils.py (97%) diff --git a/pony/fixtures.py b/pony/fixtures.py new file mode 100644 index 000000000..f61a8798d --- /dev/null +++ b/pony/fixtures.py @@ -0,0 +1,340 @@ +import os +import logging + +from pony.py23compat import PY2 +from ponytest import with_cli_args, pony_fixtures + +from functools import wraps +import click +from contextlib import contextmanager, closing + +from pony.utils import cached_property, class_cached_property + +from pony.orm.dbproviders.mysql import mysql_module +from pony.utils import cached_property, class_property + +if not PY2: + from contextlib import contextmanager +else: + from contextlib2 import contextmanager + +from pony.orm import db_session, Database, rollback + + +class DBContext(object): + + class_scoped = True + + def __init__(self, test_cls): + test_cls.db_fixture = self + test_cls.db = class_property(lambda cls: self.db) + test_cls.db_provider = class_property(lambda cls: self.provider) + self.test_cls = test_cls + + @class_property + def fixture_name(cls): + return cls.provider + + def init_db(self): + raise NotImplementedError + + @cached_property + def db(self): + raise NotImplementedError + + def __enter__(self): + self.init_db() + self.test_cls.make_entities() + self.db.generate_mapping(check_tables=True, create_tables=True) + + def __exit__(self, *exc_info): + self.db.provider.disconnect() + + + @classmethod + @with_cli_args + @click.option('--db', '-d', 'database', multiple=True) + @click.option('--exclude-db', '-e', multiple=True) + def invoke(cls, database, exclude_db): + fixture = [ + MySqlContext, OracleContext, SqliteContext, PostgresContext, + SqlServerContext, + ] + all_db = [ctx.provider for ctx in fixture] + for db in database: + if db == 'all': + continue + assert db in all_db, ( + "Unknown provider: %s. Use one of %s." % (db, ', '.join(all_db)) + ) + if 'all' in database: + database = all_db + elif exclude_db: + database = set(all_db) - set(exclude_db) + elif not database: + database = ['sqlite'] + for Ctx in fixture: + if Ctx.provider in database: + yield Ctx + + db_name = 'testdb' + + +pony_fixtures.appendleft(DBContext.invoke) + + +class MySqlContext(DBContext): + provider = 'mysql' + + + def drop_db(self, cursor): + cursor.execute('use sys') + cursor.execute('drop database %s' % self.db_name) + + + def init_db(self): + with closing(mysql_module.connect(**self.CONN).cursor()) as c: + try: + self.drop_db(c) + except mysql_module.DatabaseError as exc: + print('Failed to drop db: %s' % exc) + c.execute('create database %s' % self.db_name) + c.execute('use %s' % self.db_name) + + CONN = { + 'host': "localhost", + 'user': "ponytest", + 'passwd': "ponytest", + } + + @cached_property + def db(self): + CONN = dict(self.CONN, db=self.db_name) + return Database('mysql', **CONN) + + +class SqlServerContext(DBContext): + + provider = 'sqlserver' + + def get_conn_string(self, db=None): + s = ( + 'DSN=MSSQLdb;' + 'SERVER=mssql;' + 'UID=sa;' + 'PWD=pass;' + ) + if db: + s += 'DATABASE=%s' % db + return s + + @cached_property + def db(self): + CONN = self.get_conn_string(self.db_name) + return Database('mssqlserver', CONN) + + def init_db(self): + import pyodbc + cursor = pyodbc.connect(self.get_conn_string(), autocommit=True).cursor() + with closing(cursor) as c: + try: + self.drop_db(c) + except pyodbc.DatabaseError as exc: + print('Failed to drop db: %s' % exc) + c.execute('create database %s' % self.db_name) + c.execute('use %s' % self.db_name) + + def drop_db(self, cursor): + cursor.execute('use master') + cursor.execute('drop database %s' % self.db_name) + + +class SqliteContext(DBContext): + provider = 'sqlite' + + def init_db(self): + try: + os.remove(self.db_path) + except OSError as exc: + print('Failed to drop db: %s' % exc) + + + @cached_property + def db_path(self): + p = os.path.dirname(__file__) + p = os.path.join(p, self.db_name) + return os.path.abspath(p) + + @cached_property + def db(self): + return Database('sqlite', self.db_path, create_db=True) + + +class PostgresContext(DBContext): + provider = 'postgresql' + + def get_conn_dict(self, no_db=False): + d = dict( + user='ponytest', password='ponytest', + host='localhost' + ) + if not no_db: + d.update(database=self.db_name) + return d + + def init_db(self): + import psycopg2 + conn = psycopg2.connect( + **self.get_conn_dict(no_db=True) + ) + conn.set_isolation_level(0) + with closing(conn.cursor()) as cursor: + try: + self.drop_db(cursor) + except psycopg2.DatabaseError as exc: + print('Failed to drop db: %s' % exc) + cursor.execute('create database %s' % self.db_name) + + def drop_db(self, cursor): + cursor.execute('drop database %s' % self.db_name) + + + @cached_property + def db(self): + return Database('postgres', **self.get_conn_dict()) + + +class OracleContext(DBContext): + provider = 'oracle' + + def __enter__(self): + os.environ.update(dict( + ORACLE_BASE='/u01/app/oracle', + ORACLE_HOME='/u01/app/oracle/product/12.1.0/dbhome_1', + ORACLE_OWNR='oracle', + ORACLE_SID='orcl', + )) + return super(OracleContext, self).__enter__() + + def init_db(self): + import cx_Oracle + with closing(self.connect_sys()) as conn: + with closing(conn.cursor()) as cursor: + try: + self._destroy_test_user(cursor) + self._drop_tablespace(cursor) + except cx_Oracle.DatabaseError as exc: + print('Failed to drop db: %s' % exc) + cursor.execute( + """CREATE TABLESPACE %(tblspace)s + DATAFILE '%(datafile)s' SIZE 20M + REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize)s + """ % self.parameters) + cursor.execute( + """CREATE TEMPORARY TABLESPACE %(tblspace_temp)s + TEMPFILE '%(datafile_tmp)s' SIZE 20M + REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize_tmp)s + """ % self.parameters) + self._create_test_user(cursor) + + + def _drop_tablespace(self, cursor): + cursor.execute( + 'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS' + % self.parameters) + cursor.execute( + 'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS' + % self.parameters) + + + parameters = { + 'tblspace': 'test_tblspace', + 'tblspace_temp': 'test_tblspace_temp', + 'datafile': 'test_datafile.dbf', + 'datafile_tmp': 'test_datafile_tmp.dbf', + 'user': 'ponytest', + 'password': 'ponytest', + 'maxsize': '100M', + 'maxsize_tmp': '100M', + } + + def connect_sys(self): + import cx_Oracle + return cx_Oracle.connect('sys/the@localhost/ORCL', mode=cx_Oracle.SYSDBA) + + def connect_test(self): + import cx_Oracle + return cx_Oracle.connect('test_user/test_password@localhost/ORCL') + + + @cached_property + def db(self): + return Database('oracle', 'test_user/test_password@localhost/ORCL') + + def _create_test_user(self, cursor): + cursor.execute( + """CREATE USER %(user)s + IDENTIFIED BY %(password)s + DEFAULT TABLESPACE %(tblspace)s + TEMPORARY TABLESPACE %(tblspace_temp)s + QUOTA UNLIMITED ON %(tblspace)s + """ % self.parameters + ) + cursor.execute( + """GRANT CREATE SESSION, + CREATE TABLE, + CREATE SEQUENCE, + CREATE PROCEDURE, + CREATE TRIGGER + TO %(user)s + """ % self.parameters + ) + + def _destroy_test_user(self, cursor): + cursor.execute(''' + DROP USER %(user)s CASCADE + ''' % self.parameters) + + +@contextmanager +def logging_context(test): + level = logging.getLogger().level + from pony.orm.core import debug, sql_debug + logging.getLogger().setLevel(logging.INFO) + sql_debug(True) + yield + logging.getLogger().setLevel(level) + sql_debug(debug) + + +@with_cli_args +@click.option('--log', is_flag=True) +def use_logging(log): + if log: + yield logging_context + +pony_fixtures.appendleft(use_logging) + + +class DBSession(object): + + def __init__(self, test): + self.test = test + + @property + def in_db_session(self): + ret = getattr(self.test, 'in_db_session', True) + method = getattr(self.test, self.test._testMethodName) + return getattr(method, 'in_db_session', ret) + + def __enter__(self): + rollback() + if self.in_db_session: + db_session.__enter__() + + def __exit__(self, *exc_info): + rollback() + if self.in_db_session: + db_session.__exit__() + +pony_fixtures.appendleft([DBSession]) diff --git a/pony/orm/core.py b/pony/orm/core.py index ee23e7528..5d0ae7d3c 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -19,7 +19,7 @@ import pony from pony import options from pony.orm.decompiling import decompile -from pony.orm.ormtypes import LongStr, LongUnicode, numeric_types, RawSQL, get_normalized_type_of +from pony.orm.ormtypes import LongStr, LongUnicode, numeric_types, RawSQL, get_normalized_type_of, Json from pony.orm.asttranslation import ast2src, create_extractors, TranslationError from pony.orm.dbapiprovider import ( DBAPIProvider, DBException, Warning, Error, InterfaceError, DatabaseError, DataError, @@ -53,7 +53,7 @@ composite_key composite_index flush commit rollback db_session with_transaction - LongStr LongUnicode + LongStr LongUnicode Json select left_join get exists delete diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index f266b1778..d52323d0b 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, basestring, unicode, buffer, int_types -import os, re +import os, re, json from decimal import Decimal, InvalidOperation from datetime import datetime, date, time, timedelta from uuid import uuid4, UUID @@ -9,7 +9,7 @@ import pony from pony.utils import is_utf8, decorator, throw, localbase, deprecated from pony.converting import str2date, str2time, str2datetime, str2timedelta -from pony.orm.ormtypes import LongStr, LongUnicode, RawSQLType +from pony.orm.ormtypes import LongStr, LongUnicode, RawSQLType, TrackedValue, Json class DBException(Exception): def __init__(exc, original_exc, *args): @@ -732,3 +732,22 @@ def py2sql(converter, val): sql2py = validate def sql_type(converter): return "UUID" + +class JsonConverter(Converter): + json_kwargs = {} + class JsonEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, Json): + return obj.wrapped + return json.JSONEncoder.default(self, obj) + def val2dbval(self, val, obj=None): + return json.dumps(val, cls=self.JsonEncoder, **self.json_kwargs) + def dbval2val(self, dbval, obj=None): + if isinstance(dbval, (int, bool, float, type(None))): + return dbval + val = json.loads(dbval) + if obj is None: + return val + return TrackedValue.make(obj, self.attr, val) + def sql_type(self): + return "JSON" diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 4e9165a4d..5dfb6c7cc 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -27,7 +27,7 @@ mysql_converters.encoders[timedelta] = lambda val: mysql_converters.escape_str(timedelta2str(val)) mysql_module_name = 'pymysql' -from pony.orm import core, dbschema, dbapiprovider +from pony.orm import core, dbschema, dbapiprovider, ormtypes, sqltranslation from pony.orm.core import log_orm from pony.orm.dbapiprovider import DBAPIProvider, Pool, get_version_tuple, wrap_dbapi_exceptions from pony.orm.sqltranslation import SQLTranslator @@ -46,6 +46,77 @@ class MySQLSchema(dbschema.DBSchema): class MySQLTranslator(SQLTranslator): dialect = 'MySQL' + + + class CmpMonad(sqltranslation.CmpMonad): + + def make_json_cast_if_needed(monad, left_sql, right_sql): + translator = monad.left.translator + if monad.op not in ('==', '!='): + return sqltranslation.CmpMonad.make_json_cast_if_needed( + monad, left_sql, right_sql + ) + def need_cast(monad): + if isinstance(monad, sqltranslation.ParamMonad): + return True + return not isinstance(monad, sqltranslation.JsonMixin) + + if need_cast(monad.left): + sql = left_sql[0] + expr = translator.CastToJsonExprMonad( + translator, sql, target_monad=monad.left + ) + return expr.getsql(), right_sql + if need_cast(monad.right): + sql = right_sql[0] + expr = translator.CastToJsonExprMonad( + translator, sql, target_monad=monad.right + ) + return left_sql, expr.getsql() + return left_sql, right_sql + + + class CastFromJsonExprMonad(sqltranslation.CastFromJsonExprMonad): + + @classmethod + def dispatch_type(cls, typ): + if issubclass(typ, int): + return 'signed' + if issubclass(typ, float): + raise sqltranslation.AbortCast + return sqltranslation.CastFromJsonExprMonad.dispatch_type(typ) + + class JsonContainsExprMonad(sqltranslation.JsonContainsExprMonad): + + def __init__(monad, json_monad, item): + if not isinstance(item, sqltranslation.StringConstMonad): + raise NotImplementedError + sqltranslation.JsonContainsExprMonad.__init__( + monad, json_monad, item + ) + + def _dict_contains(monad): + path_sql = monad.json_monad._get_path_sql( + getattr(monad.json_monad, 'path', ()) + ) + path_sql.append(monad.item.value) + return ['JSON_CONTAINS_PATH', monad.attr_sql, path_sql] + + def _list_contains(monad): + translator = monad.translator + path_sql = monad.json_monad._get_path_sql( + getattr(monad.json_monad, 'path', ()) + ) + item = translator.ConstMonad.new(translator, '["%s"]' % monad.item.value) + item_sql, = item.getsql() + return ['JSON_CONTAINS', monad.attr_sql, path_sql, item_sql] + + def getsql(monad): + return [ + ['OR', monad._dict_contains(), monad._list_contains()] + ] + + class MySQLBuilder(SQLBuilder): dialect = 'MySQL' def CONCAT(builder, *args): @@ -87,6 +158,22 @@ def DATETIME_SUB(builder, expr, delta): if isinstance(delta, timedelta): return 'DATE_SUB(', builder(expr), ", INTERVAL '", timedelta2str(delta), "' HOUR_SECOND)" return 'SUBTIME(', builder(expr), ', ', builder(delta), ')' + def JSON_GETPATH(builder, expr, key): + return 'json_extract(', builder(expr), ', ', builder(key), ')' + def JSON_SUBTRACT_PATH(builder, expr, key): + return 'json_remove(', builder(expr), ', ', builder(key), ')' + def JSON_ARRAY_LENGTH(builder, value): + return 'json_length(', builder(value), ')' + def AS_JSON(builder, target): + return 'CAST(', builder(target), ' AS JSON)' + def EQ_JSON(builder, left, right): + return '(', builder(left), '=', builder.AS_JSON(right), ')' + def NE_JSON(builder, left, right): + return '(', builder(left), '!=', builder.AS_JSON(right), ')' + def JSON_CONTAINS(builder, expr, path, key): + return 'json_contains(', builder(expr), ', ', builder(key), ', ', builder(path), ')' + def JSON_CONTAINS_PATH(builder, expr, path): + return 'json_contains_path(', builder(expr), ", 'one', ", builder(path), ')' class MySQLStrConverter(dbapiprovider.StrConverter): def sql_type(converter): @@ -122,6 +209,13 @@ class MySQLUuidConverter(dbapiprovider.UuidConverter): def sql_type(converter): return 'BINARY(16)' +class MySQLJsonConverter(dbapiprovider.JsonConverter): + EQ = 'EQ_JSON' + def init(self, kwargs): + if self.provider.server_version < (5, 7, 8): + version = '.'.join(imap(str, self.provider.server_version)) + raise NotImplementedError("MySQL %s has no JSON support" % version) + class MySQLProvider(DBAPIProvider): dialect = 'MySQL' paramstyle = 'format' @@ -154,6 +248,7 @@ class MySQLProvider(DBAPIProvider): (timedelta, MySQLTimedeltaConverter), (UUID, MySQLUuidConverter), (buffer, MySQLBlobConverter), + (ormtypes.Json, MySQLJsonConverter), ] def normalize_name(provider, name): diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 7b654475a..a13d474eb 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -13,6 +13,7 @@ from pony.orm import core, sqlbuilding, dbapiprovider, sqltranslation from pony.orm.core import log_orm, log_sql, DatabaseError, TranslationError from pony.orm.dbschema import DBSchema, DBObject, Table, Column +from pony.orm.ormtypes import Json from pony.orm.dbapiprovider import DBAPIProvider, wrap_dbapi_exceptions, get_version_tuple from pony.utils import throw from pony.converting import timedelta2str @@ -124,6 +125,36 @@ def get_normalized_type_of(translator, value): if value == '': return NoneType return sqltranslation.SQLTranslator.get_normalized_type_of(value) + class JsonItemMonad(sqltranslation.JsonItemMonad): + def nonzero(monad): + raise NotImplementedError + + class JsonContainsExprMonad(sqltranslation.JsonContainsExprMonad): + + def __init__(monad, json_monad, item): + if not isinstance(item, sqltranslation.StringConstMonad): + raise NotImplementedError + sqltranslation.JsonContainsExprMonad.__init__( + monad, json_monad, item + ) + + def _dict_contains(monad): + path_sql = monad.json_monad._get_path_sql( + getattr(monad.json_monad, 'path', ()) + ) + path_sql.append(monad.item.value) + return ['JSON_CONTAINS_PATH', monad.attr_sql, path_sql] + + def _list_contains(monad): + path_sql = monad.json_monad._get_path_sql( + getattr(monad.json_monad, 'path', ()) + ) + return ['JSON_LIST_CONTAINS', monad.attr_sql, path_sql, monad.item.value] + + def getsql(monad): + return [ ['OR', monad._dict_contains(), monad._list_contains()] ] + + class OraBuilder(sqlbuilding.SQLBuilder): dialect = 'Oracle' def INSERT(builder, table_name, columns, values, returning=None): @@ -225,6 +256,19 @@ def DATETIME_SUB(builder, expr, delta): if isinstance(delta, timedelta): return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' + def JSON_GETPATH(builder, expr, key): + query = 'JSON_QUERY(', builder(expr), ', ', builder(key), ' WITH WRAPPER)' + return 'REGEXP_REPLACE(', query, ", '(^\[|\]$)', '')" + def JSON_EXISTS(builder, expr, key): + return 'JSON_EXISTS(', builder(expr), ', ', builder(key), ')' + def JSON_CONTAINS_PATH(builder, expr, path): + return builder.JSON_EXISTS(expr, path) + def JSON_LIST_CONTAINS(builder, expr, path, key): + query = 'JSON_QUERY(', builder(expr), ', ', builder(path), ')' + return 'REGEXP_LIKE(', query, ', \'', search_in_json_list_regexp(key), '\')' + +def search_in_json_list_regexp(what): + return r'^\[(.+, ?)?"%s"(, ?.+)?\]$' % what class OraBoolConverter(dbapiprovider.BoolConverter): if not PY2: @@ -317,6 +361,14 @@ class OraUuidConverter(dbapiprovider.UuidConverter): def sql_type(converter): return 'RAW(16)' +class OraJsonConverter(dbapiprovider.JsonConverter): + optimistic = False + def sql2py(converter, dbval): + if hasattr(dbval, 'read'): dbval = dbval.read() + return dbapiprovider.JsonConverter.sql2py(converter, dbval) + def sql_type(converter): + return 'CLOB' + class OraProvider(DBAPIProvider): dialect = 'Oracle' paramstyle = 'named' @@ -346,6 +398,7 @@ class OraProvider(DBAPIProvider): (timedelta, OraTimedeltaConverter), (UUID, OraUuidConverter), (buffer, OraBlobConverter), + (Json, OraJsonConverter), ] @wrap_dbapi_exceptions diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index c2bfba44e..600484040 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from pony.py23compat import PY2, basestring, unicode, buffer, int_types +import json from decimal import Decimal from datetime import datetime, date, time, timedelta from uuid import UUID @@ -16,7 +17,10 @@ import psycopg2.extras psycopg2.extras.register_uuid() -from pony.orm import core, dbschema, dbapiprovider +psycopg2.extras.register_default_json(loads=lambda x: x) +psycopg2.extras.register_default_jsonb(loads=lambda x: x) + +from pony.orm import core, dbschema, dbapiprovider, sqltranslation, ormtypes from pony.orm.core import log_orm from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions from pony.orm.sqltranslation import SQLTranslator @@ -35,6 +39,78 @@ class PGSchema(dbschema.DBSchema): class PGTranslator(SQLTranslator): dialect = 'PostgreSQL' + class JsonItemMonad(sqltranslation.JsonItemMonad): + allow_get_by_key_syntax = True + + def nonzero(monad): + translator = monad.translator + empty_str = translator.StringExprMonad( + translator, str, ['RAWSQL', '\'""\'::jsonb'] + ) + str_not_empty = translator.CmpMonad( + '!=', monad, empty_str + ) + is_true = translator.CastFromJsonExprMonad( + bool, translator, monad.getsql()[0] + ) + sql = ['AND'] + sql.extend(str_not_empty.getsql()) + sql.extend(is_true.getsql()) + return translator.BoolExprMonad(translator, sql) + + class CmpMonad(sqltranslation.CmpMonad): + + def make_json_cast_if_needed(monad, left_sql, right_sql): + translator = monad.left.translator + if monad.op not in ('==', '!='): + return sqltranslation.CmpMonad.make_json_cast_if_needed( + monad, left_sql, right_sql + ) + if isinstance(monad.left, sqltranslation.NumericMixin): + sql = left_sql[0] + expr = translator.CastToJsonExprMonad( + translator, sql, target_monad=monad.left + ) + return expr.getsql(), right_sql + if isinstance(monad.right, sqltranslation.NumericMixin): + sql = right_sql[0] + expr = translator.CastToJsonExprMonad( + translator, sql, target_monad=monad.right + ) + return left_sql, expr.getsql() + return left_sql, right_sql + + + class CastFromJsonExprMonad(sqltranslation.CastFromJsonExprMonad): + + @classmethod + def dispatch_type(cls, typ): + sql_type = sqltranslation.CastFromJsonExprMonad.dispatch_type(typ) + if not issubclass(typ, (int, float, bool)): + return sql_type + return 'text::%s' % sql_type + + + class CastToJsonExprMonad(sqltranslation.CastToJsonExprMonad): + + cast_to = 'jsonb' + + def getsql(monad): + if isinstance(monad.target_monad, sqltranslation.NumericConstMonad): + monad.sql = ['SINGLE_QUOTES', monad.sql] + return sqltranslation.CastToJsonExprMonad.getsql(monad) + + class JsonContainsExprMonad(sqltranslation.JsonContainsExprMonad): + def getsql(monad): + json_monad = monad.json_monad + path = getattr(json_monad, 'path', ()) + path_sql = json_monad._get_path_sql(path) if path else None + item_sql, = monad.item.getsql() + return [ + ['JSON_CONTAINS', monad.attr_sql, path_sql, item_sql] + ] + + class PGValue(Value): __slots__ = [] def __unicode__(self): @@ -73,6 +149,47 @@ def DATETIME_SUB(builder, expr, delta): if isinstance(delta, timedelta): return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' DAY TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' + def JSON_PATH(builder, *items): + ret = ["'{"] + for i, item in enumerate(items): + if i: ret.append(', ') + if isinstance(item, basestring): + item = '"', item, '"' + ret.append(item) + ret.append("}'") + return ret + def JSON_GET(builder, expr, key): + val = builder.VALUE(key) + return '(', builder(expr), "->", val, ')' + def JSON_GETPATH(builder, expr, key): + return '(', builder(expr), "#>", builder(key), ')' + def JSON_CONTAINS(builder, expr, path, key): + if path: + json_sql = builder.JSON_GETPATH(expr, path) + else: + json_sql = builder(expr) + return json_sql, " ? ", builder(key) + def JSON_CONTAINS_JSON(builder, sub_value, value): + return builder(sub_value), " <@ ", builder(value) + def JSON_IS_CONTAINED(builder, value, contained_in): + raise NotImplementedError('Not needed') + def JSON_HAS_ANY(builder, array, value): + raise NotImplementedError + def JSON_HAS_ALL(builder, array, value): + raise NotImplementedError + def JSON_SUBTRACT_VALUE(builder, expr, key): + val = builder.VALUE(key) + return '(', builder(expr), " - ", val, ')' + def JSON_SUBTRACT_PATH(builder, value, key): + return '(', builder(value), " #- ", builder(key), ')' + def JSON_ARRAY_LENGTH(builder, value): + return 'jsonb_array_length(', builder(value), ')' + def _as_json(builder, target): + return '(', builder(target), ')::jsonb' + def CAST(builder, expr, type): + return '(', builder(expr), ')::', type + def SINGLE_QUOTES(builder, expr): + return "'", builder(expr), "'" class PGStrConverter(dbapiprovider.StrConverter): if PY2: @@ -104,6 +221,10 @@ class PGUuidConverter(dbapiprovider.UuidConverter): def py2sql(converter, val): return val +class PGJsonConverter(dbapiprovider.JsonConverter): + def sql_type(self): + return "JSONB" + class PGPool(Pool): def _connect(pool): pool.con = pool.dbapi_module.connect(*pool.args, **pool.kwargs) @@ -244,6 +365,7 @@ def drop_table(provider, connection, table_name): (timedelta, PGTimedeltaConverter), (UUID, PGUuidConverter), (buffer, PGBlobConverter), + (ormtypes.Json, PGJsonConverter), ] provider_cls = PGProvider diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 126487220..8257d7eb6 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -2,6 +2,7 @@ from pony.py23compat import PY2, imap, basestring, buffer, int_types, unicode import os.path +import json import sqlite3 as sqlite from decimal import Decimal from datetime import datetime, date, time, timedelta @@ -10,13 +11,19 @@ from threading import Lock from uuid import UUID from binascii import hexlify +from functools import wraps -from pony.orm import core, dbschema, sqltranslation, dbapiprovider +from pony.orm import core, dbschema, sqltranslation, dbapiprovider, ormtypes from pony.orm.core import log_orm from pony.orm.sqlbuilding import SQLBuilder, join, make_unary_func from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions from pony.utils import datetime2timestamp, timestamp2datetime, absolutize_path, throw +from contextlib import contextmanager + +class SqliteExtensionUnavailable(Exception): + pass + NoneType = type(None) class SQLiteForeignKey(dbschema.ForeignKey): @@ -47,8 +54,40 @@ class SQLiteTranslator(sqltranslation.SQLTranslator): StringMixin_UPPER = make_overriden_string_func('PY_UPPER') StringMixin_LOWER = make_overriden_string_func('PY_LOWER') + class CmpMonad(sqltranslation.CmpMonad): + def __init__(monad, op, left, right): + translator = left.translator + sqltranslation.CmpMonad.__init__(monad, op, left, right) + if not isinstance(left, translator.JsonMixin): + return + if op in ('==', '!='): + if isinstance(right, sqltranslation.AttrMonad) : + left.quote_strings = False + + class JsonItemMonad(sqltranslation.JsonItemMonad): + quote_strings = True + def getsql(monad): + sql, = sqltranslation.JsonItemMonad.getsql(monad) + if monad.quote_strings: + sql[0] = 'JSON_GETPATH__QUOTE_STRINGS' + return [sql] + def nonzero(monad): + translator = monad.translator + if translator.database.provider.json1_available: + monad.quote_strings = False + return monad + sql = ['PY_JSON_NONZERO'] + expr_sql = monad.attr_monad.getsql()[0] + path_sql = monad._get_path_sql(monad.path) + sql.extend([expr_sql, path_sql]) + return translator.BoolExprMonad(translator, sql) + + class SQLiteBuilder(SQLBuilder): dialect = 'SQLite' + def __init__(builder, provider, ast): + builder.json1_available = provider.json1_available + SQLBuilder.__init__(builder, provider, ast) def SELECT_FOR_UPDATE(builder, nowait, *sections): assert not builder.indent and not nowait return builder.SELECT(*sections) @@ -124,6 +163,42 @@ def RANDOM(builder): PY_UPPER = make_unary_func('py_upper') PY_LOWER = make_unary_func('py_lower') + def JSON_PATH(builder, *items): + if builder.json1_available: + return SQLBuilder.JSON_PATH(builder, *items) + return "'", json.dumps(items), "'" + def JSON_GETPATH(builder, expr, key): + if not builder.json1_available: + return 'py_json_extract(', builder(expr), ', ', builder(key), ', 0)' + return 'json_extract(', builder(expr), ', ', builder(key), ')' + def JSON_GETPATH__QUOTE_STRINGS(builder, expr, key): + if not builder.json1_available: + return 'py_json_extract(', builder(expr), ', ', builder(key), ', 1)' + ret = 'json_extract(', builder(expr), ', null, ', builder(key), ')' + return 'unwrap_extract_json(', ret, ')' + def JSON_SUBTRACT_PATH(builder, expr, key): + if not builder.json1_available: + raise SqliteExtensionUnavailable('json1') + return 'json_remove(', builder(expr), ', ', builder(key), ')' + def JSON_ARRAY_LENGTH(builder, value): + if not builder.json1_available: + raise SqliteExtensionUnavailable('json1') + return 'json_array_length(', builder(value), ')' + def JSON_CONTAINS(builder, expr, path, key): + # if builder.json1_available: + # TODO impl + with builder.json1_disabled(): + return 'py_json_contains(', builder(expr), ', ', builder(path), ', ', builder(key), ')' + def PY_JSON_NONZERO(builder, expr, path): + return 'py_json_nonzero(', builder(expr), ', ', builder(path), ')' + + @contextmanager + def json1_disabled(builder): + was_available = builder.json1_available + builder.json1_available = False + yield + builder.json1_available = was_available + class SQLiteIntConverter(dbapiprovider.IntConverter): def sql_type(converter): attr = converter.attr @@ -175,6 +250,23 @@ def sql2py(converter, val): def py2sql(converter, val): return datetime2timestamp(val) +class SQLiteJsonConverter(dbapiprovider.JsonConverter): + json_kwargs = {'separators': (',', ':')} + +def print_traceback(func): + @wraps(func) + def wrapper(*args, **kw): + try: + return func(*args, **kw) + except: + if core.debug: + import traceback + msg = traceback.format_exc() + log_orm(msg) + raise + return wrapper + + class SQLiteProvider(DBAPIProvider): dialect = 'SQLite' max_name_len = 1024 @@ -201,13 +293,19 @@ class SQLiteProvider(DBAPIProvider): (time, SQLiteTimeConverter), (timedelta, SQLiteTimedeltaConverter), (UUID, dbapiprovider.UuidConverter), - (buffer, dbapiprovider.BlobConverter), - ] + (buffer, dbapiprovider.BlobConverter), + (ormtypes.Json, SQLiteJsonConverter) + ] def __init__(provider, *args, **kwargs): DBAPIProvider.__init__(provider, *args, **kwargs) provider.transaction_lock = Lock() + @wrap_dbapi_exceptions + def inspect_connection(provider, conn): + DBAPIProvider.inspect_connection(provider, conn) + provider.json1_available = provider.check_json1(conn) + @wrap_dbapi_exceptions def set_transaction_mode(provider, connection, cache): assert not cache.in_transaction @@ -317,6 +415,16 @@ def _exists(provider, connection, table_name, index_name=None, case_sensitive=Tr def fk_exists(provider, connection, table_name, fk_name): assert False # pragma: no cover + def check_json1(provider, connection): + cursor = connection.cursor() + sql = ''' + select json('{"this": "is", "a": ["test"]}')''' + try: + cursor.execute(sql) + return True + except sqlite.OperationalError: + return False + provider_cls = SQLiteProvider def _text_factory(s): @@ -340,6 +448,51 @@ def func(value): py_upper = make_string_function('py_upper', unicode.upper) py_lower = make_string_function('py_lower', unicode.lower) +@print_traceback +def unwrap_extract_json(value): + ''' + [null,some-value] -> some-value + ''' + assert value.startswith('[null,') + return value[6:-1] + +@print_traceback +def py_json_extract(value, path, quote_strings): + value = json.loads(value) + for item in json.loads(path): + try: + value = value[item] + except (KeyError, IndexError): + value = None + break + if isinstance(value, int) and not isinstance(value, bool): + return value + if isinstance(value, basestring) and not quote_strings: + return value + return json.dumps(value, separators=(',', ':')) + +@print_traceback +def py_json_contains(value, path, key): + value = json.loads(value) + try: + for item in json.loads(path): + value = value[item] + except (KeyError, IndexError): + value = None + if isinstance(value, (list, dict)): + return key in value + +@print_traceback +def py_json_nonzero(value, path): + value = json.loads(value) + try: + for item in json.loads(path): + value = value[item] + except (KeyError, IndexError): + value = None + return bool(value) + + class SQLitePool(Pool): def __init__(pool, filename, create_db): # called separately in each thread pool.filename = filename @@ -355,6 +508,11 @@ def _connect(pool): con.create_function('rand', 0, random) con.create_function('py_upper', 1, py_upper) con.create_function('py_lower', 1, py_lower) + con.create_function('unwrap_extract_json', 1, unwrap_extract_json) + con.create_function('py_json_extract', 3, py_json_extract) + con.create_function('py_json_contains', 3, py_json_contains) + con.create_function('py_json_nonzero', 2, py_json_nonzero) + con.create_function('py_lower', 1, py_lower) if sqlite.sqlite_version_info >= (3, 6, 19): con.execute('PRAGMA foreign_keys = true') def disconnect(pool): diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 0b49d3120..49d2b2783 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -153,8 +153,9 @@ def normalize_type(t): if t is NoneType: return t t = type_normalization_dict.get(t, t) if t in primitive_types: return t + if issubclass(t, (slice, type(Ellipsis))): return t if issubclass(t, basestring): return unicode - if issubclass(t, dict): return dict + if issubclass(t, (dict, Json)): return Json throw(TypeError, 'Unsupported type %r' % t.__name__) coercions = { @@ -184,6 +185,10 @@ def are_comparable_types(t1, t2, op='=='): # types must be normalized already! tt1 = type(t1) tt2 = type(t2) + + t12 = {t1, t2} + if Json in t12 and t12 < {Json, str, unicode, int, bool, float}: + return True if op in ('in', 'not in'): if tt2 is RawSQLType: return True if tt2 is not SetType: return False @@ -280,3 +285,12 @@ def __reduce__(self): clear = tracked_method(list.clear) def get_untracked(self): return [val.get_untracked() if isinstance(val, TrackedValue) else val for val in self] + +class Json(object): + """A wrapper over a dict or list + """ + def __init__(self, wrapped): + self.wrapped = wrapped + + def __repr__(self): + return 'Json %s' % repr(self.wrapped) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index eec272bc5..599cee879 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -503,3 +503,17 @@ def RANDOM(builder): def RAWSQL(builder, sql): if isinstance(sql, basestring): return sql return [ x if isinstance(x, basestring) else builder(x) for x in sql ] + def JSON_PATH(builder, *items): + ret = ['\'$'] + for item in items: + if isinstance(item, int): + ret.append('[%d]' % item) + elif isinstance(item, str): + ret.append('."%s"' % item) + else: assert 0 + ret.append('\'') + return ret + def JSON_GETPATH(builder, expr, key): + raise NotImplementedError + def CAST(builder, expr, type): + return 'CAST(', builder(expr), ' AS ', type, ')' diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 9564ae240..a6f244a4f 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -16,7 +16,8 @@ from pony.orm.asttranslation import ASTTranslator, ast2src, TranslationError from pony.orm.ormtypes import \ numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ - get_normalized_type_of, normalize_type, coerce_types, are_comparable_types + get_normalized_type_of, normalize_type, coerce_types, are_comparable_types, \ + Json from pony.orm import core from pony.orm.core import EntityMeta, Set, JOIN, OptimizationFailed, Attribute, DescWrapper @@ -334,6 +335,8 @@ def __init__(translator, tree, extractors, vartypes, parent_translator=None, lef offset = 0 provider = translator.database.provider for m in expr_monads: + if m.disable_distinct: + translator.distinct = False expr_type = m.type if isinstance(expr_type, SetType): expr_type = expr_type.item_type if isinstance(expr_type, EntityMeta): @@ -352,7 +355,7 @@ def func(value, converter=converter): value = converter.dbval2val(value) return value row_layout.append((func, offset, ast2src(m.node))) - m.orderby_columns = (offset+1,) + m.orderby_columns = (offset+1,) if not m.disable_ordering else () offset += 1 translator.row_layout = row_layout translator.col_names = [ src for func, slice_or_offset, src in translator.row_layout ] @@ -964,6 +967,8 @@ class MonadMixin(with_metaclass(MonadMeta)): pass class Monad(with_metaclass(MonadMeta)): + disable_distinct = False + disable_ordering = False def __init__(monad, translator, type): monad.translator = translator monad.type = type @@ -1490,6 +1495,7 @@ def new(parent, attr, *args, **kwargs): elif type is datetime: cls = translator.DatetimeAttrMonad elif type is buffer: cls = translator.BufferAttrMonad elif type is UUID: cls = translator.UuidAttrMonad + elif type is Json: cls = translator.JsonAttrMonad elif isinstance(type, EntityMeta): cls = translator.ObjectAttrMonad else: throw(NotImplementedError, type) # pragma: no cover return cls(parent, attr, *args, **kwargs) @@ -1544,6 +1550,90 @@ class DatetimeAttrMonad(DatetimeMixin, AttrMonad): pass class BufferAttrMonad(BufferMixin, AttrMonad): pass class UuidAttrMonad(UuidMixin, AttrMonad): pass + +class JsonMixin(object): + disable_distinct = True # at least in Oracle we cannot use DISTINCT with JSON column + disable_ordering = True # at least in Oracle we cannot use ORDER BY with JSON column + + @classmethod + def _get_path_sql(cls, items): + return ['JSON_PATH'] + [cls._get_value(item) for item in items] + + @classmethod + def _get_value(cls, monad): + tr = monad.translator + if not isinstance(monad, (tr.NumericConstMonad, tr.StringConstMonad)): + raise TypeError('Invalid JSON path item: %s' % ast2src(monad.node)) + return monad.value + + allow_subtract_key_syntax = False # support only subtracting path by default + + def __sub__(monad, other): + translator = monad.translator + left_sql, = monad.getsql() + items = None + if isinstance(other, translator.ListMonad): + items = other.items + elif not monad.allow_subtract_key_syntax: + items = [other] + else: + value = monad._get_value(other) + sql = ['JSON_SUBTRACT_VALUE', left_sql, value] + if items: + path = monad._get_path_sql(items) + sql = ['JSON_SUBTRACT_PATH', left_sql, path] + return translator.JsonExprMonad(translator, Json, sql) + + def __getitem__(monad, item, is_overriden=False): + ''' + Transform the item and return it. Please override. + ''' + assert is_overriden, 'Json.__getitem__ is not a valid implementation' + if isinstance(item, slice) \ + and isinstance(item.start, (NoneType, NoneMonad)) \ + and isinstance(item.stop, (NoneType, NoneMonad)): + return FullSliceMonad(monad.translator) + return item + + def contains(monad, item, not_in=False): + translator = monad.translator + expr = monad.translator.JsonContainsExprMonad(monad, item) + if not_in: + sql, = expr.getsql() + expr = translator.JsonBoolExprMonad(translator, bool, ['NOT', sql]) + return expr + + # TODO not_in + # def contains_json(monad, item, not_in=False): + # import ipdb; ipdb.set_trace() + # translator = monad.translator + # parent_sql, = monad.getsql() + # item_sql, = item.getsql() + # if isinstance(item, translator.JsonMixin): + # sql = ['JSON_CONTAINS_JSON', item_sql, parent_sql] + # return translator.JsonBoolExprMonad(monad.translator, bool, sql) + # elif isinstance(item, translator.StringMixin): + # sql = ['JSON_CONTAINS', item_sql, parent_sql] + # return translator.JsonBoolExprMonad(monad.translator, bool, sql) + # else: + # raise TypeError('Invalid JSON key: %s,' % ast2src(item.node)) + + def len(monad): + translator = monad.translator + sql, = monad.getsql() + return translator.NumericExprMonad( + translator, int, ['JSON_ARRAY_LENGTH', sql]) + +class JsonAttrMonad(JsonMixin, AttrMonad): + def __getitem__(monad, key): + key = JsonMixin.__getitem__(monad, key, True) + return monad.translator.JsonItemMonad(monad, [key]) + + @property + def attr_monad(monad): + return monad + + class ParamMonad(Monad): @staticmethod def new(translator, type, paramkey): @@ -1556,6 +1646,7 @@ def new(translator, type, paramkey): elif type is datetime: cls = translator.DatetimeParamMonad elif type is buffer: cls = translator.BufferParamMonad elif type is UUID: cls = translator.UuidParamMonad + elif type is Json: cls = translator.JsonParamMonad elif isinstance(type, EntityMeta): cls = translator.ObjectParamMonad else: throw(NotImplementedError, 'Parameter {EXPR} has unsupported type %r' % (type)) result = cls(translator, type, paramkey) @@ -1598,6 +1689,10 @@ class DatetimeParamMonad(DatetimeMixin, ParamMonad): pass class BufferParamMonad(BufferMixin, ParamMonad): pass class UuidParamMonad(UuidMixin, ParamMonad): pass + +class JsonParamMonad(JsonMixin, ParamMonad): + pass + class ExprMonad(Monad): @staticmethod def new(translator, type, sql): @@ -1625,6 +1720,112 @@ class TimeExprMonad(TimeMixin, ExprMonad): pass class TimedeltaExprMonad(TimedeltaMixin, ExprMonad): pass class DatetimeExprMonad(DatetimeMixin, ExprMonad): pass +class JsonBoolExprMonad(ExprMonad): + pass + +class JsonContainsExprMonad(Monad): + def __init__(monad, json_monad, item): + monad.json_monad = json_monad + monad.item = item + Monad.__init__(monad, json_monad.translator, bool) + monad.attr_sql = json_monad.attr_monad.getsql()[0] + + def getsql(monad): + json_monad = monad.json_monad + path_sql = json_monad._get_path_sql( + getattr(json_monad, 'path', ()) + ) + item_sql, = monad.item.getsql() + return [ + ['JSON_CONTAINS', monad.attr_sql, path_sql, item_sql] + ] + + +class CastFromJsonExprMonad(ExprMonad): + + def __init__(monad, type_to, translator, sql): + monad.type_to = type_to + ExprMonad.__init__(monad, type_to, translator, sql) + + + @classmethod + def dispatch_type(cls, typ): + if issubclass(typ, bool): + return 'boolean' + if issubclass(typ, int): + return 'integer' + if issubclass(typ, float): + return 'real' + + + def getsql(monad): + sql_type = monad.dispatch_type(monad.type_to) + sql = ['CAST', monad.sql, sql_type] + return [sql] + + +class CastToJsonExprMonad(ExprMonad): + + cast_to = 'JSON' + target_monad = None + + def __new__(cls, *args, **kwargs): + kwargs.pop('target_monad', None) + return ExprMonad.__new__(cls, *args, **kwargs) + + def __init__(monad, translator, sql, target_monad=None): + ExprMonad.__init__(monad, translator, Json, sql) + if target_monad: + monad.target_monad = target_monad + + def getsql(monad): + sql = monad.sql + if monad.target_monad: + m = monad.target_monad + if isinstance(m, ConstMonad) and issubclass(m.type, bool): + sql = [ + "RAWSQL", + "'%s'" % ('true' if m.value else 'false') + ] + sql = ['CAST', sql, monad.cast_to] + return [sql] + + +class AbortCast(Exception): + pass + +class JsonExprMonad(JsonMixin, ExprMonad): + pass + +class JsonItemMonad(JsonMixin, Monad): + + allow_get_by_key_syntax = False + + def __init__(monad, attr_monad, path): + translator = attr_monad.translator + monad.attr_monad = attr_monad + monad.path = path + Monad.__init__(monad, translator, Json) + + def __getitem__(monad, key): + key = JsonMixin.__getitem__(monad, key, True) + return monad.translator.JsonItemMonad( + monad.attr_monad, monad.path + [key]) + + def getsql(monad): + base_sql, = monad.attr_monad.getsql() + if monad.allow_get_by_key_syntax and len(monad.path) == 1: + value = monad._get_value(monad.path[0]) + sql = ['JSON_GET', base_sql, value] + return [sql] + path_sql = monad._get_path_sql(monad.path) + sql = ['JSON_GETPATH'] + sql.extend((base_sql, path_sql)) + return [sql] + + def nonzero(monad): + return monad + class ConstMonad(Monad): @staticmethod def new(translator, value): @@ -1661,6 +1862,11 @@ def __init__(monad, translator, value=None): class EllipsisMonad(ConstMonad): pass +class FullSliceMonad(ConstMonad): + SLICE = slice(None, None, None) + def __init__(monad, translator): + ConstMonad.__init__(monad, translator, monad.SLICE) + class BufferConstMonad(BufferMixin, ConstMonad): pass class StringConstMonad(StringMixin, ConstMonad): @@ -1727,8 +1933,59 @@ def __init__(monad, op, left, right): monad.left = left monad.right = right monad.aggregated = getattr(left, 'aggregated', False) or getattr(right, 'aggregated', False) + + if isinstance(left, JsonMixin): + json_monad, other_monad = left, right + elif isinstance(right, JsonMixin): + json_monad, other_monad = right, left + else: + return + + # Customizing comparisons for Json + if op in ('==', '!='): + if isinstance(other_monad, StringConstMonad): + other_monad.value = '"%s"' % right.value + elif isinstance(other_monad, ParamMonad): + other_monad.converter = translator.database.provider \ + .get_converter_by_py_type(Json) + def negate(monad): return monad.translator.CmpMonad(cmp_negate[monad.op], monad.left, monad.right) + + def make_json_cast_if_needed(monad, left_sql, right_sql): + translator = monad.left.translator + is_needed = monad.op in ('<', '>', '==', '!=') and any( + isinstance(m, NumericMixin) for m in (monad.left, monad.right) + ) + if not is_needed: + return left_sql, right_sql + # special handling for boolean constants + if monad.op in ('==', '!='): + if isinstance(monad.left, ConstMonad) and issubclass(monad.left.type, bool): + # FIXME use CastToJson + bool_sql = [ + "RAWSQL", + "'%s'" % ('true' if monad.left.value else 'false') + ] + return [bool_sql], right_sql + if isinstance(monad.right, ConstMonad) and issubclass(monad.right.type, bool): + bool_sql = [ + "RAWSQL", + "'%s'" % ('true' if monad.right.value else 'false') + ] + return left_sql, [bool_sql] + if isinstance(monad.left, JsonMixin): + other_monad = monad.right + expr = translator.CastFromJsonExprMonad( + other_monad.type, translator, left_sql[0] + ) + return expr.getsql(), right_sql + other_monad = monad.left + expr = translator.CastFromJsonExprMonad( + other_monad.type, translator, right_sql[0] + ) + return left_sql, expr.getsql() + def getsql(monad, subquery=None): op = monad.op left_sql = monad.left.getsql() @@ -1737,6 +1994,12 @@ def getsql(monad, subquery=None): if op == 'is not': return [ sqland([ [ 'IS_NOT_NULL', item ] for item in left_sql ]) ] right_sql = monad.right.getsql() + + if any(isinstance(m, JsonMixin) for m in (monad.left, monad.right)): + try: + left_sql, right_sql = monad.make_json_cast_if_needed(left_sql, right_sql) + except AbortCast: + pass if len(left_sql) == 1 and left_sql[0][0] == 'ROW': left_sql = left_sql[0][1:] if len(right_sql) == 1 and right_sql[0][0] == 'ROW': @@ -1780,6 +2043,11 @@ def getsql(monad, subquery=None): result.extend(operand_sql) return [ result ] + +class JsonConcatExprMonad(JsonMixin, ExprMonad): + pass + + class AndMonad(LogicalBinOpMonad): binop = 'AND' diff --git a/pony/orm/tests/test_json/__init__.py b/pony/orm/tests/test_json/__init__.py new file mode 100644 index 000000000..c3b5dd15d --- /dev/null +++ b/pony/orm/tests/test_json/__init__.py @@ -0,0 +1,47 @@ + +from pony.orm import * + +class SetupTest(object): + + E = NotImplemented + + @classmethod + def setUpClass(cls): + cls.bindDb() + cls.prepareDb() + cls.db.generate_mapping(create_tables=True) + + @classmethod + def tearDownClass(cls): + with db_session: + cls.db.execute(""" + drop table e + """) + + @db_session + def tearDown(self): + select(m for m in self.E).delete() + + @classmethod + def prepareDb(cls): + class E(cls.db.Entity): + article = Required(str) + info = Optional(ormtypes.Json) + extra_info = Optional(ormtypes.Json) + zero = Optional(int) + + cls.M = cls.E = E + + @db_session + def setUp(self): + info = [ + 'description', + 4, + {'size': '100x50'}, + ['item1', 'item2', 'smth', 'else'], + ] + extra_info = {'info': ['warranty 1 year', '2 weeks testing']} + self.E(article='A-347', info=info, extra_info=extra_info) + + + diff --git a/pony/orm/tests/test_json/_postgres.py b/pony/orm/tests/test_json/_postgres.py new file mode 100644 index 000000000..7e94e9016 --- /dev/null +++ b/pony/orm/tests/test_json/_postgres.py @@ -0,0 +1,127 @@ +''' +Postgres-specific tests +''' + +import unittest + +from pony.orm import * +from pony.orm.ormtypes import Json +from pony.orm.tests.testutils import raises_exception + +from . import SetupTest + + +class JsonConcatTest(SetupTest, unittest.TestCase): + + @classmethod + def bindDb(cls): + cls.db = Database('postgres', user='postgres', password='postgres', + database='testjson', host='localhost') + + @db_session + def setUp(self): + info = ['description', 4, {'size': '100x50'}] + self.E(article='A-347', info=info, extra_info={'overpriced': True}) + + @db_session + def test_field(self): + result = select(m.info[2] | m.extra_info for m in self.M)[:] + self.assertDictEqual(result[0], {u'overpriced': True, u'size': u'100x50'}) + + @db_session + def test_param(self): + x = 17 + result = select(m.info[2] | {"weight": x} for m in self.M)[:] + self.assertEqual(len(result), 1) + self.assertDictEqual(result[0], {'weight': 17, 'size': '100x50'}) + + @db_session + def test_complex_param(self): + x = {"weight": {'net': 17}} + result = select(m.info[2] | x for m in self.M)[:] + self.assertEqual(len(result), 1) + self.assertDictEqual(result[0], {'weight': {'net': 17}, 'size': '100x50'}) + + @db_session + def test_complex_param_2(self): + x = {'net': 17} + result = select(m.info[2] | {"weight": x} for m in self.M)[:] + self.assertEqual(len(result), 1) + self.assertDictEqual(result[0], {'weight': {'net': 17}, 'size': '100x50'}) + + @db_session + def test_str_const(self): + result = select(m.info[2] | {"weight": 17} for m in self.M)[:] + self.assertEqual(len(result), 1) + self.assertDictEqual(result[0], {'weight': 17, 'size': '100x50'}) + + @db_session + def test_str_param(self): + extra = {"weight": 17} + result = select(m.info[2] | extra for m in self.M)[:] + self.assertEqual(len(result), 1) + self.assertDictEqual(result[0], {'weight': 17, 'size': '100x50'}) + + @raises_exception(Exception) + @db_session + def test_no_json_wrapper(self): + result = select(m.info[2] | '{"weight": 17}' for m in self.M)[:] + self.assertEqual(len(result), 1) + self.assertDictEqual(result[0], {'weight': 17, 'size': '100x50'}) + + +class JsonContainsTest(SetupTest, unittest.TestCase): + + @classmethod + def bindDb(cls): + cls.db = Database('postgres', user='postgres', password='postgres', + database='testjson', host='localhost') + + @db_session + def setUp(self): + info = ['description', 4, {'size': '100x50'}] + self.M(article='A-347', info=info, extra_info={'overpriced': True}) + + @db_session + def test_key_in(self): + result = select('size' in m.info[2] for m in self.M)[:] + self.assertEqual(len(result), 1) + self.assertEqual(result[0], True) + + @db_session + def test_contains(self): + result = select({"size": "100x50"} in m.info[2] for m in self.M)[:] + self.assertEqual(len(result), 1) + self.assertEqual(result[0], True) + + @db_session + def test_contains_param(self): + for size in ['100x50', '200x100']: + result = select({"size": "%s" % size} in m.info[2] for m in self.M)[:] + self.assertEqual(len(result), 1) + self.assertEqual(result[0], size == '100x50') + + @db_session + def test_list(self): + result = select(Json(["description"]) in m.info for m in self.M)[:] + self.assertEqual(len(result), 1) + self.assertEqual(result[0], True) + + @db_session + def test_contains_field(self): + result = select({"size": "100x50"} in m.info[2] for m in self.M)[:] + self.assertEqual(len(result), 1) + self.assertEqual(result[0], True) + + @db_session + def test_inverse_order(self): + result = select(m.info[2] in {"size": "100x50", "weight": 1} for m in self.M)[:] + self.assertEqual(len(result), 1) + self.assertEqual(result[0], True) + + @db_session + def test_with_concat(self): + result = select((m.info[2] | {'weight': 1}) in {"size": "100x50", "weight": 1} + for m in self.M)[:] + self.assertEqual(len(result), 1) + self.assertEqual(result[0], True) diff --git a/pony/orm/tests/test_json/test.py b/pony/orm/tests/test_json/test.py new file mode 100644 index 000000000..accef1e2b --- /dev/null +++ b/pony/orm/tests/test_json/test.py @@ -0,0 +1,609 @@ +# *uses fixtures* + +import unittest + +from pony.orm import * +from pony.orm.tests.testutils import raises_exception +from pony.orm.ormtypes import Json, TrackedValue, TrackedList, TrackedDict + +from contextlib import contextmanager + +import pony.fixtures +from ponytest import with_cli_args + + + +def no_json1_fixture(cls): + if cls.db_provider != 'sqlite': + raise unittest.SkipTest + + cls.no_json1 = True + + @contextmanager + def mgr(): + json1_available = cls.db.provider.json1_available + cls.db.provider.json1_available = False + try: + yield + finally: + cls.db.provider.json1_available = json1_available + + return mgr() + + +no_json1_fixture.class_scoped = True + +import click + + +@contextmanager +def empty_mgr(*args, **kw): + yield + + +@with_cli_args +@click.option('--json1', flag_value=True, default=None) +@click.option('--no-json1', 'json1', flag_value=False) +def json1_cli(json1): + if json1 is None or json1 is True: + yield empty_mgr + if json1 is None or json1 is False: + yield no_json1_fixture + + + +class JsonTest(unittest.TestCase): + in_db_session = False + + @classmethod + def make_entities(cls): + class E(cls.db.Entity): + article = Required(str) + info = Optional(ormtypes.Json) + extra_info = Optional(ormtypes.Json) + zero = Optional(int) + DESCRIPTION = Optional(str, default='description') + + class F(cls.db.Entity): + info = Optional(ormtypes.Json) + + cls.M = cls.E = cls.db.E + + from ponytest import pony_fixtures + pony_fixtures = list(pony_fixtures) + [json1_cli] + + @db_session + def setUp(self): + self.db.execute('delete from %s' % self.db.E._table_) + self.db.execute('delete from %s' % self.db.F._table_) + + info = [ + 'description', + 4, + {'size': '100x50'}, + ['item1', 'item2', 'smth', 'else'], + ] + extra_info = {'info': ['warranty 1 year', '2 weeks testing']} + self.db.E(article='A-347', info=info, extra_info=extra_info) + + + def test_int(self): + Merchandise = self.M + with db_session: + qs = select(b.info[1] for b in Merchandise)[:] + self.assertEqual(qs[0], 4) + + def test(self): + Merchandise = self.M + with db_session: + qs = select(b for b in Merchandise)[:] + self.assertEqual(len(qs), 1) + o = qs[0] + o.info[2]['weight'] = '3 kg' + with db_session: + qs = select(b for b in Merchandise)[:] + self.assertEqual(len(qs), 1) + o = qs[0] + self.assertEqual(o.info[2]['weight'], '3 kg') + + def test_sqlite_sql_inject(self): + # py_json_extract + with db_session: + o = select(m for m in self.M).first() + o.info = {'text' : "3 ' kg"} + with db_session: + o = select(m.info['text'] for m in self.M).first() + # test quote in json is not causing error + + def test_set_list(self): + with db_session: + qs = select(m for m in self.M)[:] + self.assertEqual(len(qs), 1) + o = qs[0] + o.info[2] = ['some', 'list'] + with db_session: + val = select(m.info[2] for m in self.M).first() + self.assertListEqual(val, ['some', 'list']) + + def test_getitem_int(self): + Merchandise = self.M + with db_session: + qs = select(b.info[0] for b in Merchandise)[:] + self.assertEqual(qs[0], 'description') + + def test_getitem_str(self): + Merchandise = self.M + with db_session: + qs = select(b.info[2]['size'] for b in Merchandise)[:] + self.assertEqual(qs[0], '100x50') + + @db_session + def test_delete_str(self): + if self.db_provider == 'oracle' or getattr(self, 'no_json1', False): + raise unittest.SkipTest + def g(): + for m in self.M: + yield m.info[2] - 'size' + val = select(g()).first() + self.assertDictEqual(val, {}) + + @raises_exception(TypeError) # only constants are supported + @db_session + def test_delete_field(self): + if self.db_provider == 'oracle' or getattr(self, 'no_json1', False): + raise unittest.SkipTest + qs = select(m.info - m.zero for m in self.M)[:] + self.assertEqual(len(qs), 1) + val = qs[0] + self.assertEqual(val[0], 4) + + @db_session + def test_delete_path(self): + if self.db_provider== 'oracle' or getattr(self, 'no_json1', False): + raise unittest.SkipTest + val = select(m.info - [2, 'size'] for m in self.M).first() + self.assertDictEqual(val[2], {}) + + # JSON length + + @db_session + def test_len(self): + if self.db_provider == 'oracle' or getattr(self, 'no_json1', False): + raise unittest.SkipTest + g = (len(m.info) for m in self.M) + val = select(g).first() + self.assertEqual(val, 4) + + @db_session + def test_item_len(self): + if self.db_provider == 'oracle' or getattr(self, 'no_json1', False): + raise unittest.SkipTest + g = (len(m.info[3]) for m in self.M) + val = select(g).first() + self.assertEqual(val, 4) + + # Tracked attribute + + def test_tracked_attr(self): + with db_session: + val = select(m for m in self.M).first() + val.info = val.extra_info['info'] + self.assertIsInstance(val.info, TrackedValue) + with db_session: + o = select(m for m in self.M).first() + self.assertListEqual(o.info, o.extra_info['info']) + + @db_session + def test_tracked_attr_type(self): + val = select(m.extra_info['info'] for m in self.M).first() + self.assertEqual(type(val), list) + o = select(m for m in self.M).first() + self.assertEqual(type(o.extra_info), TrackedDict) + self.assertEqual(type(o.extra_info['info']), TrackedList) + + def test_tracked_del(self): + with db_session: + d = select(m for m in self.M).first() + del d.info[2]['size'] + with db_session: + d = select(m.info[2] for m in self.M).first() + self.assertDictEqual(d, {}) + + # # Json equality + + @db_session + def test_equal_str(self): + g = (m.info[1] for m in self.M if m.info[0] == 'description') + val = select(g).first() + self.assertTrue(val) + + @db_session + def test_equal_string_attr(self): + if self.db_provider == 'oracle': + raise unittest.SkipTest + g = (m.info[1] for m in self.M if m.info[0] == m.DESCRIPTION) + val = select(g).first() + self.assertTrue(val) + + @db_session + def test_equal_param(self): + if self.db_provider == 'oracle': + raise unittest.SkipTest + x = 'description' + g = (m.info[1] for m in self.M if m.info[0] == x) + val = select(g).first() + self.assertTrue(val) + + @db_session + def test_computed_param(self): + index = 2 + key = 'size' + qs = select(b.info[index][key] for b in self.db.E)[:] + self.assertEqual(qs[0], '100x50') + + + @db_session + def test_equal_json(self): + if self.db_provider == 'oracle': + raise unittest.SkipTest + g = (m.info[2] for m in self.M if m.info[2] == {"size":"100x50"}) + val = select(g).first() + self.assertTrue(val) + + @db_session + def test_ne_json(self): + if self.db_provider == 'oracle': + raise unittest.SkipTest + g = (m.info[2] for m in self.M if m.info[2] != {"size":"200x50"}) + val = select(g).first() + self.assertTrue(val) + g = (m.info[2] for m in self.M if m.info[2] != {"size":"100x50"}) + val = select(g).first() + self.assertFalse(val) + + def test_equal_attr(self): + if self.db_provider == 'oracle': + raise unittest.SkipTest + with db_session: + e = select(e for e in self.db.E).first() + f = self.db.F(info=e.info[2]) + with db_session: + g = (e.info[2] + for e in self.db.E for f in self.db.F + if e.info[2] == f.info) + val = select(g).first() + self.assertTrue(val) + + @db_session + def test_equal_list(self): + if self.db_provider == 'oracle': + raise unittest.SkipTest + li = ['item1', 'item2', 'smth', 'else'] + self.assertTrue( + get(m for m in self.M if m.info[3] == Json(li)) + ) + + @db_session + def test_dbval2val(self): + with db_session: + obj = select(e for e in self.E)[:][0] + self.assertIsInstance(obj.info, TrackedValue) + obj.info[3][0] = 'trash' + with db_session: + obj = select(e for e in self.E)[:][0] + dbval = obj._dbvals_[self.E.info] + val = obj._vals_[self.E.info] + self.assertIn('trash', str(dbval)) + self.assertIsInstance(dbval, str) + self.assertIsInstance(val, TrackedValue) + + @db_session + def test_starred_path1(self): + if self.db_provider not in ['mysql', 'oracle']: + raise unittest.SkipTest('* in path is not supported by %s' % self.db_provider) + g = select(e.info[:][...] for e in self.E) + for val in g[:]: + self.assertListEqual(val, ['100x50']) + + @db_session + def test_starred_gen_as_string(self): + if self.db_provider not in ['mysql', 'oracle']: + raise unittest.SkipTest('* in path is not supported by %s' % self.db_provider) + g = select('e.info[:][...] for e in self.E') + for val in g[:]: + self.assertListEqual(val, ['100x50']) + + @db_session + def test_starred_path2(self): + if self.db_provider not in ['mysql', 'oracle']: + raise unittest.SkipTest('* in path is not supported by %s' % self.db_provider) + g = select(e.extra_info[...][0] for e in self.E) + for val in g[:]: + self.assertListEqual(val, ['warranty 1 year']) + + ##### 'key' in json + + @db_session + def test_in_dict(self): + obj = select( + m.info[2]['size'] for m in self.M if 'size' in m.info[2] + ).first() + self.assertTrue(obj) + + @db_session + def test_not_in_dict(self): + obj = select( + m.info for m in self.M if 'size' not in m.info[2] + ).first() + self.assertEqual(obj, None) + obj = select( + m.info for m in self.M if 'siz' not in m.info[2] + ).first() + self.assertTrue(obj) + + @db_session + def test_in_list(self): + obj = select( + m.info[3] for m in self.M if 'item1' in m.info[3] + ).first() + self.assertTrue(obj) + obj = select( + m.info for m in self.M if 'description' in m.info + ).first() + self.assertTrue(obj) + + @db_session + def test_not_in_list(self): + obj = select( + m.info[3] for m in self.M if 'item1' not in m.info[3] + ).first() + self.assertEqual(obj, None) + obj = select( + m.info[3] for m in self.M if 'ite' not in m.info[3] + ).first() + self.assertIn('item1', obj) + + @db_session + def test_var_in_json(self): + if self.db_provider in ('mysql', 'oracle'): + if_implemented = lambda: self.assertRaises(NotImplementedError) + else: + @contextmanager + def if_implemented(): + yield + with if_implemented(): + key = 'item1' + obj = select( + m.info[3] for m in self.M if key in m.info[3] + ).first() + self.assertTrue(obj) + + @db_session + def test_get_json_attr(self): + ''' query should not contain distinct + ''' + if self.db_provider != 'oracle': + raise unittest.SkipTest + obj = get( + m.info for m in self.M + ) + self.assertTrue(obj) + + @db_session + def test_select_first(self): + ''' query shoud not contain ORDER BY + ''' + if self.db_provider != 'oracle': + raise unittest.SkipTest + obj = select( + m.info for m in self.M + ).first() + self.assertTrue(obj) + + def test_in_json_regexp(self): + if self.db_provider != 'oracle': + raise unittest.SkipTest + import re + from pony.orm.dbproviders.oracle import search_in_json_list_regexp + regexp = search_in_json_list_regexp('item') + pos = [ + '["item"]', + '[0, "item"]', + '[{}, "item", []]', + '[{"a": 1}, "item", []]', + '[false, "item", "erg"]', + ] + for s in pos: + self.assertTrue(re.search(regexp, s)) + neg = [ + '[["item"]]', + '[{"item": 0]]', + '["1 item", "item 1"]', + '[0, " "]', + '[]', + ] + for s in neg: + self.assertFalse(re.search(regexp, s)) + + +class TestDataTypes(unittest.TestCase): + + in_db_session = False + + from ponytest import pony_fixtures + pony_fixtures = list(pony_fixtures) + [json1_cli] + + @classmethod + def make_entities(cls): + class Data(cls.db.Entity): + data = Optional(Json) + + @db_session + def setUp(self): + self.db.execute('delete from %s' % self.db.Data._table_) + + + def test_int(self): + + db = self.db + with db_session: + db.Data(data={'val': 1}) + + with db_session: + obj = get(d for d in db.Data if d.data['val'] == 1) + self.assertEqual(obj.data['val'], 1) + + def test_compare_int(self): + db = self.db + with db_session: + db.Data(data={'val': 3}) + + with db_session: + self.assertTrue( + get(d for d in db.Data if d.data['val'] > 2) + ) + self.assertTrue( + get(d for d in db.Data if d.data['val'] < 4) + ) + + def test_str(self): + db = self.db + with db_session: + db.Data(data={'val': "1"}) + + with db_session: + obj = get(d for d in db.Data if d.data['val'] == '1') + self.assertTrue(obj) + + def test_none(self): + db = self.db + with db_session: + db.Data() + + with db_session: + data = get(d for d in db.Data if d.data is None) + self.assertTrue(data) + + # def test_is_null(self): + # db = self.db + # with db_session: + # db.Data(data={'val': None}) + + # with db_session: + # data = get(d for d in db.Data if d.data['val'] is None) + # self.assertTrue(data) + + # def test_eq_null(self): + # db = self.db + # with db_session: + # db.Data(data={'val': None}) + + # with db_session: + # data = get(d for d in db.Data if d.data['val'] == None) + # self.assertTrue(data) + + def test_bool(self): + with db_session: + self.db.Data(data={'val': True, 'id': 1}) + self.db.Data(data={'val': False, 'id': 2}) + + with db_session: + val = get( + d.data['id'] for d in self.db.Data + if d.data['val'] == False + ) + self.assertEqual(val, 2) + val = get( + d.data['id'] for d in self.db.Data + if d.data['val'] == True + ) + self.assertEqual(val, 1) + + def test_nonzero(self): + with db_session: + self.db.Data(data={'val': True, 'id': 1}) + self.db.Data(data={'val': False, 'id': 2}) + self.db.Data(data={'val': 0, 'id': 3}) + self.db.Data(data={'val': '', 'id': 4}) + self.db.Data(data={'id': 5}) + + if self.db_provider == 'oracle': + assert_raises = lambda: self.assertRaises(NotImplementedError) + else: + @contextmanager + def assert_raises(): + yield + + with db_session, assert_raises(): + val = get( + d.data['id'] for d in self.db.Data + if d.data['val'] + ) + self.assertEqual(val, 1) + + + def test_float(self): + with db_session: + self.db.Data(data={'val': 3.14}) + + with db_session: + val = get(d.data['val'] for d in self.db.Data) + self.assertIsInstance(val, float) + + def test_compare_float(self): + with db_session: + self.db.Data(data={'val': 3.14}) + with db_session: + val = get( + d.data['val'] for d in self.db.Data + if d.data['val'] < 3.15 + ) + self.assertIsInstance(val, float) + + + +# from ._postgres import JsonConcatTest #, JsonContainsTest # TODO + + +class TestSqliteFallback(unittest.TestCase): + + from ponytest import pony_fixtures + pony_fixtures = list(pony_fixtures) + [ + [no_json1_fixture] + ] + + @classmethod + def make_entities(cls): + class Person(cls.db.Entity): + name = Required(str) + data = Optional(Json) + + + def setUp(self): + self.db.execute('delete from %s' % self.db.Person._table_) + + + def test(self): + Person = self.db.Person + with db_session: + Person(name='John') + Person(name='Mike', data=dict(a=1,b=2)) + with db_session: + p = Person[1] + p.data = dict(c=[2, 3, 4], d='d') + p = Person[2] + p.data['c'] = [1, 2, 3] + + qs = select(p for p in Person if p.data['c'][1] == 2) + self.assertEqual(qs.count(), 1) + + + def test_cmp(self): + Person = self.db.Person + with db_session: + Person(name='Mike', data=[4]) + with db_session: + qs = select(p for p in Person if p.data[0] < 5) + self.assertEqual(qs.count(), 1) + qs = select(p for p in Person if p.data[0] > 3) + self.assertEqual(qs.count(), 1) \ No newline at end of file diff --git a/pony/utils/__init__.py b/pony/utils/__init__.py new file mode 100644 index 000000000..83f8d4248 --- /dev/null +++ b/pony/utils/__init__.py @@ -0,0 +1,4 @@ + + +from .utils import * +from .properties import * \ No newline at end of file diff --git a/pony/utils/properties.py b/pony/utils/properties.py new file mode 100644 index 000000000..eedccd7d5 --- /dev/null +++ b/pony/utils/properties.py @@ -0,0 +1,40 @@ + + +class cached_property(object): + """ + A property that is only computed once per instance and then replaces itself + with an ordinary attribute. Deleting the attribute resets the property. + Source: https://github.com/bottlepy/bottle/commit/fa7733e075da0d790d809aa3d2f53071897e6f76 + """ # noqa + + def __init__(self, func): + self.__doc__ = getattr(func, '__doc__') + self.func = func + + def __get__(self, obj, cls): + if obj is None: + return self + value = obj.__dict__[self.func.__name__] = self.func(obj) + return value + + +class class_property(object): + """ + Read-only class property + """ + + def __init__(self, func): + self.func = func + + def __get__(self, instance, cls): + return self.func(cls) + +class class_cached_property(object): + + def __init__(self, func): + self.func = func + + def __get__(self, obj, cls): + value = self.func(cls) + setattr(cls, self.func.__name__, value) + return value \ No newline at end of file diff --git a/pony/utils.py b/pony/utils/utils.py similarity index 97% rename from pony/utils.py rename to pony/utils/utils.py index a6475e790..f6a04e15a 100644 --- a/pony/utils.py +++ b/pony/utils/utils.py @@ -33,6 +33,7 @@ def _deepcopy_method(x, memo): if pony.MODE.startswith('GAE-'): localbase = object else: from threading import local as localbase + class PonyDeprecationWarning(DeprecationWarning): pass @@ -499,4 +500,4 @@ def concat(*args): return ''.join(tostring(arg) for arg in args) def is_utf8(encoding): - return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') + return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') \ No newline at end of file From 1f0ba7d50baf8ae6c241046767f810457397ccc1 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 18 Aug 2016 14:09:37 +0300 Subject: [PATCH 023/547] Explain why OraJsonConverter is optimistic --- pony/orm/dbproviders/oracle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index a13d474eb..b07a39e72 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -362,7 +362,7 @@ def sql_type(converter): return 'RAW(16)' class OraJsonConverter(dbapiprovider.JsonConverter): - optimistic = False + optimistic = False # CLOBs cannot be compared with strings, and TO_CHAR(CLOB) returns first 4000 chars only def sql2py(converter, dbval): if hasattr(dbval, 'read'): dbval = dbval.read() return dbapiprovider.JsonConverter.sql2py(converter, dbval) From 64fe3ba9febe4259d8c30b68e4564291e1719a43 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 29 Jul 2016 15:17:56 +0300 Subject: [PATCH 024/547] Update tests --- pony/orm/tests/test_json/__init__.py | 47 -- pony/orm/tests/test_json/test.py | 974 ++++++++++++++------------- pony/orm/tests/testutils.py | 22 + 3 files changed, 543 insertions(+), 500 deletions(-) diff --git a/pony/orm/tests/test_json/__init__.py b/pony/orm/tests/test_json/__init__.py index c3b5dd15d..e69de29bb 100644 --- a/pony/orm/tests/test_json/__init__.py +++ b/pony/orm/tests/test_json/__init__.py @@ -1,47 +0,0 @@ - -from pony.orm import * - -class SetupTest(object): - - E = NotImplemented - - @classmethod - def setUpClass(cls): - cls.bindDb() - cls.prepareDb() - cls.db.generate_mapping(create_tables=True) - - @classmethod - def tearDownClass(cls): - with db_session: - cls.db.execute(""" - drop table e - """) - - @db_session - def tearDown(self): - select(m for m in self.E).delete() - - @classmethod - def prepareDb(cls): - class E(cls.db.Entity): - article = Required(str) - info = Optional(ormtypes.Json) - extra_info = Optional(ormtypes.Json) - zero = Optional(int) - - cls.M = cls.E = E - - @db_session - def setUp(self): - info = [ - 'description', - 4, - {'size': '100x50'}, - ['item1', 'item2', 'smth', 'else'], - ] - extra_info = {'info': ['warranty 1 year', '2 weeks testing']} - self.E(article='A-347', info=info, extra_info=extra_info) - - - diff --git a/pony/orm/tests/test_json/test.py b/pony/orm/tests/test_json/test.py index accef1e2b..f6982f24e 100644 --- a/pony/orm/tests/test_json/test.py +++ b/pony/orm/tests/test_json/test.py @@ -2,8 +2,10 @@ import unittest +import click + from pony.orm import * -from pony.orm.tests.testutils import raises_exception +from pony.orm.tests.testutils import raises_exception, raises_if from pony.orm.ormtypes import Json, TrackedValue, TrackedList, TrackedDict from contextlib import contextmanager @@ -12,7 +14,6 @@ from ponytest import with_cli_args - def no_json1_fixture(cls): if cls.db_provider != 'sqlite': raise unittest.SkipTest @@ -30,11 +31,8 @@ def mgr(): return mgr() - no_json1_fixture.class_scoped = True -import click - @contextmanager def empty_mgr(*args, **kw): @@ -51,559 +49,629 @@ def json1_cli(json1): yield no_json1_fixture - -class JsonTest(unittest.TestCase): +class TestJson(unittest.TestCase): in_db_session = False @classmethod def make_entities(cls): - class E(cls.db.Entity): - article = Required(str) - info = Optional(ormtypes.Json) - extra_info = Optional(ormtypes.Json) - zero = Optional(int) - DESCRIPTION = Optional(str, default='description') - - class F(cls.db.Entity): - info = Optional(ormtypes.Json) + class Product(cls.db.Entity): + name = Required(str) + info = Optional(Json) + tags = Optional(Json) - cls.M = cls.E = cls.db.E + cls.Product = cls.db.Product from ponytest import pony_fixtures pony_fixtures = list(pony_fixtures) + [json1_cli] @db_session def setUp(self): - self.db.execute('delete from %s' % self.db.E._table_) - self.db.execute('delete from %s' % self.db.F._table_) + self.db.execute('delete from %s' % self.db.Product._table_) + + self.Product( + name='Apple iPad Air 2', + info={ + 'name': 'Apple iPad Air 2', + 'display': { + 'size': 9.7, + 'resolution': [2048, 1536], + 'matrix-type': 'IPS', + 'multi-touch': True + }, + 'os': { + 'type': 'iOS', + 'version': '8' + }, + 'cpu': 'Apple A8X', + 'ram': '8GB', + 'colors': ['Gold', 'Silver', 'Space Gray'], + 'models': [ + { + 'name': 'Wi-Fi', + 'capacity': ['16GB', '64GB'], + 'height': 240, + 'width': 169.5, + 'depth': 6.1, + 'weight': 437, + }, + { + 'name': 'Wi-Fi + Cellular', + 'capacity': ['16GB', '64GB'], + 'height': 240, + 'width': 169.5, + 'depth': 6.1, + 'weight': 444, + }, + ], + 'discontinued': False, + 'videoUrl': None, + 'non-ascii-attr': u'\u0442\u0435\u0441\u0442' + }, + tags=['Tablets', 'Apple', 'Retina']) - info = [ - 'description', - 4, - {'size': '100x50'}, - ['item1', 'item2', 'smth', 'else'], - ] - extra_info = {'info': ['warranty 1 year', '2 weeks testing']} - self.db.E(article='A-347', info=info, extra_info=extra_info) - - def test_int(self): - Merchandise = self.M + def test(self): + with db_session: + result = select(p for p in self.Product)[:] + self.assertEqual(len(result), 1) + p = result[0] + p.info['os']['version'] = '9' with db_session: - qs = select(b.info[1] for b in Merchandise)[:] - self.assertEqual(qs[0], 4) + result = select(p for p in self.Product)[:] + self.assertEqual(len(result), 1) + p = result[0] + self.assertEqual(p.info['os']['version'], '9') - def test(self): - Merchandise = self.M + @db_session + def test_query_int(self): + val = get(p.info['display']['resolution'][0] for p in self.Product) + self.assertEqual(val, 2048) + + @db_session + def test_query_float(self): + val = get(p.info['display']['size'] for p in self.Product) + self.assertAlmostEqual(val, 9.7) + + @db_session + def test_query_true(self): + val = get(p.info['display']['multi-touch'] for p in self.Product) + self.assertIs(val, True) + + @db_session + def test_query_false(self): + val = get(p.info['discontinued'] for p in self.Product) + self.assertIs(val, False) + + @db_session + def test_query_null(self): + val = get(p.info['videoUrl'] for p in self.Product) + self.assertIs(val, None) + + @db_session + def test_query_list(self): + val = get(p.info['colors'] for p in self.Product) + self.assertListEqual(val, ['Gold', 'Silver', 'Space Gray']) + self.assertNotIsInstance(val, TrackedValue) + + @db_session + def test_query_dict(self): + val = get(p.info['display'] for p in self.Product) + self.assertDictEqual(val, { + 'size': 9.7, + 'resolution': [2048, 1536], + 'matrix-type': 'IPS', + 'multi-touch': True + }) + self.assertNotIsInstance(val, TrackedValue) + + @db_session + def test_query_json_field(self): + val = get(p.info for p in self.Product) + self.assertDictEqual(val['display'], { + 'size': 9.7, + 'resolution': [2048, 1536], + 'matrix-type': 'IPS', + 'multi-touch': True + }) + self.assertNotIsInstance(val['display'], TrackedDict) + val = get(p.tags for p in self.Product) + self.assertListEqual(val, ['Tablets', 'Apple', 'Retina']) + self.assertNotIsInstance(val, TrackedList) + + @db_session + def test_get_object(self): + p = get(p for p in self.Product) + self.assertDictEqual(p.info['display'], { + 'size': 9.7, + 'resolution': [2048, 1536], + 'matrix-type': 'IPS', + 'multi-touch': True + }) + self.assertEqual(p.info['discontinued'], False) + self.assertEqual(p.info['videoUrl'], None) + self.assertListEqual(p.tags, ['Tablets', 'Apple', 'Retina']) + self.assertIsInstance(p.info, TrackedDict) + self.assertIsInstance(p.info['display'], TrackedDict) + self.assertIsInstance(p.info['colors'], TrackedList) + self.assertIsInstance(p.tags, TrackedList) + + def test_set_str(self): with db_session: - qs = select(b for b in Merchandise)[:] - self.assertEqual(len(qs), 1) - o = qs[0] - o.info[2]['weight'] = '3 kg' + p = get(p for p in self.Product) + p.info['os']['version'] = '9' with db_session: - qs = select(b for b in Merchandise)[:] - self.assertEqual(len(qs), 1) - o = qs[0] - self.assertEqual(o.info[2]['weight'], '3 kg') + p = get(p for p in self.Product) + self.assertEqual(p.info['os']['version'], '9') - def test_sqlite_sql_inject(self): - # py_json_extract + def test_set_int(self): with db_session: - o = select(m for m in self.M).first() - o.info = {'text' : "3 ' kg"} + p = get(p for p in self.Product) + p.info['display']['resolution'][0] += 1 with db_session: - o = select(m.info['text'] for m in self.M).first() - # test quote in json is not causing error + p = get(p for p in self.Product) + self.assertEqual(p.info['display']['resolution'][0], 2049) - def test_set_list(self): + def test_set_true(self): with db_session: - qs = select(m for m in self.M)[:] - self.assertEqual(len(qs), 1) - o = qs[0] - o.info[2] = ['some', 'list'] + p = get(p for p in self.Product) + p.info['discontinued'] = True with db_session: - val = select(m.info[2] for m in self.M).first() - self.assertListEqual(val, ['some', 'list']) + p = get(p for p in self.Product) + self.assertIs(p.info['discontinued'], True) - def test_getitem_int(self): - Merchandise = self.M + def test_set_false(self): + with db_session: + p = get(p for p in self.Product) + p.info['display']['multi-touch'] = False with db_session: - qs = select(b.info[0] for b in Merchandise)[:] - self.assertEqual(qs[0], 'description') + p = get(p for p in self.Product) + self.assertIs(p.info['display']['multi-touch'], False) - def test_getitem_str(self): - Merchandise = self.M + def test_set_null(self): with db_session: - qs = select(b.info[2]['size'] for b in Merchandise)[:] - self.assertEqual(qs[0], '100x50') + p = get(p for p in self.Product) + p.info['display'] = None + with db_session: + p = get(p for p in self.Product) + self.assertIs(p.info['display'], None) - @db_session - def test_delete_str(self): - if self.db_provider == 'oracle' or getattr(self, 'no_json1', False): - raise unittest.SkipTest - def g(): - for m in self.M: - yield m.info[2] - 'size' - val = select(g()).first() - self.assertDictEqual(val, {}) + def test_set_list(self): + with db_session: + p = get(p for p in self.Product) + p.info['colors'] = ['Pink', 'Black'] + with db_session: + p = get(p for p in self.Product) + self.assertListEqual(p.info['colors'], ['Pink', 'Black']) - @raises_exception(TypeError) # only constants are supported - @db_session - def test_delete_field(self): - if self.db_provider == 'oracle' or getattr(self, 'no_json1', False): - raise unittest.SkipTest - qs = select(m.info - m.zero for m in self.M)[:] - self.assertEqual(len(qs), 1) - val = qs[0] - self.assertEqual(val[0], 4) + def test_list_del(self): + with db_session: + p = get(p for p in self.Product) + del p.info['colors'][1] + with db_session: + p = get(p for p in self.Product) + self.assertListEqual(p.info['colors'], ['Gold', 'Space Gray']) - @db_session - def test_delete_path(self): - if self.db_provider== 'oracle' or getattr(self, 'no_json1', False): - raise unittest.SkipTest - val = select(m.info - [2, 'size'] for m in self.M).first() - self.assertDictEqual(val[2], {}) + def test_list_append(self): + with db_session: + p = get(p for p in self.Product) + p.info['colors'].append('White') + with db_session: + p = get(p for p in self.Product) + self.assertListEqual(p.info['colors'], ['Gold', 'Silver', 'Space Gray', 'White']) - # JSON length + def test_list_set_slice(self): + with db_session: + p = get(p for p in self.Product) + p.info['colors'][1:] = ['White'] + with db_session: + p = get(p for p in self.Product) + self.assertListEqual(p.info['colors'], ['Gold', 'White']) - @db_session - def test_len(self): - if self.db_provider == 'oracle' or getattr(self, 'no_json1', False): - raise unittest.SkipTest - g = (len(m.info) for m in self.M) - val = select(g).first() - self.assertEqual(val, 4) + def test_list_set_item(self): + with db_session: + p = get(p for p in self.Product) + p.info['colors'][1] = 'White' + with db_session: + p = get(p for p in self.Product) + self.assertListEqual(p.info['colors'], ['Gold', 'White', 'Space Gray']) - @db_session - def test_item_len(self): - if self.db_provider == 'oracle' or getattr(self, 'no_json1', False): - raise unittest.SkipTest - g = (len(m.info[3]) for m in self.M) - val = select(g).first() - self.assertEqual(val, 4) + def test_set_dict(self): + with db_session: + p = get(p for p in self.Product) + p.info['display']['resolution'] = {'width': 2048, 'height': 1536} + with db_session: + p = get(p for p in self.Product) + self.assertDictEqual(p.info['display']['resolution'], {'width': 2048, 'height': 1536}) - # Tracked attribute + def test_dict_del(self): + with db_session: + p = get(p for p in self.Product) + del p.info['os']['version'] + with db_session: + p = get(p for p in self.Product) + self.assertDictEqual(p.info['os'], {'type': 'iOS'}) - def test_tracked_attr(self): + def test_dict_pop(self): with db_session: - val = select(m for m in self.M).first() - val.info = val.extra_info['info'] - self.assertIsInstance(val.info, TrackedValue) + p = get(p for p in self.Product) + p.info['os'].pop('version') with db_session: - o = select(m for m in self.M).first() - self.assertListEqual(o.info, o.extra_info['info']) + p = get(p for p in self.Product) + self.assertDictEqual(p.info['os'], {'type': 'iOS'}) - @db_session - def test_tracked_attr_type(self): - val = select(m.extra_info['info'] for m in self.M).first() - self.assertEqual(type(val), list) - o = select(m for m in self.M).first() - self.assertEqual(type(o.extra_info), TrackedDict) - self.assertEqual(type(o.extra_info['info']), TrackedList) + def test_dict_update(self): + with db_session: + p = get(p for p in self.Product) + p.info['os'].update(version='9') + with db_session: + p = get(p for p in self.Product) + self.assertDictEqual(p.info['os'], {'type': 'iOS', 'version': '9'}) - def test_tracked_del(self): + def test_dict_set_item(self): with db_session: - d = select(m for m in self.M).first() - del d.info[2]['size'] + p = get(p for p in self.Product) + p.info['os']['version'] = '9' with db_session: - d = select(m.info[2] for m in self.M).first() - self.assertDictEqual(d, {}) + p = get(p for p in self.Product) + self.assertDictEqual(p.info['os'], {'type': 'iOS', 'version': '9'}) + + # JSON length + + @db_session + def test_len(self): + with raises_if(self, self.db.provider.dialect == 'Oracle', + TranslationError, 'Oracle does not provide `length` function for JSON arrays'): + val = select(len(p.tags) for p in self.Product).first() + self.assertEqual(val, 3) + val = select(len(p.info['colors']) for p in self.Product).first() + self.assertEqual(val, 3) # # Json equality @db_session def test_equal_str(self): - g = (m.info[1] for m in self.M if m.info[0] == 'description') - val = select(g).first() - self.assertTrue(val) + p = get(p for p in self.Product if p.info['name'] == 'Apple iPad Air 2') + self.assertTrue(p) @db_session def test_equal_string_attr(self): - if self.db_provider == 'oracle': - raise unittest.SkipTest - g = (m.info[1] for m in self.M if m.info[0] == m.DESCRIPTION) - val = select(g).first() - self.assertTrue(val) + p = get(p for p in self.Product if p.info['name'] == p.name) + self.assertTrue(p) @db_session def test_equal_param(self): - if self.db_provider == 'oracle': - raise unittest.SkipTest - x = 'description' - g = (m.info[1] for m in self.M if m.info[0] == x) - val = select(g).first() - self.assertTrue(val) + x = 'Apple iPad Air 2' + p = get(p for p in self.Product if p.name == x) + self.assertTrue(p) @db_session - def test_computed_param(self): - index = 2 - key = 'size' - qs = select(b.info[index][key] for b in self.db.E)[:] - self.assertEqual(qs[0], '100x50') - + def test_composite_param(self): + with raises_if(self, self.db.provider.dialect == 'Oracle', + TranslationError, "Oracle doesn't allow parameters in JSON paths"): + key = 'models' + index = 0 + val = get(p.info[key][index]['name'] for p in self.Product) + self.assertEqual(val, 'Wi-Fi') @db_session - def test_equal_json(self): - if self.db_provider == 'oracle': - raise unittest.SkipTest - g = (m.info[2] for m in self.M if m.info[2] == {"size":"100x50"}) - val = select(g).first() - self.assertTrue(val) + def test_composite_param_in_condition(self): + with raises_if(self, self.db.provider.dialect == 'Oracle', + TranslationError, "Oracle doesn't allow parameters in JSON paths"): + key = 'models' + index = 0 + p = get(p for p in self.Product if p.info[key][index]['name'] == 'Wi-Fi') + self.assertIsNotNone(p) @db_session - def test_ne_json(self): - if self.db_provider == 'oracle': - raise unittest.SkipTest - g = (m.info[2] for m in self.M if m.info[2] != {"size":"200x50"}) - val = select(g).first() - self.assertTrue(val) - g = (m.info[2] for m in self.M if m.info[2] != {"size":"100x50"}) - val = select(g).first() - self.assertFalse(val) - - def test_equal_attr(self): - if self.db_provider == 'oracle': - raise unittest.SkipTest - with db_session: - e = select(e for e in self.db.E).first() - f = self.db.F(info=e.info[2]) - with db_session: - g = (e.info[2] - for e in self.db.E for f in self.db.F - if e.info[2] == f.info) - val = select(g).first() - self.assertTrue(val) + def test_equal_json_1(self): + with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: " + "p.info['os'] == {'type':'iOS', 'version':'8'}"): + p = get(p for p in self.Product if p.info['os'] == {'type': 'iOS', 'version': '8'}) + self.assertTrue(p) @db_session - def test_equal_list(self): - if self.db_provider == 'oracle': - raise unittest.SkipTest - li = ['item1', 'item2', 'smth', 'else'] - self.assertTrue( - get(m for m in self.M if m.info[3] == Json(li)) - ) + def test_equal_json_2(self): + with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: " + "p.info['os'] == Json({'type':'iOS', 'version':'8'})"): + p = get(p for p in self.Product if p.info['os'] == Json({'type': 'iOS', 'version': '8'})) + self.assertTrue(p) @db_session - def test_dbval2val(self): - with db_session: - obj = select(e for e in self.E)[:][0] - self.assertIsInstance(obj.info, TrackedValue) - obj.info[3][0] = 'trash' - with db_session: - obj = select(e for e in self.E)[:][0] - dbval = obj._dbvals_[self.E.info] - val = obj._vals_[self.E.info] - self.assertIn('trash', str(dbval)) - self.assertIsInstance(dbval, str) - self.assertIsInstance(val, TrackedValue) + def test_ne_json_1(self): + with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['os'] != {}"): + p = get(p for p in self.Product if p.info['os'] != {}) + self.assertTrue(p) + p = get(p for p in self.Product if p.info['os'] != {'type': 'iOS', 'version': '8'}) + self.assertFalse(p) @db_session - def test_starred_path1(self): - if self.db_provider not in ['mysql', 'oracle']: - raise unittest.SkipTest('* in path is not supported by %s' % self.db_provider) - g = select(e.info[:][...] for e in self.E) - for val in g[:]: - self.assertListEqual(val, ['100x50']) + def test_ne_json_2(self): + with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['os'] != Json({})"): + p = get(p for p in self.Product if p.info['os'] != Json({})) + self.assertTrue(p) + p = get(p for p in self.Product if p.info['os'] != {'type': 'iOS', 'version': '8'}) + self.assertFalse(p) @db_session - def test_starred_gen_as_string(self): - if self.db_provider not in ['mysql', 'oracle']: - raise unittest.SkipTest('* in path is not supported by %s' % self.db_provider) - g = select('e.info[:][...] for e in self.E') - for val in g[:]: - self.assertListEqual(val, ['100x50']) + def test_equal_list_1(self): + colors = ['Gold', 'Silver', 'Space Gray'] + with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['colors'] == Json(colors)"): + p = get(p for p in self.Product if p.info['colors'] == Json(colors)) + self.assertTrue(p) @db_session - def test_starred_path2(self): - if self.db_provider not in ['mysql', 'oracle']: - raise unittest.SkipTest('* in path is not supported by %s' % self.db_provider) - g = select(e.extra_info[...][0] for e in self.E) - for val in g[:]: - self.assertListEqual(val, ['warranty 1 year']) - - ##### 'key' in json + @raises_exception(TypeError, "Incomparable types 'Json' and 'list' in expression: p.info['colors'] == ['Gold']") + def test_equal_list_2(self): + p = get(p for p in self.Product if p.info['colors'] == ['Gold']) @db_session - def test_in_dict(self): - obj = select( - m.info[2]['size'] for m in self.M if 'size' in m.info[2] - ).first() - self.assertTrue(obj) + def test_equal_list_3(self): + with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['colors'] != Json(['Gold'])"): + p = get(p for p in self.Product if p.info['colors'] != Json(['Gold'])) + self.assertIsNotNone(p) @db_session - def test_not_in_dict(self): - obj = select( - m.info for m in self.M if 'size' not in m.info[2] - ).first() - self.assertEqual(obj, None) - obj = select( - m.info for m in self.M if 'siz' not in m.info[2] - ).first() - self.assertTrue(obj) + def test_equal_list_4(self): + colors = ['Gold', 'Silver', 'Space Gray'] + with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['colors'] == Json(colors)"): + p = get(p for p in self.Product if p.info['colors'] == Json(colors)) + self.assertTrue(p) @db_session - def test_in_list(self): - obj = select( - m.info[3] for m in self.M if 'item1' in m.info[3] - ).first() - self.assertTrue(obj) - obj = select( - m.info for m in self.M if 'description' in m.info - ).first() - self.assertTrue(obj) + @raises_exception(TypeError, "Incomparable types 'Json' and 'list' in expression: p.info['colors'] == []") + def test_equal_empty_list_1(self): + p = get(p for p in self.Product if p.info['colors'] == []) @db_session - def test_not_in_list(self): - obj = select( - m.info[3] for m in self.M if 'item1' not in m.info[3] - ).first() - self.assertEqual(obj, None) - obj = select( - m.info[3] for m in self.M if 'ite' not in m.info[3] - ).first() - self.assertIn('item1', obj) + def test_equal_empty_list_2(self): + with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['colors'] == Json([])"): + p = get(p for p in self.Product if p.info['colors'] == Json([])) + self.assertIsNone(p) @db_session - def test_var_in_json(self): - if self.db_provider in ('mysql', 'oracle'): - if_implemented = lambda: self.assertRaises(NotImplementedError) - else: - @contextmanager - def if_implemented(): - yield - with if_implemented(): - key = 'item1' - obj = select( - m.info[3] for m in self.M if key in m.info[3] - ).first() - self.assertTrue(obj) + def test_ne_list(self): + with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['colors'] != Json(['Gold'])"): + p = get(p for p in self.Product if p.info['colors'] != Json(['Gold'])) + self.assertTrue(p) @db_session - def test_get_json_attr(self): - ''' query should not contain distinct - ''' - if self.db_provider != 'oracle': - raise unittest.SkipTest - obj = get( - m.info for m in self.M - ) - self.assertTrue(obj) + def test_ne_empty_list(self): + with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + "Oracle does not support comparison of json structures: p.info['colors'] != Json([])"): + p = get(p for p in self.Product if p.info['colors'] != Json([])) + self.assertTrue(p) @db_session - def test_select_first(self): - ''' query shoud not contain ORDER BY - ''' - if self.db_provider != 'oracle': - raise unittest.SkipTest - obj = select( - m.info for m in self.M - ).first() - self.assertTrue(obj) + def test_dbval2val(self): + p = select(p for p in self.Product)[:][0] + attr = self.Product.info + val = p._vals_[attr] + dbval = p._dbvals_[attr] + self.assertIsInstance(dbval, basestring) + self.assertIsInstance(val, TrackedValue) + p.info['os']['version'] = '9' + self.assertIs(val, p._vals_[attr]) + self.assertIs(dbval, p._dbvals_[attr]) + p.flush() + self.assertIs(val, p._vals_[attr]) + self.assertNotEqual(dbval, p._dbvals_[attr]) - def test_in_json_regexp(self): - if self.db_provider != 'oracle': - raise unittest.SkipTest - import re - from pony.orm.dbproviders.oracle import search_in_json_list_regexp - regexp = search_in_json_list_regexp('item') - pos = [ - '["item"]', - '[0, "item"]', - '[{}, "item", []]', - '[{"a": 1}, "item", []]', - '[false, "item", "erg"]', - ] - for s in pos: - self.assertTrue(re.search(regexp, s)) - neg = [ - '[["item"]]', - '[{"item": 0]]', - '["1 item", "item 1"]', - '[0, " "]', - '[]', - ] - for s in neg: - self.assertFalse(re.search(regexp, s)) - - -class TestDataTypes(unittest.TestCase): + @db_session + def test_wildcard_path_1(self): + with raises_if(self, self.db.provider.dialect not in ('Oracle', 'MySQL'), + TranslationError, '...does not support wildcards in JSON path...'): + names = get(p.info['models'][:]['name'] for p in self.Product) + self.assertSetEqual(set(names), {'Wi-Fi', 'Wi-Fi + Cellular'}) - in_db_session = False + @db_session + def test_wildcard_path_2(self): + with raises_if(self, self.db.provider.dialect not in ('Oracle', 'MySQL'), + TranslationError, '...does not support wildcards in JSON path...'): + values = get(p.info['os'][...] for p in self.Product) + self.assertSetEqual(set(values), {'iOS', '8'}) - from ponytest import pony_fixtures - pony_fixtures = list(pony_fixtures) + [json1_cli] + @db_session + def test_wildcard_path_3(self): + with raises_if(self, self.db.provider.dialect not in ('Oracle', 'MySQL'), + TranslationError, '...does not support wildcards in JSON path...'): + names = get(p.info[...][0]['name'] for p in self.Product) + self.assertSetEqual(set(names), {'Wi-Fi'}) - @classmethod - def make_entities(cls): - class Data(cls.db.Entity): - data = Optional(Json) + @db_session + def test_wildcard_path_4(self): + if self.db.provider.dialect == 'Oracle': + raise unittest.SkipTest + with raises_if(self, self.db.provider.dialect != 'MySQL', + TranslationError, '...does not support wildcards in JSON path...'): + values = get(p.info[...][:][...][:] for p in self.Product)[:] + self.assertSetEqual(set(values), {'16GB', '64GB'}) @db_session - def setUp(self): - self.db.execute('delete from %s' % self.db.Data._table_) + def test_wildcard_path_with_params(self): + if self.db.provider.dialect != 'Oracle': + exc_msg = '...does not support wildcards in JSON path...' + else: + exc_msg = "Oracle doesn't allow parameters in JSON paths" + with raises_if(self, self.db.provider.dialect != 'MySQL', TranslationError, exc_msg): + key = 'models' + index = 0 + values = get(p.info[key][:]['capacity'][index] for p in self.Product) + self.assertListEqual(values, ['16GB', '16GB']) + @db_session + def test_wildcard_path_with_params_as_string(self): + if self.db.provider.dialect != 'Oracle': + exc_msg = '...does not support wildcards in JSON path...' + else: + exc_msg = "Oracle doesn't allow parameters in JSON paths" + with raises_if(self, self.db.provider.dialect != 'MySQL', TranslationError, exc_msg): + key = 'models' + index = 0 + values = get("p.info[key][:]['capacity'][index] for p in self.Product") + self.assertListEqual(values, ['16GB', '16GB']) - def test_int(self): + @db_session + def test_wildcard_path_in_condition(self): + errors = { + 'MySQL': 'Wildcards are not allowed in json_contains()', + 'SQLite': '...does not support wildcards in JSON path...', + 'PostgreSQL': '...does not support wildcards in JSON path...' + } + dialect = self.db.provider.dialect + with raises_if(self, dialect in errors, TranslationError, errors.get(dialect)): + p = get(p for p in self.Product if '16GB' in p.info['models'][:]['capacity']) + self.assertTrue(p) - db = self.db - with db_session: - db.Data(data={'val': 1}) + ##### 'key' in json - with db_session: - obj = get(d for d in db.Data if d.data['val'] == 1) - self.assertEqual(obj.data['val'], 1) + @db_session + def test_in_dict(self): + obj = get(p for p in self.Product if 'resolution' in p.info['display']) + self.assertTrue(obj) - def test_compare_int(self): - db = self.db - with db_session: - db.Data(data={'val': 3}) + @db_session + def test_not_in_dict(self): + obj = get(p for p in self.Product if 'resolution' not in p.info['display']) + self.assertIs(obj, None) + obj = get(p for p in self.Product if 'xyz' not in p.info['display']) + self.assertTrue(obj) - with db_session: - self.assertTrue( - get(d for d in db.Data if d.data['val'] > 2) - ) - self.assertTrue( - get(d for d in db.Data if d.data['val'] < 4) - ) + @db_session + def test_in_list(self): + obj = get(p for p in self.Product if 'Gold' in p.info['colors']) + self.assertTrue(obj) - def test_str(self): - db = self.db - with db_session: - db.Data(data={'val': "1"}) + @db_session + def test_not_in_list(self): + obj = get(p for p in self.Product if 'White' not in p.info['colors']) + self.assertTrue(obj) + obj = get(p for p in self.Product if 'Gold' not in p.info['colors']) + self.assertIs(obj, None) - with db_session: - obj = get(d for d in db.Data if d.data['val'] == '1') + @db_session + def test_var_in_json(self): + with raises_if(self, self.db.provider.dialect == 'Oracle', + TypeError, "For `key in JSON` operation Oracle supports literal key values only, " + "parameters are not allowed: key in p.info['colors']"): + key = 'Gold' + obj = get(p for p in self.Product if key in p.info['colors']) self.assertTrue(obj) - def test_none(self): - db = self.db - with db_session: - db.Data() - - with db_session: - data = get(d for d in db.Data if d.data is None) - self.assertTrue(data) - - # def test_is_null(self): - # db = self.db - # with db_session: - # db.Data(data={'val': None}) - - # with db_session: - # data = get(d for d in db.Data if d.data['val'] is None) - # self.assertTrue(data) - - # def test_eq_null(self): - # db = self.db - # with db_session: - # db.Data(data={'val': None}) - - # with db_session: - # data = get(d for d in db.Data if d.data['val'] == None) - # self.assertTrue(data) - - def test_bool(self): - with db_session: - self.db.Data(data={'val': True, 'id': 1}) - self.db.Data(data={'val': False, 'id': 2}) - - with db_session: - val = get( - d.data['id'] for d in self.db.Data - if d.data['val'] == False - ) - self.assertEqual(val, 2) - val = get( - d.data['id'] for d in self.db.Data - if d.data['val'] == True - ) - self.assertEqual(val, 1) - - def test_nonzero(self): - with db_session: - self.db.Data(data={'val': True, 'id': 1}) - self.db.Data(data={'val': False, 'id': 2}) - self.db.Data(data={'val': 0, 'id': 3}) - self.db.Data(data={'val': '', 'id': 4}) - self.db.Data(data={'id': 5}) - - if self.db_provider == 'oracle': - assert_raises = lambda: self.assertRaises(NotImplementedError) - else: - @contextmanager - def assert_raises(): - yield - - with db_session, assert_raises(): - val = get( - d.data['id'] for d in self.db.Data - if d.data['val'] - ) - self.assertEqual(val, 1) - - - def test_float(self): - with db_session: - self.db.Data(data={'val': 3.14}) - - with db_session: - val = get(d.data['val'] for d in self.db.Data) - self.assertIsInstance(val, float) + @db_session + def test_select_first(self): + # query should not contain ORDER BY + obj = select(p.info for p in self.Product).first() + self.assertNotIn('order by', self.db.last_sql.lower()) - def test_compare_float(self): + def test_sql_inject(self): + # test quote in json is not causing error with db_session: - self.db.Data(data={'val': 3.14}) + p = select(p for p in self.Product).first() + p.info['display']['size'] = "0' 9.7\"" with db_session: - val = get( - d.data['val'] for d in self.db.Data - if d.data['val'] < 3.15 - ) - self.assertIsInstance(val, float) - - - -# from ._postgres import JsonConcatTest #, JsonContainsTest # TODO - - -class TestSqliteFallback(unittest.TestCase): - - from ponytest import pony_fixtures - pony_fixtures = list(pony_fixtures) + [ - [no_json1_fixture] - ] - - @classmethod - def make_entities(cls): - class Person(cls.db.Entity): - name = Required(str) - data = Optional(Json) + p = select(p for p in self.Product).first() + self.assertEqual(p.info['display']['size'], "0' 9.7\"") + @db_session + def test_int_compare(self): + p = get(p for p in self.Product if p.info['display']['resolution'][0] == 2048) + self.assertTrue(p) + p = get(p for p in self.Product if p.info['display']['resolution'][0] != 2048) + self.assertIsNone(p) + p = get(p for p in self.Product if p.info['display']['resolution'][0] < 2048) + self.assertIs(p, None) + p = get(p for p in self.Product if p.info['display']['resolution'][0] <= 2048) + self.assertTrue(p) + p = get(p for p in self.Product if p.info['display']['resolution'][0] > 2048) + self.assertIs(p, None) + p = get(p for p in self.Product if p.info['display']['resolution'][0] >= 2048) + self.assertTrue(p) - def setUp(self): - self.db.execute('delete from %s' % self.db.Person._table_) + @db_session + def test_float_compare(self): + p = get(p for p in self.Product if p.info['display']['size'] > 9.5) + self.assertTrue(p) + p = get(p for p in self.Product if p.info['display']['size'] < 9.8) + self.assertTrue(p) + p = get(p for p in self.Product if p.info['display']['size'] < 9.5) + self.assertIsNone(p) + p = get(p for p in self.Product if p.info['display']['size'] > 9.8) + self.assertIsNone(p) + @db_session + def test_str_compare(self): + p = get(p for p in self.Product if p.info['ram'] == '8GB') + self.assertTrue(p) + p = get(p for p in self.Product if p.info['ram'] != '8GB') + self.assertIsNone(p) + p = get(p for p in self.Product if p.info['ram'] < '9GB') + self.assertTrue(p) + p = get(p for p in self.Product if p.info['ram'] > '7GB') + self.assertTrue(p) + p = get(p for p in self.Product if p.info['ram'] > '9GB') + self.assertIsNone(p) + p = get(p for p in self.Product if p.info['ram'] < '7GB') + self.assertIsNone(p) - def test(self): - Person = self.db.Person - with db_session: - Person(name='John') - Person(name='Mike', data=dict(a=1,b=2)) - with db_session: - p = Person[1] - p.data = dict(c=[2, 3, 4], d='d') - p = Person[2] - p.data['c'] = [1, 2, 3] + @db_session + def test_bool_compare(self): + p = get(p for p in self.Product if p.info['display']['multi-touch'] == True) + self.assertTrue(p) + p = get(p for p in self.Product if p.info['display']['multi-touch'] is True) + self.assertTrue(p) + p = get(p for p in self.Product if p.info['display']['multi-touch'] == False) + self.assertIsNone(p) + p = get(p for p in self.Product if p.info['display']['multi-touch'] is False) + self.assertIsNone(p) + p = get(p for p in self.Product if p.info['discontinued'] == False) + self.assertTrue(p) + p = get(p for p in self.Product if p.info['discontinued'] == True) + self.assertIsNone(p) - qs = select(p for p in Person if p.data['c'][1] == 2) - self.assertEqual(qs.count(), 1) + @db_session + def test_none_compare(self): + p = get(p for p in self.Product if p.info['videoUrl'] is None) + self.assertTrue(p) + p = get(p for p in self.Product if p.info['videoUrl'] is not None) + self.assertIsNone(p) + @db_session + def test_none_for_nonexistent_path(self): + p = get(p for p in self.Product if p.info['some_attr'] is None) + self.assertTrue(p) - def test_cmp(self): - Person = self.db.Person - with db_session: - Person(name='Mike', data=[4]) - with db_session: - qs = select(p for p in Person if p.data[0] < 5) - self.assertEqual(qs.count(), 1) - qs = select(p for p in Person if p.data[0] > 3) - self.assertEqual(qs.count(), 1) \ No newline at end of file + def test_nonzero(self): + Product = self.Product + with db_session: + delete(p for p in Product) + Product(name='P1', info=dict(id=1, val=True)) + Product(name='P2', info=dict(id=2, val=False)) + Product(name='P3', info=dict(id=3, val=0)) + Product(name='P4', info=dict(id=4, val=1)) + Product(name='P5', info=dict(id=5, val='')) + Product(name='P6', info=dict(id=6, val='x')) + Product(name='P7', info=dict(id=7, val=[])) + Product(name='P8', info=dict(id=8, val=[1, 2, 3])) + Product(name='P9', info=dict(id=9, val={})) + Product(name='P10', info=dict(id=10, val={'a': 'b'})) + Product(name='P11', info=dict(id=11)) + Product(name='P12', info=dict(id=12, val='True')) + Product(name='P13', info=dict(id=13, val='False')) + Product(name='P14', info=dict(id=14, val='0')) + Product(name='P15', info=dict(id=15, val='1')) + Product(name='P16', info=dict(id=16, val='""')) + Product(name='P17', info=dict(id=17, val='[]')) + Product(name='P18', info=dict(id=18, val='{}')) + + with db_session: + val = select(p.info['id'] for p in Product if not p.info['val']) + self.assertEqual(tuple(sorted(val)), (2, 3, 5, 7, 9, 11)) diff --git a/pony/orm/tests/testutils.py b/pony/orm/tests/testutils.py index 87f9a91de..62cefbdc9 100644 --- a/pony/orm/tests/testutils.py +++ b/pony/orm/tests/testutils.py @@ -1,6 +1,9 @@ from __future__ import absolute_import, print_function, division from pony.py23compat import basestring +from functools import wraps +from contextlib import contextmanager + from pony.orm.core import Database from pony.utils import import_module @@ -18,6 +21,25 @@ def wrapper(self, *args, **kwargs): return wrapper return decorator +@contextmanager +def raises_if(test, cond, exc_class, exc_msg=None): + try: + yield + except exc_class as e: + test.assertTrue(cond) + if exc_msg is None: pass + elif exc_msg.startswith('...') and exc_msg != '...': + if exc_msg.endswith('...'): + test.assertIn(exc_msg[3:-3], str(e)) + else: + test.assertTrue(str(e).endswith(exc_msg[3:])) + elif exc_msg.endswith('...'): + test.assertTrue(str(e).startswith(exc_msg[:-3])) + else: + test.assertEqual(str(e), exc_msg) + else: + test.assertFalse(cond) + def flatten(x): result = [] for el in x: From 5105e1683e5a1cdb92dd58539a2910db905d69a8 Mon Sep 17 00:00:00 2001 From: Vitalii Date: Thu, 18 Aug 2016 14:45:28 +0300 Subject: [PATCH 025/547] json tests: update fixtures --- pony/fixtures.py | 340 ------------------- pony/orm/tests/fixtures.py | 548 +++++++++++++++++++++++++++++++ pony/orm/tests/test_json/test.py | 44 +-- 3 files changed, 552 insertions(+), 380 deletions(-) delete mode 100644 pony/fixtures.py create mode 100644 pony/orm/tests/fixtures.py diff --git a/pony/fixtures.py b/pony/fixtures.py deleted file mode 100644 index f61a8798d..000000000 --- a/pony/fixtures.py +++ /dev/null @@ -1,340 +0,0 @@ -import os -import logging - -from pony.py23compat import PY2 -from ponytest import with_cli_args, pony_fixtures - -from functools import wraps -import click -from contextlib import contextmanager, closing - -from pony.utils import cached_property, class_cached_property - -from pony.orm.dbproviders.mysql import mysql_module -from pony.utils import cached_property, class_property - -if not PY2: - from contextlib import contextmanager -else: - from contextlib2 import contextmanager - -from pony.orm import db_session, Database, rollback - - -class DBContext(object): - - class_scoped = True - - def __init__(self, test_cls): - test_cls.db_fixture = self - test_cls.db = class_property(lambda cls: self.db) - test_cls.db_provider = class_property(lambda cls: self.provider) - self.test_cls = test_cls - - @class_property - def fixture_name(cls): - return cls.provider - - def init_db(self): - raise NotImplementedError - - @cached_property - def db(self): - raise NotImplementedError - - def __enter__(self): - self.init_db() - self.test_cls.make_entities() - self.db.generate_mapping(check_tables=True, create_tables=True) - - def __exit__(self, *exc_info): - self.db.provider.disconnect() - - - @classmethod - @with_cli_args - @click.option('--db', '-d', 'database', multiple=True) - @click.option('--exclude-db', '-e', multiple=True) - def invoke(cls, database, exclude_db): - fixture = [ - MySqlContext, OracleContext, SqliteContext, PostgresContext, - SqlServerContext, - ] - all_db = [ctx.provider for ctx in fixture] - for db in database: - if db == 'all': - continue - assert db in all_db, ( - "Unknown provider: %s. Use one of %s." % (db, ', '.join(all_db)) - ) - if 'all' in database: - database = all_db - elif exclude_db: - database = set(all_db) - set(exclude_db) - elif not database: - database = ['sqlite'] - for Ctx in fixture: - if Ctx.provider in database: - yield Ctx - - db_name = 'testdb' - - -pony_fixtures.appendleft(DBContext.invoke) - - -class MySqlContext(DBContext): - provider = 'mysql' - - - def drop_db(self, cursor): - cursor.execute('use sys') - cursor.execute('drop database %s' % self.db_name) - - - def init_db(self): - with closing(mysql_module.connect(**self.CONN).cursor()) as c: - try: - self.drop_db(c) - except mysql_module.DatabaseError as exc: - print('Failed to drop db: %s' % exc) - c.execute('create database %s' % self.db_name) - c.execute('use %s' % self.db_name) - - CONN = { - 'host': "localhost", - 'user': "ponytest", - 'passwd': "ponytest", - } - - @cached_property - def db(self): - CONN = dict(self.CONN, db=self.db_name) - return Database('mysql', **CONN) - - -class SqlServerContext(DBContext): - - provider = 'sqlserver' - - def get_conn_string(self, db=None): - s = ( - 'DSN=MSSQLdb;' - 'SERVER=mssql;' - 'UID=sa;' - 'PWD=pass;' - ) - if db: - s += 'DATABASE=%s' % db - return s - - @cached_property - def db(self): - CONN = self.get_conn_string(self.db_name) - return Database('mssqlserver', CONN) - - def init_db(self): - import pyodbc - cursor = pyodbc.connect(self.get_conn_string(), autocommit=True).cursor() - with closing(cursor) as c: - try: - self.drop_db(c) - except pyodbc.DatabaseError as exc: - print('Failed to drop db: %s' % exc) - c.execute('create database %s' % self.db_name) - c.execute('use %s' % self.db_name) - - def drop_db(self, cursor): - cursor.execute('use master') - cursor.execute('drop database %s' % self.db_name) - - -class SqliteContext(DBContext): - provider = 'sqlite' - - def init_db(self): - try: - os.remove(self.db_path) - except OSError as exc: - print('Failed to drop db: %s' % exc) - - - @cached_property - def db_path(self): - p = os.path.dirname(__file__) - p = os.path.join(p, self.db_name) - return os.path.abspath(p) - - @cached_property - def db(self): - return Database('sqlite', self.db_path, create_db=True) - - -class PostgresContext(DBContext): - provider = 'postgresql' - - def get_conn_dict(self, no_db=False): - d = dict( - user='ponytest', password='ponytest', - host='localhost' - ) - if not no_db: - d.update(database=self.db_name) - return d - - def init_db(self): - import psycopg2 - conn = psycopg2.connect( - **self.get_conn_dict(no_db=True) - ) - conn.set_isolation_level(0) - with closing(conn.cursor()) as cursor: - try: - self.drop_db(cursor) - except psycopg2.DatabaseError as exc: - print('Failed to drop db: %s' % exc) - cursor.execute('create database %s' % self.db_name) - - def drop_db(self, cursor): - cursor.execute('drop database %s' % self.db_name) - - - @cached_property - def db(self): - return Database('postgres', **self.get_conn_dict()) - - -class OracleContext(DBContext): - provider = 'oracle' - - def __enter__(self): - os.environ.update(dict( - ORACLE_BASE='/u01/app/oracle', - ORACLE_HOME='/u01/app/oracle/product/12.1.0/dbhome_1', - ORACLE_OWNR='oracle', - ORACLE_SID='orcl', - )) - return super(OracleContext, self).__enter__() - - def init_db(self): - import cx_Oracle - with closing(self.connect_sys()) as conn: - with closing(conn.cursor()) as cursor: - try: - self._destroy_test_user(cursor) - self._drop_tablespace(cursor) - except cx_Oracle.DatabaseError as exc: - print('Failed to drop db: %s' % exc) - cursor.execute( - """CREATE TABLESPACE %(tblspace)s - DATAFILE '%(datafile)s' SIZE 20M - REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize)s - """ % self.parameters) - cursor.execute( - """CREATE TEMPORARY TABLESPACE %(tblspace_temp)s - TEMPFILE '%(datafile_tmp)s' SIZE 20M - REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize_tmp)s - """ % self.parameters) - self._create_test_user(cursor) - - - def _drop_tablespace(self, cursor): - cursor.execute( - 'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS' - % self.parameters) - cursor.execute( - 'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS' - % self.parameters) - - - parameters = { - 'tblspace': 'test_tblspace', - 'tblspace_temp': 'test_tblspace_temp', - 'datafile': 'test_datafile.dbf', - 'datafile_tmp': 'test_datafile_tmp.dbf', - 'user': 'ponytest', - 'password': 'ponytest', - 'maxsize': '100M', - 'maxsize_tmp': '100M', - } - - def connect_sys(self): - import cx_Oracle - return cx_Oracle.connect('sys/the@localhost/ORCL', mode=cx_Oracle.SYSDBA) - - def connect_test(self): - import cx_Oracle - return cx_Oracle.connect('test_user/test_password@localhost/ORCL') - - - @cached_property - def db(self): - return Database('oracle', 'test_user/test_password@localhost/ORCL') - - def _create_test_user(self, cursor): - cursor.execute( - """CREATE USER %(user)s - IDENTIFIED BY %(password)s - DEFAULT TABLESPACE %(tblspace)s - TEMPORARY TABLESPACE %(tblspace_temp)s - QUOTA UNLIMITED ON %(tblspace)s - """ % self.parameters - ) - cursor.execute( - """GRANT CREATE SESSION, - CREATE TABLE, - CREATE SEQUENCE, - CREATE PROCEDURE, - CREATE TRIGGER - TO %(user)s - """ % self.parameters - ) - - def _destroy_test_user(self, cursor): - cursor.execute(''' - DROP USER %(user)s CASCADE - ''' % self.parameters) - - -@contextmanager -def logging_context(test): - level = logging.getLogger().level - from pony.orm.core import debug, sql_debug - logging.getLogger().setLevel(logging.INFO) - sql_debug(True) - yield - logging.getLogger().setLevel(level) - sql_debug(debug) - - -@with_cli_args -@click.option('--log', is_flag=True) -def use_logging(log): - if log: - yield logging_context - -pony_fixtures.appendleft(use_logging) - - -class DBSession(object): - - def __init__(self, test): - self.test = test - - @property - def in_db_session(self): - ret = getattr(self.test, 'in_db_session', True) - method = getattr(self.test, self.test._testMethodName) - return getattr(method, 'in_db_session', ret) - - def __enter__(self): - rollback() - if self.in_db_session: - db_session.__enter__() - - def __exit__(self, *exc_info): - rollback() - if self.in_db_session: - db_session.__exit__() - -pony_fixtures.appendleft([DBSession]) diff --git a/pony/orm/tests/fixtures.py b/pony/orm/tests/fixtures.py new file mode 100644 index 000000000..b395e1e06 --- /dev/null +++ b/pony/orm/tests/fixtures.py @@ -0,0 +1,548 @@ +import sys +import os +import logging + +from pony.py23compat import PY2 +from ponytest import with_cli_args, pony_fixtures, provider_validators, provider + +from functools import wraps, partial +import click +from contextlib import contextmanager, closing + + +from pony.orm.dbproviders.mysql import mysql_module +from pony.utils import cached_property, class_property + +if not PY2: + from contextlib import contextmanager, ContextDecorator, ExitStack +else: + from contextlib2 import contextmanager, ContextDecorator, ExitStack + +import unittest + +from pony.orm import db_session, Database, rollback, delete + +if not PY2: + from io import StringIO +else: + from StringIO import StringIO + +from multiprocessing import Process + +import threading + + +class DBContext(ContextDecorator): + + fixture = 'db' + enabled = False + + def __init__(self, Test): + if not isinstance(Test, type): + # FIXME ? + TestCls = type(Test) + NewClass = type(TestCls.__name__, (TestCls,), {}) + NewClass.__module__ = TestCls.__module__ + NewClass.db = property(lambda t: self.db) + Test.__class__ = NewClass + else: + Test.db = class_property(lambda cls: self.db) + self.Test = Test + + @class_property + def fixture_name(cls): + return cls.db_provider + + @class_property + def db_provider(cls): + # is used in tests + return cls.provider_key + + def init_db(self): + raise NotImplementedError + + @cached_property + def db(self): + raise NotImplementedError + + def __enter__(self): + self.init_db() + try: + self.Test.make_entities() + except (AttributeError, TypeError): + # No method make_entities with due signature + pass + else: + self.db.generate_mapping(check_tables=True, create_tables=True) + return self.db + + def __exit__(self, *exc_info): + self.db.provider.disconnect() + + @classmethod + def validate_fixtures(cls, fixtures, config): + return any(f.fixture_key == 'db' for f in fixtures) + + db_name = 'testdb' + + +@provider() +class GenerateMapping(ContextDecorator): + + weight = 200 + fixture = 'generate_mapping' + + def __init__(self, Test): + self.Test = Test + + def __enter__(self): + db = getattr(self.Test, 'db', None) + if not db or not db.entities: + return + for entity in db.entities.values(): + if entity._database_.schema is None: + db.generate_mapping(check_tables=True, create_tables=True) + break + + def __exit__(self, *exc_info): + pass + +@provider() +class MySqlContext(DBContext): + provider_key = 'mysql' + + def drop_db(self, cursor): + cursor.execute('use sys') + cursor.execute('drop database %s' % self.db_name) + + + def init_db(self): + with closing(mysql_module.connect(**self.CONN).cursor()) as c: + try: + self.drop_db(c) + except mysql_module.DatabaseError as exc: + print('Failed to drop db: %s' % exc) + c.execute('create database %s' % self.db_name) + c.execute('use %s' % self.db_name) + + CONN = { + 'host': "localhost", + 'user': "ponytest", + 'passwd': "ponytest", + } + + @cached_property + def db(self): + CONN = dict(self.CONN, db=self.db_name) + return Database('mysql', **CONN) + +@provider() +class SqlServerContext(DBContext): + + provider_key = 'sqlserver' + + def get_conn_string(self, db=None): + s = ( + 'DSN=MSSQLdb;' + 'SERVER=mssql;' + 'UID=sa;' + 'PWD=pass;' + ) + if db: + s += 'DATABASE=%s' % db + return s + + @cached_property + def db(self): + CONN = self.get_conn_string(self.db_name) + return Database('mssqlserver', CONN) + + def init_db(self): + import pyodbc + cursor = pyodbc.connect(self.get_conn_string(), autocommit=True).cursor() + with closing(cursor) as c: + try: + self.drop_db(c) + except pyodbc.DatabaseError as exc: + print('Failed to drop db: %s' % exc) + c.execute('create database %s' % self.db_name) + c.execute('use %s' % self.db_name) + + def drop_db(self, cursor): + cursor.execute('use master') + cursor.execute('drop database %s' % self.db_name) + + +@provider() +class SqliteContext(DBContext): + provider_key = 'sqlite' + enabled = True + + def init_db(self): + try: + os.remove(self.db_path) + except OSError as exc: + print('Failed to drop db: %s' % exc) + + fixture_name = 'sqlite, with json1' + + + # TODO if json1 is not installed, do not run the tests + + @cached_property + def db_path(self): + p = os.path.dirname(__file__) + p = os.path.join(p, '%s.sqlite' % self.db_name) + return os.path.abspath(p) + + @cached_property + def db(self): + return Database('sqlite', self.db_path, create_db=True) + + +@provider() +class PostgresContext(DBContext): + provider_key = 'postgresql' + + def get_conn_dict(self, no_db=False): + d = dict( + user='ponytest', password='ponytest', + host='localhost', database='postgres', + ) + if not no_db: + d.update(database=self.db_name) + return d + + def init_db(self): + import psycopg2 + conn = psycopg2.connect( + **self.get_conn_dict(no_db=True) + ) + conn.set_isolation_level(0) + with closing(conn.cursor()) as cursor: + try: + self.drop_db(cursor) + except psycopg2.DatabaseError as exc: + print('Failed to drop db: %s' % exc) + cursor.execute('create database %s' % self.db_name) + + def drop_db(self, cursor): + cursor.execute('drop database %s' % self.db_name) + + + @cached_property + def db(self): + return Database('postgres', **self.get_conn_dict()) + + +@provider() +class OracleContext(DBContext): + provider_key = 'oracle' + + def __enter__(self): + os.environ.update(dict( + ORACLE_BASE='/u01/app/oracle', + ORACLE_HOME='/u01/app/oracle/product/12.1.0/dbhome_1', + ORACLE_OWNR='oracle', + ORACLE_SID='orcl', + )) + return super(OracleContext, self).__enter__() + + def init_db(self): + + import cx_Oracle + with closing(self.connect_sys()) as conn: + with closing(conn.cursor()) as cursor: + try: + self._destroy_test_user(cursor) + except cx_Oracle.DatabaseError as exc: + print('Failed to drop user: %s' % exc) + try: + self._drop_tablespace(cursor) + except cx_Oracle.DatabaseError as exc: + print('Failed to drop db: %s' % exc) + cursor.execute( + """CREATE TABLESPACE %(tblspace)s + DATAFILE '%(datafile)s' SIZE 20M + REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize)s + """ % self.parameters) + cursor.execute( + """CREATE TEMPORARY TABLESPACE %(tblspace_temp)s + TEMPFILE '%(datafile_tmp)s' SIZE 20M + REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize_tmp)s + """ % self.parameters) + self._create_test_user(cursor) + + + def _drop_tablespace(self, cursor): + cursor.execute( + 'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS' + % self.parameters) + cursor.execute( + 'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS' + % self.parameters) + + + parameters = { + 'tblspace': 'test_tblspace', + 'tblspace_temp': 'test_tblspace_temp', + 'datafile': 'test_datafile.dbf', + 'datafile_tmp': 'test_datafile_tmp.dbf', + 'user': 'ponytest', + 'password': 'ponytest', + 'maxsize': '100M', + 'maxsize_tmp': '100M', + } + + def connect_sys(self): + import cx_Oracle + return cx_Oracle.connect('sys/the@localhost/ORCL', mode=cx_Oracle.SYSDBA) + + def connect_test(self): + import cx_Oracle + return cx_Oracle.connect('ponytest/ponytest@localhost/ORCL') + + + @cached_property + def db(self): + return Database('oracle', 'ponytest/ponytest@localhost/ORCL') + + def _create_test_user(self, cursor): + cursor.execute( + """CREATE USER %(user)s + IDENTIFIED BY %(password)s + DEFAULT TABLESPACE %(tblspace)s + TEMPORARY TABLESPACE %(tblspace_temp)s + QUOTA UNLIMITED ON %(tblspace)s + """ % self.parameters + ) + cursor.execute( + """GRANT CREATE SESSION, + CREATE TABLE, + CREATE SEQUENCE, + CREATE PROCEDURE, + CREATE TRIGGER + TO %(user)s + """ % self.parameters + ) + + def _destroy_test_user(self, cursor): + cursor.execute(''' + DROP USER %(user)s CASCADE + ''' % self.parameters) + + +@provider(fixture='log', weight=100, enabled=False) +@contextmanager +def logging_context(test): + level = logging.getLogger().level + from pony.orm.core import debug, sql_debug + logging.getLogger().setLevel(logging.INFO) + sql_debug(True) + yield + logging.getLogger().setLevel(level) + sql_debug(debug) + +# @provider('log_all', scope='class', weight=-100, enabled=False) +# def log_all(Test): +# return logging_context(Test) + + + +# @with_cli_args +# @click.option('--log', 'scope', flag_value='test') +# @click.option('--log-all', 'scope', flag_value='all') +# def use_logging(scope): +# if scope == 'test': +# yield logging_context +# elif scope =='all': +# yield log_all + + + + +@provider() +class DBSessionProvider(object): + + fixture= 'db_session' + + weight = 30 + + def __new__(cls, test): + return db_session + + +@provider(fixture='rollback', weight=40) +@contextmanager +def do_rollback(test): + try: + yield + finally: + rollback() + + +@provider() +class SeparateProcess(object): + + # TODO read failures from sep process better + + fixture = 'separate_process' + + enabled = False + + scope = 'class' + + def __init__(self, Test): + self.Test = Test + + def __call__(self, func): + def wrapper(Test): + rnr = unittest.runner.TextTestRunner() + TestCls = Test if isinstance(Test, type) else type(Test) + def runTest(self): + try: + func(Test) + finally: + rnr.stream = unittest.runner._WritelnDecorator(StringIO()) + name = getattr(func, '__name__', 'runTest') + Case = type(TestCls.__name__, (TestCls,), {name: runTest}) + Case.__module__ = TestCls.__module__ + case = Case(name) + suite = unittest.suite.TestSuite([case]) + def run(): + result = rnr.run(suite) + if not result.wasSuccessful(): + sys.exit(1) + p = Process(target=run, args=()) + p.start() + p.join() + case.assertEqual(p.exitcode, 0) + return wrapper + + @classmethod + def validate_chain(cls, fixtures, klass): + for f in fixtures: + if f.KEY in ('ipdb', 'ipdb_all'): + return False + for f in fixtures: + if f.KEY == 'db' and f.provider_key in ('sqlserver', 'oracle'): + return True + +@provider() +class ClearTables(ContextDecorator): + + fixture = 'clear_tables' + + def __init__(self, test): + self.test = test + + def __enter__(self): + pass + + @db_session + def __exit__(self, *exc_info): + db = self.test.db + for entity in db.entities.values(): + if entity._database_.schema is None: + break + delete(i for i in entity) + + +@provider() +class NoJson1(SqliteContext): + provider_key = 'sqlite_no_json1' + fixture = 'db' + + def __init__(self, cls): + self.Test = cls + cls.no_json1 = True + return super(NoJson1, self).__init__(cls) + + fixture_name = 'sqlite, no json1' + + def __enter__(self): + resource = super(NoJson1, self).__enter__() + self.json1_available = self.Test.db.provider.json1_available + self.Test.db.provider.json1_available = False + return resource + + def __exit__(self, *exc_info): + self.Test.db.provider.json1_available = self.json1_available + return super(NoJson1, self).__exit__() + + + +import signal + +@provider() +class Timeout(object): + + fixture = 'timeout' + + @with_cli_args + @click.option('--timeout', type=int) + def __init__(self, Test, timeout): + self.Test = Test + self.timeout = timeout if timeout else Test.TIMEOUT + + scope = 'class' + enabled = False + + class Exception(Exception): + pass + + class FailedInSubprocess(Exception): + pass + + def __call__(self, func): + def wrapper(test): + p = Process(target=func, args=(test,)) + p.start() + + def on_expired(): + p.terminate() + + t = threading.Timer(self.timeout, on_expired) + t.start() + p.join() + t.cancel() + if p.exitcode == -signal.SIGTERM: + raise self.Exception + elif p.exitcode: + raise self.FailedInSubprocess + + return wrapper + + @classmethod + @with_cli_args + @click.option('--timeout', type=int) + def validate_chain(cls, fixtures, klass, timeout): + if not getattr(klass, 'TIMEOUT', None) and not timeout: + return False + for f in fixtures: + if f.KEY in ('ipdb', 'ipdb_all'): + return False + for f in fixtures: + if f.KEY == 'db' and f.provider_key in ('sqlserver', 'oracle'): + return True + + +pony_fixtures['test'].extend([ + 'log', + 'clear_tables', + 'db_session', +]) + +pony_fixtures['class'].extend([ + 'separate_process', + 'timeout', + 'db', + 'generate_mapping', +]) + +# def db_is_required(providers, config): +# return providers + +# provider_validators.update({ +# 'db': db_is_required, +# }) \ No newline at end of file diff --git a/pony/orm/tests/test_json/test.py b/pony/orm/tests/test_json/test.py index f6982f24e..f70b0e3b6 100644 --- a/pony/orm/tests/test_json/test.py +++ b/pony/orm/tests/test_json/test.py @@ -1,4 +1,4 @@ -# *uses fixtures* +from pony.py23compat import basestring import unittest @@ -10,46 +10,12 @@ from contextlib import contextmanager -import pony.fixtures -from ponytest import with_cli_args +import pony.orm.tests.fixtures +from ponytest import with_cli_args, TestCase -def no_json1_fixture(cls): - if cls.db_provider != 'sqlite': - raise unittest.SkipTest - cls.no_json1 = True - - @contextmanager - def mgr(): - json1_available = cls.db.provider.json1_available - cls.db.provider.json1_available = False - try: - yield - finally: - cls.db.provider.json1_available = json1_available - - return mgr() - -no_json1_fixture.class_scoped = True - - -@contextmanager -def empty_mgr(*args, **kw): - yield - - -@with_cli_args -@click.option('--json1', flag_value=True, default=None) -@click.option('--no-json1', 'json1', flag_value=False) -def json1_cli(json1): - if json1 is None or json1 is True: - yield empty_mgr - if json1 is None or json1 is False: - yield no_json1_fixture - - -class TestJson(unittest.TestCase): +class TestJson(TestCase): in_db_session = False @classmethod @@ -61,8 +27,6 @@ class Product(cls.db.Entity): cls.Product = cls.db.Product - from ponytest import pony_fixtures - pony_fixtures = list(pony_fixtures) + [json1_cli] @db_session def setUp(self): From 7086c8e471a7ebe2caca3f90bb0298c1e0cbf056 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 18 Aug 2016 17:10:09 +0300 Subject: [PATCH 026/547] Fix import error when MySQL driver is not installed --- pony/orm/tests/fixtures.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pony/orm/tests/fixtures.py b/pony/orm/tests/fixtures.py index b395e1e06..47c8921b5 100644 --- a/pony/orm/tests/fixtures.py +++ b/pony/orm/tests/fixtures.py @@ -9,8 +9,6 @@ import click from contextlib import contextmanager, closing - -from pony.orm.dbproviders.mysql import mysql_module from pony.utils import cached_property, class_property if not PY2: @@ -117,6 +115,7 @@ def drop_db(self, cursor): def init_db(self): + from pony.orm.dbproviders.mysql import mysql_module with closing(mysql_module.connect(**self.CONN).cursor()) as c: try: self.drop_db(c) From 1797e7115e212c16e2ec9e03336da85934c4b304 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 11 Jul 2016 20:57:39 +0300 Subject: [PATCH 027/547] Remove JSON_SUBTRACT_PATH & JSON_SUBTRACT_VALUE methods --- pony/orm/dbproviders/mysql.py | 2 -- pony/orm/dbproviders/postgres.py | 5 ----- pony/orm/dbproviders/sqlite.py | 4 ---- pony/orm/sqltranslation.py | 18 ------------------ 4 files changed, 29 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 5dfb6c7cc..fdd501ff8 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -160,8 +160,6 @@ def DATETIME_SUB(builder, expr, delta): return 'SUBTIME(', builder(expr), ', ', builder(delta), ')' def JSON_GETPATH(builder, expr, key): return 'json_extract(', builder(expr), ', ', builder(key), ')' - def JSON_SUBTRACT_PATH(builder, expr, key): - return 'json_remove(', builder(expr), ', ', builder(key), ')' def JSON_ARRAY_LENGTH(builder, value): return 'json_length(', builder(value), ')' def AS_JSON(builder, target): diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 600484040..305335a74 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -177,11 +177,6 @@ def JSON_HAS_ANY(builder, array, value): raise NotImplementedError def JSON_HAS_ALL(builder, array, value): raise NotImplementedError - def JSON_SUBTRACT_VALUE(builder, expr, key): - val = builder.VALUE(key) - return '(', builder(expr), " - ", val, ')' - def JSON_SUBTRACT_PATH(builder, value, key): - return '(', builder(value), " #- ", builder(key), ')' def JSON_ARRAY_LENGTH(builder, value): return 'jsonb_array_length(', builder(value), ')' def _as_json(builder, target): diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 8257d7eb6..757d061a5 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -176,10 +176,6 @@ def JSON_GETPATH__QUOTE_STRINGS(builder, expr, key): return 'py_json_extract(', builder(expr), ', ', builder(key), ', 1)' ret = 'json_extract(', builder(expr), ', null, ', builder(key), ')' return 'unwrap_extract_json(', ret, ')' - def JSON_SUBTRACT_PATH(builder, expr, key): - if not builder.json1_available: - raise SqliteExtensionUnavailable('json1') - return 'json_remove(', builder(expr), ', ', builder(key), ')' def JSON_ARRAY_LENGTH(builder, value): if not builder.json1_available: raise SqliteExtensionUnavailable('json1') diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index a6f244a4f..86453d4c0 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1566,24 +1566,6 @@ def _get_value(cls, monad): raise TypeError('Invalid JSON path item: %s' % ast2src(monad.node)) return monad.value - allow_subtract_key_syntax = False # support only subtracting path by default - - def __sub__(monad, other): - translator = monad.translator - left_sql, = monad.getsql() - items = None - if isinstance(other, translator.ListMonad): - items = other.items - elif not monad.allow_subtract_key_syntax: - items = [other] - else: - value = monad._get_value(other) - sql = ['JSON_SUBTRACT_VALUE', left_sql, value] - if items: - path = monad._get_path_sql(items) - sql = ['JSON_SUBTRACT_PATH', left_sql, path] - return translator.JsonExprMonad(translator, Json, sql) - def __getitem__(monad, item, is_overriden=False): ''' Transform the item and return it. Please override. From 0ba76cdf19b74b2f4e2ace1ba91c9721ed7ff256 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 11 Jul 2016 21:28:28 +0300 Subject: [PATCH 028/547] Remove allow_get_by_key_syntax flag and JSON_GET builder method --- pony/orm/dbproviders/postgres.py | 5 ----- pony/orm/sqltranslation.py | 6 ------ 2 files changed, 11 deletions(-) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 305335a74..5f6645396 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -40,8 +40,6 @@ class PGTranslator(SQLTranslator): dialect = 'PostgreSQL' class JsonItemMonad(sqltranslation.JsonItemMonad): - allow_get_by_key_syntax = True - def nonzero(monad): translator = monad.translator empty_str = translator.StringExprMonad( @@ -158,9 +156,6 @@ def JSON_PATH(builder, *items): ret.append(item) ret.append("}'") return ret - def JSON_GET(builder, expr, key): - val = builder.VALUE(key) - return '(', builder(expr), "->", val, ')' def JSON_GETPATH(builder, expr, key): return '(', builder(expr), "#>", builder(key), ')' def JSON_CONTAINS(builder, expr, path, key): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 86453d4c0..49ac1121a 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1781,8 +1781,6 @@ class JsonExprMonad(JsonMixin, ExprMonad): class JsonItemMonad(JsonMixin, Monad): - allow_get_by_key_syntax = False - def __init__(monad, attr_monad, path): translator = attr_monad.translator monad.attr_monad = attr_monad @@ -1796,10 +1794,6 @@ def __getitem__(monad, key): def getsql(monad): base_sql, = monad.attr_monad.getsql() - if monad.allow_get_by_key_syntax and len(monad.path) == 1: - value = monad._get_value(monad.path[0]) - sql = ['JSON_GET', base_sql, value] - return [sql] path_sql = monad._get_path_sql(monad.path) sql = ['JSON_GETPATH'] sql.extend((base_sql, path_sql)) From ae01a4a17eebd6ca4c984b6871e389ebdf35ca53 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 11 Jul 2016 11:29:12 +0300 Subject: [PATCH 029/547] Add | concatenation operation which translates to || in PostgreSQL 9.5 --- pony/orm/dbproviders/postgres.py | 2 ++ pony/orm/sqltranslation.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 5f6645396..36a9487de 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -158,6 +158,8 @@ def JSON_PATH(builder, *items): return ret def JSON_GETPATH(builder, expr, key): return '(', builder(expr), "#>", builder(key), ')' + def JSON_CONCAT(builder, left, right): + return '(', builder(left), '||', builder(right), ')' def JSON_CONTAINS(builder, expr, path, key): if path: json_sql = builder.JSON_GETPATH(expr, path) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 49ac1121a..a9b7afcdb 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -700,6 +700,16 @@ def postAnd(translator, node): return translator.AndMonad([ subnode.monad for subnode in node.nodes ]) def postOr(translator, node): return translator.OrMonad([ subnode.monad for subnode in node.nodes ]) + def postBitor(translator, node): + left, right = (subnode.monad for subnode in node.nodes) + return left | right + def postBitand(translator, node): + left, right = (subnode.monad for subnode in node.nodes) + return left & right + def postBitxor(translator, node): + left, right = (subnode.monad for subnode in node.nodes) + return left ^ right + def preNot(translator, node): translator.inside_not = not translator.inside_not def postNot(translator, node): @@ -1056,6 +1066,9 @@ def __truediv__(monad, monad2): throw(TypeError) def __floordiv__(monad, monad2): throw(TypeError) def __pow__(monad, monad2): throw(TypeError) def __neg__(monad): throw(TypeError) + def __or__(monad): throw(TypeError) + def __and__(monad): throw(TypeError) + def __xor__(monad): throw(TypeError) def abs(monad): throw(TypeError) class RawSQLMonad(Monad): @@ -1600,6 +1613,15 @@ def contains(monad, item, not_in=False): # else: # raise TypeError('Invalid JSON key: %s,' % ast2src(item.node)) + def __or__(monad, other): + translator = monad.translator + if not isinstance(other, translator.JsonMixin): + raise TypeError('Should be JSON: %s' % ast2src(other.node)) + left_sql, = monad.getsql() + right_sql, = other.getsql() + sql = ['JSON_CONCAT', left_sql, right_sql] + return translator.JsonConcatExprMonad(translator, Json, sql) + def len(monad): translator = monad.translator sql, = monad.getsql() From c0312113b766ea60e5c3f50b57eabb1c7534514d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 11 Jul 2016 11:30:22 +0300 Subject: [PATCH 030/547] Add wildcard path support to MySQL and Oracle --- pony/orm/dbproviders/oracle.py | 2 ++ pony/orm/ormtypes.py | 11 +++++++++-- pony/orm/sqlbuilding.py | 8 +++++++- pony/orm/sqltranslation.py | 11 +++++++++-- 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index b07a39e72..443f4626f 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -256,6 +256,8 @@ def DATETIME_SUB(builder, expr, delta): if isinstance(delta, timedelta): return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' + def JSON_GETPATH_STARRED(builder, expr, key): + return 'JSON_QUERY(', builder(expr), ', ', builder(key), ' WITH WRAPPER)' def JSON_GETPATH(builder, expr, key): query = 'JSON_QUERY(', builder(expr), ', ', builder(key), ' WITH WRAPPER)' return 'REGEXP_REPLACE(', query, ", '(^\[|\]$)', '')" diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 49d2b2783..9f07a9462 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -153,7 +153,7 @@ def normalize_type(t): if t is NoneType: return t t = type_normalization_dict.get(t, t) if t in primitive_types: return t - if issubclass(t, (slice, type(Ellipsis))): return t + if issubclass(t, (AnyItem, slice, type(Ellipsis))): return t if issubclass(t, basestring): return unicode if issubclass(t, (dict, Json)): return Json throw(TypeError, 'Unsupported type %r' % t.__name__) @@ -293,4 +293,11 @@ def __init__(self, wrapped): self.wrapped = wrapped def __repr__(self): - return 'Json %s' % repr(self.wrapped) + return '' % self.wrapped + +class AnyItem(object): + def __init__(self, type): + self.type = type + +AnyStr = AnyItem('Str') +AnyNum = AnyItem('Number') diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 599cee879..de2d68b63 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -8,7 +8,7 @@ from pony import options from pony.utils import datetime2timestamp, throw -from pony.orm.ormtypes import RawSQL +from pony.orm.ormtypes import RawSQL, Json, AnyNum, AnyStr class AstError(Exception): pass @@ -508,12 +508,18 @@ def JSON_PATH(builder, *items): for item in items: if isinstance(item, int): ret.append('[%d]' % item) + elif item is AnyNum: + ret.append('[*]') elif isinstance(item, str): ret.append('."%s"' % item) + elif item is AnyStr: + ret.append('.*') else: assert 0 ret.append('\'') return ret def JSON_GETPATH(builder, expr, key): raise NotImplementedError + def JSON_GETPATH_STARRED(builder, expr, key): + return builder.JSON_GETPATH(expr, key) def CAST(builder, expr, type): return 'CAST(', builder(expr), ' AS ', type, ')' diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index a9b7afcdb..46112393a 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -17,7 +17,7 @@ from pony.orm.ormtypes import \ numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ get_normalized_type_of, normalize_type, coerce_types, are_comparable_types, \ - Json + Json, AnyStr, AnyNum, AnyItem from pony.orm import core from pony.orm.core import EntityMeta, Set, JOIN, OptimizationFailed, Attribute, DescWrapper @@ -1575,6 +1575,10 @@ def _get_path_sql(cls, items): @classmethod def _get_value(cls, monad): tr = monad.translator + if isinstance(monad, EllipsisMonad): + return AnyStr + if isinstance(monad, FullSliceMonad): + return AnyNum if not isinstance(monad, (tr.NumericConstMonad, tr.StringConstMonad)): raise TypeError('Invalid JSON path item: %s' % ast2src(monad.node)) return monad.value @@ -1817,7 +1821,10 @@ def __getitem__(monad, key): def getsql(monad): base_sql, = monad.attr_monad.getsql() path_sql = monad._get_path_sql(monad.path) - sql = ['JSON_GETPATH'] + if any(isinstance(item, AnyItem) for item in path_sql): + sql = ['JSON_GETPATH_STARRED'] + else: + sql = ['JSON_GETPATH'] sql.extend((base_sql, path_sql)) return [sql] From cac4988e01e8a62851aded602df6558335629eef Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 11 Jul 2016 21:09:10 +0300 Subject: [PATCH 031/547] JSON_PATH fixes --- pony/orm/dbproviders/postgres.py | 17 +++++++++-------- pony/orm/dbproviders/sqlite.py | 2 +- pony/orm/sqlbuilding.py | 18 +++++++++--------- 3 files changed, 19 insertions(+), 18 deletions(-) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 36a9487de..82d9a3062 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -26,6 +26,7 @@ from pony.orm.sqltranslation import SQLTranslator from pony.orm.sqlbuilding import Value, SQLBuilder from pony.converting import timedelta2str +from pony.utils import is_ident NoneType = type(None) @@ -148,14 +149,14 @@ def DATETIME_SUB(builder, expr, delta): return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' DAY TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' def JSON_PATH(builder, *items): - ret = ["'{"] - for i, item in enumerate(items): - if i: ret.append(', ') - if isinstance(item, basestring): - item = '"', item, '"' - ret.append(item) - ret.append("}'") - return ret + result = [] + for item in items: + if isinstance(item, int): + result.append(str(item)) + elif isinstance(item, basestring): + result.append(item if is_ident(item) else '"%s"' % item.replace('"', '\\"')) + else: assert False, item + return '{%s}' % ','.join(result) def JSON_GETPATH(builder, expr, key): return '(', builder(expr), "#>", builder(key), ')' def JSON_CONCAT(builder, left, right): diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 757d061a5..db2d702eb 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -166,7 +166,7 @@ def RANDOM(builder): def JSON_PATH(builder, *items): if builder.json1_available: return SQLBuilder.JSON_PATH(builder, *items) - return "'", json.dumps(items), "'" + return builder.VALUE(json.dumps(items)) def JSON_GETPATH(builder, expr, key): if not builder.json1_available: return 'py_json_extract(', builder(expr), ', ', builder(key), ', 0)' diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index de2d68b63..292459805 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -7,7 +7,7 @@ from binascii import hexlify from pony import options -from pony.utils import datetime2timestamp, throw +from pony.utils import datetime2timestamp, throw, is_ident from pony.orm.ormtypes import RawSQL, Json, AnyNum, AnyStr class AstError(Exception): pass @@ -504,19 +504,19 @@ def RAWSQL(builder, sql): if isinstance(sql, basestring): return sql return [ x if isinstance(x, basestring) else builder(x) for x in sql ] def JSON_PATH(builder, *items): - ret = ['\'$'] + result = ['\'$'] for item in items: if isinstance(item, int): - ret.append('[%d]' % item) + result.append('[%d]' % item) elif item is AnyNum: - ret.append('[*]') + result.append('[*]') elif isinstance(item, str): - ret.append('."%s"' % item) + result.append('.' + item if is_ident(item) else '."%s"' % item.replace('"', '\\"')) elif item is AnyStr: - ret.append('.*') - else: assert 0 - ret.append('\'') - return ret + result.append('.*') + else: assert False + result.append('\'') + return result def JSON_GETPATH(builder, expr, key): raise NotImplementedError def JSON_GETPATH_STARRED(builder, expr, key): From 478a218af38f05097ad5f224f3a1480a3deb39fd Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 12 Jul 2016 19:43:03 +0300 Subject: [PATCH 032/547] Remove contains_json logic --- pony/orm/dbproviders/postgres.py | 2 -- pony/orm/sqltranslation.py | 15 --------------- 2 files changed, 17 deletions(-) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 82d9a3062..aba165c26 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -167,8 +167,6 @@ def JSON_CONTAINS(builder, expr, path, key): else: json_sql = builder(expr) return json_sql, " ? ", builder(key) - def JSON_CONTAINS_JSON(builder, sub_value, value): - return builder(sub_value), " <@ ", builder(value) def JSON_IS_CONTAINED(builder, value, contained_in): raise NotImplementedError('Not needed') def JSON_HAS_ANY(builder, array, value): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 46112393a..2bb9ea17b 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1602,21 +1602,6 @@ def contains(monad, item, not_in=False): expr = translator.JsonBoolExprMonad(translator, bool, ['NOT', sql]) return expr - # TODO not_in - # def contains_json(monad, item, not_in=False): - # import ipdb; ipdb.set_trace() - # translator = monad.translator - # parent_sql, = monad.getsql() - # item_sql, = item.getsql() - # if isinstance(item, translator.JsonMixin): - # sql = ['JSON_CONTAINS_JSON', item_sql, parent_sql] - # return translator.JsonBoolExprMonad(monad.translator, bool, sql) - # elif isinstance(item, translator.StringMixin): - # sql = ['JSON_CONTAINS', item_sql, parent_sql] - # return translator.JsonBoolExprMonad(monad.translator, bool, sql) - # else: - # raise TypeError('Invalid JSON key: %s,' % ast2src(item.node)) - def __or__(monad, other): translator = monad.translator if not isinstance(other, translator.JsonMixin): From bfc4739998ffdf5bc82c46b6700e4210b8bea0b7 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 12 Jul 2016 21:03:38 +0300 Subject: [PATCH 033/547] Refactoring of JsonMixin & JsonItemMonad --- pony/orm/dbproviders/sqlite.py | 31 ------------ pony/orm/sqltranslation.py | 87 +++++++++++++++------------------- 2 files changed, 38 insertions(+), 80 deletions(-) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index db2d702eb..cf33e5fb6 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -54,35 +54,6 @@ class SQLiteTranslator(sqltranslation.SQLTranslator): StringMixin_UPPER = make_overriden_string_func('PY_UPPER') StringMixin_LOWER = make_overriden_string_func('PY_LOWER') - class CmpMonad(sqltranslation.CmpMonad): - def __init__(monad, op, left, right): - translator = left.translator - sqltranslation.CmpMonad.__init__(monad, op, left, right) - if not isinstance(left, translator.JsonMixin): - return - if op in ('==', '!='): - if isinstance(right, sqltranslation.AttrMonad) : - left.quote_strings = False - - class JsonItemMonad(sqltranslation.JsonItemMonad): - quote_strings = True - def getsql(monad): - sql, = sqltranslation.JsonItemMonad.getsql(monad) - if monad.quote_strings: - sql[0] = 'JSON_GETPATH__QUOTE_STRINGS' - return [sql] - def nonzero(monad): - translator = monad.translator - if translator.database.provider.json1_available: - monad.quote_strings = False - return monad - sql = ['PY_JSON_NONZERO'] - expr_sql = monad.attr_monad.getsql()[0] - path_sql = monad._get_path_sql(monad.path) - sql.extend([expr_sql, path_sql]) - return translator.BoolExprMonad(translator, sql) - - class SQLiteBuilder(SQLBuilder): dialect = 'SQLite' def __init__(builder, provider, ast): @@ -185,8 +156,6 @@ def JSON_CONTAINS(builder, expr, path, key): # TODO impl with builder.json1_disabled(): return 'py_json_contains(', builder(expr), ', ', builder(path), ', ', builder(key), ')' - def PY_JSON_NONZERO(builder, expr, path): - return 'py_json_nonzero(', builder(expr), ', ', builder(path), ')' @contextmanager def json1_disabled(builder): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 2bb9ea17b..1a8445f40 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1568,31 +1568,23 @@ class JsonMixin(object): disable_distinct = True # at least in Oracle we cannot use DISTINCT with JSON column disable_ordering = True # at least in Oracle we cannot use ORDER BY with JSON column - @classmethod - def _get_path_sql(cls, items): - return ['JSON_PATH'] + [cls._get_value(item) for item in items] + def mixin_init(monad): + assert monad.type is Json, monad.type + + def _get_path_sql(monad, path): + result = [ 'JSON_PATH' ] + for item in path: + if isinstance(item, EllipsisMonad): + result.append(AnyStr) + elif isinstance(item, slice): + result.append(AnyNum) + elif isinstance(item, (NumericConstMonad, StringConstMonad)): + result.append(item.value) + raise TypeError('Invalid JSON path item: %s' % ast2src(item.node)) + return result - @classmethod - def _get_value(cls, monad): - tr = monad.translator - if isinstance(monad, EllipsisMonad): - return AnyStr - if isinstance(monad, FullSliceMonad): - return AnyNum - if not isinstance(monad, (tr.NumericConstMonad, tr.StringConstMonad)): - raise TypeError('Invalid JSON path item: %s' % ast2src(monad.node)) - return monad.value - - def __getitem__(monad, item, is_overriden=False): - ''' - Transform the item and return it. Please override. - ''' - assert is_overriden, 'Json.__getitem__ is not a valid implementation' - if isinstance(item, slice) \ - and isinstance(item.start, (NoneType, NoneMonad)) \ - and isinstance(item.stop, (NoneType, NoneMonad)): - return FullSliceMonad(monad.translator) - return item + def __getitem__(monad, key): + return monad.translator.JsonItemMonad(monad, key) def contains(monad, item, not_in=False): translator = monad.translator @@ -1617,15 +1609,7 @@ def len(monad): return translator.NumericExprMonad( translator, int, ['JSON_ARRAY_LENGTH', sql]) -class JsonAttrMonad(JsonMixin, AttrMonad): - def __getitem__(monad, key): - key = JsonMixin.__getitem__(monad, key, True) - return monad.translator.JsonItemMonad(monad, [key]) - - @property - def attr_monad(monad): - return monad - +class JsonAttrMonad(JsonMixin, AttrMonad): pass class ParamMonad(Monad): @staticmethod @@ -1791,21 +1775,31 @@ class JsonExprMonad(JsonMixin, ExprMonad): pass class JsonItemMonad(JsonMixin, Monad): - - def __init__(monad, attr_monad, path): - translator = attr_monad.translator - monad.attr_monad = attr_monad - monad.path = path + def __init__(monad, parent, key): + assert isinstance(parent, JsonMixin), parent + translator = parent.translator + if isinstance(key, slice): + for item in (key.start, key.stop, key.step): + if not isinstance(item, (NoneType, NoneMonad)): + throw(NotImplementedError) + elif not isinstance(key, (EllipsisMonad, StringConstMonad, NumericConstMonad)): + throw(NotImplementedError) Monad.__init__(monad, translator, Json) + monad.parent = parent + monad.key = key - def __getitem__(monad, key): - key = JsonMixin.__getitem__(monad, key, True) - return monad.translator.JsonItemMonad( - monad.attr_monad, monad.path + [key]) + def get_path(monad): + path = [] + while isinstance(monad, JsonItemMonad): + path.append(monad.key) + monad = monad.parent + path.reverse() + return monad, path def getsql(monad): - base_sql, = monad.attr_monad.getsql() - path_sql = monad._get_path_sql(monad.path) + base_monad, path = monad.get_path() + base_sql = base_monad.getsql()[0] + path_sql = monad._get_path_sql(path) if any(isinstance(item, AnyItem) for item in path_sql): sql = ['JSON_GETPATH_STARRED'] else: @@ -1852,11 +1846,6 @@ def __init__(monad, translator, value=None): class EllipsisMonad(ConstMonad): pass -class FullSliceMonad(ConstMonad): - SLICE = slice(None, None, None) - def __init__(monad, translator): - ConstMonad.__init__(monad, translator, monad.SLICE) - class BufferConstMonad(BufferMixin, ConstMonad): pass class StringConstMonad(StringMixin, ConstMonad): From 0c13d4a236210c008379ed192c84546334195ebd Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 12 Jul 2016 21:09:59 +0300 Subject: [PATCH 034/547] Remove JsonBoolExprMonad --- pony/orm/sqltranslation.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 1a8445f40..74bcbb602 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1590,8 +1590,7 @@ def contains(monad, item, not_in=False): translator = monad.translator expr = monad.translator.JsonContainsExprMonad(monad, item) if not_in: - sql, = expr.getsql() - expr = translator.JsonBoolExprMonad(translator, bool, ['NOT', sql]) + expr = translator.BoolExprMonad(translator, ['NOT', expr.getsql()[0]]) return expr def __or__(monad, other): @@ -1697,9 +1696,6 @@ class TimeExprMonad(TimeMixin, ExprMonad): pass class TimedeltaExprMonad(TimedeltaMixin, ExprMonad): pass class DatetimeExprMonad(DatetimeMixin, ExprMonad): pass -class JsonBoolExprMonad(ExprMonad): - pass - class JsonContainsExprMonad(Monad): def __init__(monad, json_monad, item): monad.json_monad = json_monad From ff7b83a8f370ed41480699fab82fc2624e65a112 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 14 Jul 2016 01:48:21 +0300 Subject: [PATCH 035/547] Refactor of JsonMixin.contains, get rid of JsonContainsExprMonad --- pony/orm/dbproviders/mysql.py | 42 ++++++-------------------------- pony/orm/dbproviders/oracle.py | 39 ++++------------------------- pony/orm/dbproviders/postgres.py | 12 --------- pony/orm/sqlbuilding.py | 2 ++ pony/orm/sqltranslation.py | 39 +++++++++++------------------ 5 files changed, 29 insertions(+), 105 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index fdd501ff8..075ddf17b 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -1,6 +1,7 @@ from __future__ import absolute_import from pony.py23compat import PY2, imap, basestring, buffer, int_types +import json from decimal import Decimal from datetime import datetime, date, time, timedelta from uuid import UUID @@ -86,37 +87,6 @@ def dispatch_type(cls, typ): raise sqltranslation.AbortCast return sqltranslation.CastFromJsonExprMonad.dispatch_type(typ) - class JsonContainsExprMonad(sqltranslation.JsonContainsExprMonad): - - def __init__(monad, json_monad, item): - if not isinstance(item, sqltranslation.StringConstMonad): - raise NotImplementedError - sqltranslation.JsonContainsExprMonad.__init__( - monad, json_monad, item - ) - - def _dict_contains(monad): - path_sql = monad.json_monad._get_path_sql( - getattr(monad.json_monad, 'path', ()) - ) - path_sql.append(monad.item.value) - return ['JSON_CONTAINS_PATH', monad.attr_sql, path_sql] - - def _list_contains(monad): - translator = monad.translator - path_sql = monad.json_monad._get_path_sql( - getattr(monad.json_monad, 'path', ()) - ) - item = translator.ConstMonad.new(translator, '["%s"]' % monad.item.value) - item_sql, = item.getsql() - return ['JSON_CONTAINS', monad.attr_sql, path_sql, item_sql] - - def getsql(monad): - return [ - ['OR', monad._dict_contains(), monad._list_contains()] - ] - - class MySQLBuilder(SQLBuilder): dialect = 'MySQL' def CONCAT(builder, *args): @@ -169,9 +139,13 @@ def EQ_JSON(builder, left, right): def NE_JSON(builder, left, right): return '(', builder(left), '!=', builder.AS_JSON(right), ')' def JSON_CONTAINS(builder, expr, path, key): - return 'json_contains(', builder(expr), ', ', builder(key), ', ', builder(path), ')' - def JSON_CONTAINS_PATH(builder, expr, path): - return 'json_contains_path(', builder(expr), ", 'one', ", builder(path), ')' + assert key[0] == 'VALUE' and isinstance(key[1], basestring) + expr_sql = builder(expr) + result = [ '(json_contains(', expr_sql, ', ', builder([ 'VALUE', json.dumps([ key[1] ]) ]) ] + path_sql = builder(path) + path_with_key_sql = builder(path + [ key[1] ]) + result += [ ', ', path_sql, ') or json_contains_path(', expr_sql, ", 'one', ", path_with_key_sql, '))' ] + return result class MySQLStrConverter(dbapiprovider.StrConverter): def sql_type(converter): diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 443f4626f..5aaa7b9e3 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -129,32 +129,6 @@ class JsonItemMonad(sqltranslation.JsonItemMonad): def nonzero(monad): raise NotImplementedError - class JsonContainsExprMonad(sqltranslation.JsonContainsExprMonad): - - def __init__(monad, json_monad, item): - if not isinstance(item, sqltranslation.StringConstMonad): - raise NotImplementedError - sqltranslation.JsonContainsExprMonad.__init__( - monad, json_monad, item - ) - - def _dict_contains(monad): - path_sql = monad.json_monad._get_path_sql( - getattr(monad.json_monad, 'path', ()) - ) - path_sql.append(monad.item.value) - return ['JSON_CONTAINS_PATH', monad.attr_sql, path_sql] - - def _list_contains(monad): - path_sql = monad.json_monad._get_path_sql( - getattr(monad.json_monad, 'path', ()) - ) - return ['JSON_LIST_CONTAINS', monad.attr_sql, path_sql, monad.item.value] - - def getsql(monad): - return [ ['OR', monad._dict_contains(), monad._list_contains()] ] - - class OraBuilder(sqlbuilding.SQLBuilder): dialect = 'Oracle' def INSERT(builder, table_name, columns, values, returning=None): @@ -263,14 +237,11 @@ def JSON_GETPATH(builder, expr, key): return 'REGEXP_REPLACE(', query, ", '(^\[|\]$)', '')" def JSON_EXISTS(builder, expr, key): return 'JSON_EXISTS(', builder(expr), ', ', builder(key), ')' - def JSON_CONTAINS_PATH(builder, expr, path): - return builder.JSON_EXISTS(expr, path) - def JSON_LIST_CONTAINS(builder, expr, path, key): - query = 'JSON_QUERY(', builder(expr), ', ', builder(path), ')' - return 'REGEXP_LIKE(', query, ', \'', search_in_json_list_regexp(key), '\')' - -def search_in_json_list_regexp(what): - return r'^\[(.+, ?)?"%s"(, ?.+)?\]$' % what + def JSON_CONTAINS(builder, expr, path, key): + assert key[0] == 'VALUE' and isinstance(key[1], basestring) + expr_sql = builder(expr) + path_sql = builder(path + [ key[1] ]) + return 'JSON_EXISTS(', expr_sql, ', ', path_sql, ')' class OraBoolConverter(dbapiprovider.BoolConverter): if not PY2: diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index aba165c26..16ac29e39 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -1,7 +1,6 @@ from __future__ import absolute_import from pony.py23compat import PY2, basestring, unicode, buffer, int_types -import json from decimal import Decimal from datetime import datetime, date, time, timedelta from uuid import UUID @@ -99,17 +98,6 @@ def getsql(monad): monad.sql = ['SINGLE_QUOTES', monad.sql] return sqltranslation.CastToJsonExprMonad.getsql(monad) - class JsonContainsExprMonad(sqltranslation.JsonContainsExprMonad): - def getsql(monad): - json_monad = monad.json_monad - path = getattr(json_monad, 'path', ()) - path_sql = json_monad._get_path_sql(path) if path else None - item_sql, = monad.item.getsql() - return [ - ['JSON_CONTAINS', monad.attr_sql, path_sql, item_sql] - ] - - class PGValue(Value): __slots__ = [] def __unicode__(self): diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 292459805..356d63f99 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -523,3 +523,5 @@ def JSON_GETPATH_STARRED(builder, expr, key): return builder.JSON_GETPATH(expr, key) def CAST(builder, expr, type): return 'CAST(', builder(expr), ' AS ', type, ')' + def JSON_CONTAINS(builder, expr, path, key): + raise NotImplementedError diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 74bcbb602..0ac401f49 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1571,7 +1571,10 @@ class JsonMixin(object): def mixin_init(monad): assert monad.type is Json, monad.type - def _get_path_sql(monad, path): + def get_path(monad): + return monad, [] + + def get_path_sql(monad, path): result = [ 'JSON_PATH' ] for item in path: if isinstance(item, EllipsisMonad): @@ -1586,12 +1589,16 @@ def _get_path_sql(monad, path): def __getitem__(monad, key): return monad.translator.JsonItemMonad(monad, key) - def contains(monad, item, not_in=False): + def contains(monad, key, not_in=False): + if not isinstance(key, StringConstMonad): raise NotImplementedError translator = monad.translator - expr = monad.translator.JsonContainsExprMonad(monad, item) - if not_in: - expr = translator.BoolExprMonad(translator, ['NOT', expr.getsql()[0]]) - return expr + base_monad, path = monad.get_path() + base_sql = base_monad.getsql()[0] + path_sql = monad.get_path_sql(path) + key_sql = key.getsql()[0] + sql = [ 'JSON_CONTAINS', base_sql, path_sql, key_sql ] + if not_in: sql = [ 'NOT', sql ] + return translator.BoolExprMonad(translator, sql) def __or__(monad, other): translator = monad.translator @@ -1696,24 +1703,6 @@ class TimeExprMonad(TimeMixin, ExprMonad): pass class TimedeltaExprMonad(TimedeltaMixin, ExprMonad): pass class DatetimeExprMonad(DatetimeMixin, ExprMonad): pass -class JsonContainsExprMonad(Monad): - def __init__(monad, json_monad, item): - monad.json_monad = json_monad - monad.item = item - Monad.__init__(monad, json_monad.translator, bool) - monad.attr_sql = json_monad.attr_monad.getsql()[0] - - def getsql(monad): - json_monad = monad.json_monad - path_sql = json_monad._get_path_sql( - getattr(json_monad, 'path', ()) - ) - item_sql, = monad.item.getsql() - return [ - ['JSON_CONTAINS', monad.attr_sql, path_sql, item_sql] - ] - - class CastFromJsonExprMonad(ExprMonad): def __init__(monad, type_to, translator, sql): @@ -1795,7 +1784,7 @@ def get_path(monad): def getsql(monad): base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] - path_sql = monad._get_path_sql(path) + path_sql = monad.get_path_sql(path) if any(isinstance(item, AnyItem) for item in path_sql): sql = ['JSON_GETPATH_STARRED'] else: From 7923a0c962fec01bbfa07913f7c571fb2049cd2d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 14 Jul 2016 09:28:51 +0300 Subject: [PATCH 036/547] Minor refactoring of PostgreSQL JSON_CONTAINS method --- pony/orm/dbproviders/postgres.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 16ac29e39..ced76b580 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -150,11 +150,7 @@ def JSON_GETPATH(builder, expr, key): def JSON_CONCAT(builder, left, right): return '(', builder(left), '||', builder(right), ')' def JSON_CONTAINS(builder, expr, path, key): - if path: - json_sql = builder.JSON_GETPATH(expr, path) - else: - json_sql = builder(expr) - return json_sql, " ? ", builder(key) + return (builder.JSON_GETPATH(expr, path) if path else builder(expr)), ' ? ', builder(key) def JSON_IS_CONTAINED(builder, value, contained_in): raise NotImplementedError('Not needed') def JSON_HAS_ANY(builder, array, value): From cc3a9813b2539f42987043a8254328bcb4479662 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 6 Jul 2016 21:23:48 +0300 Subject: [PATCH 037/547] Refactoring of CmpMonad, CastFromJsonExprMonad, CastToJsonExprMonad --- pony/orm/dbproviders/mysql.py | 41 +----- pony/orm/dbproviders/postgres.py | 54 +------- pony/orm/dbproviders/sqlite.py | 9 +- pony/orm/sqlbuilding.py | 12 +- pony/orm/sqltranslation.py | 128 ++---------------- pony/orm/tests/test_declarative_exceptions.py | 4 +- 6 files changed, 40 insertions(+), 208 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 075ddf17b..b0d566001 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -47,46 +47,6 @@ class MySQLSchema(dbschema.DBSchema): class MySQLTranslator(SQLTranslator): dialect = 'MySQL' - - - class CmpMonad(sqltranslation.CmpMonad): - - def make_json_cast_if_needed(monad, left_sql, right_sql): - translator = monad.left.translator - if monad.op not in ('==', '!='): - return sqltranslation.CmpMonad.make_json_cast_if_needed( - monad, left_sql, right_sql - ) - def need_cast(monad): - if isinstance(monad, sqltranslation.ParamMonad): - return True - return not isinstance(monad, sqltranslation.JsonMixin) - - if need_cast(monad.left): - sql = left_sql[0] - expr = translator.CastToJsonExprMonad( - translator, sql, target_monad=monad.left - ) - return expr.getsql(), right_sql - if need_cast(monad.right): - sql = right_sql[0] - expr = translator.CastToJsonExprMonad( - translator, sql, target_monad=monad.right - ) - return left_sql, expr.getsql() - return left_sql, right_sql - - - class CastFromJsonExprMonad(sqltranslation.CastFromJsonExprMonad): - - @classmethod - def dispatch_type(cls, typ): - if issubclass(typ, int): - return 'signed' - if issubclass(typ, float): - raise sqltranslation.AbortCast - return sqltranslation.CastFromJsonExprMonad.dispatch_type(typ) - class MySQLBuilder(SQLBuilder): dialect = 'MySQL' def CONCAT(builder, *args): @@ -146,6 +106,7 @@ def JSON_CONTAINS(builder, expr, path, key): path_with_key_sql = builder(path + [ key[1] ]) result += [ ', ', path_sql, ') or json_contains_path(', expr_sql, ", 'one', ", path_with_key_sql, '))' ] return result + type_mapping = {str: 'text', bool: 'boolean', int: 'signed', float: None, ormtypes.Json: 'json'} class MySQLStrConverter(dbapiprovider.StrConverter): def sql_type(converter): diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index ced76b580..dfccd2eaa 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -48,56 +48,12 @@ def nonzero(monad): str_not_empty = translator.CmpMonad( '!=', monad, empty_str ) - is_true = translator.CastFromJsonExprMonad( - bool, translator, monad.getsql()[0] - ) + is_true = monad.cast_from_json(bool).getsql()[0] sql = ['AND'] sql.extend(str_not_empty.getsql()) sql.extend(is_true.getsql()) return translator.BoolExprMonad(translator, sql) - class CmpMonad(sqltranslation.CmpMonad): - - def make_json_cast_if_needed(monad, left_sql, right_sql): - translator = monad.left.translator - if monad.op not in ('==', '!='): - return sqltranslation.CmpMonad.make_json_cast_if_needed( - monad, left_sql, right_sql - ) - if isinstance(monad.left, sqltranslation.NumericMixin): - sql = left_sql[0] - expr = translator.CastToJsonExprMonad( - translator, sql, target_monad=monad.left - ) - return expr.getsql(), right_sql - if isinstance(monad.right, sqltranslation.NumericMixin): - sql = right_sql[0] - expr = translator.CastToJsonExprMonad( - translator, sql, target_monad=monad.right - ) - return left_sql, expr.getsql() - return left_sql, right_sql - - - class CastFromJsonExprMonad(sqltranslation.CastFromJsonExprMonad): - - @classmethod - def dispatch_type(cls, typ): - sql_type = sqltranslation.CastFromJsonExprMonad.dispatch_type(typ) - if not issubclass(typ, (int, float, bool)): - return sql_type - return 'text::%s' % sql_type - - - class CastToJsonExprMonad(sqltranslation.CastToJsonExprMonad): - - cast_to = 'jsonb' - - def getsql(monad): - if isinstance(monad.target_monad, sqltranslation.NumericConstMonad): - monad.sql = ['SINGLE_QUOTES', monad.sql] - return sqltranslation.CastToJsonExprMonad.getsql(monad) - class PGValue(Value): __slots__ = [] def __unicode__(self): @@ -159,10 +115,12 @@ def JSON_HAS_ALL(builder, array, value): raise NotImplementedError def JSON_ARRAY_LENGTH(builder, value): return 'jsonb_array_length(', builder(value), ')' - def _as_json(builder, target): - return '(', builder(target), ')::jsonb' def CAST(builder, expr, type): - return '(', builder(expr), ')::', type + return '(', builder(expr), ')::', builder.get_cast_type_name(type) + def JSON_CAST(builder, expr, type): + type = builder.get_cast_type_name(type) + if type == 'text': return '(', builder(expr), ')::', type + return '(', builder(expr), ')::text::', type def SINGLE_QUOTES(builder, expr): return "'", builder(expr), "'" diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index cf33e5fb6..2a471fe7f 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -415,11 +415,12 @@ def func(value): @print_traceback def unwrap_extract_json(value): - ''' - [null,some-value] -> some-value - ''' + # [null,some-value] -> some-value assert value.startswith('[null,') - return value[6:-1] + result = value[6:-1] + if not result.startswith(('[', '{')): + result = json.loads(result) + return result @print_traceback def py_json_extract(value, path, quote_strings): diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 356d63f99..f7d36ccd7 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -521,7 +521,15 @@ def JSON_GETPATH(builder, expr, key): raise NotImplementedError def JSON_GETPATH_STARRED(builder, expr, key): return builder.JSON_GETPATH(expr, key) - def CAST(builder, expr, type): - return 'CAST(', builder(expr), ' AS ', type, ')' def JSON_CONTAINS(builder, expr, path, key): raise NotImplementedError + def CAST(builder, expr, type): + type_name = builder.get_cast_type_name(type) + if type_name is None: return builder(expr) + return 'CAST(', builder(expr), ' AS ', type_name, ')' + JSON_CAST = CAST + def get_cast_type_name(builder, type): + if isinstance(type, basestring): return type + if type not in builder.typecast_mapping: throw(NotImplementedError, type) + return builder.typecast_mapping[type] + typecast_mapping = {unicode: 'text', bool: 'boolean', int: 'integer', float: 'real', Json: None} diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 0ac401f49..6cd6027dd 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1070,6 +1070,7 @@ def __or__(monad): throw(TypeError) def __and__(monad): throw(TypeError) def __xor__(monad): throw(TypeError) def abs(monad): throw(TypeError) + def cast_from_json(monad, type): throw(TypeError) class RawSQLMonad(Monad): def __init__(monad, translator, rawtype, varkey): @@ -1571,6 +1572,11 @@ class JsonMixin(object): def mixin_init(monad): assert monad.type is Json, monad.type + def cast_from_json(monad, type): + translator = monad.translator + if issubclass(type, Json): return monad + return translator.ExprMonad.new(translator, type, ['JSON_CAST', monad.getsql()[0], type ]) + def get_path(monad): return monad, [] @@ -1702,62 +1708,8 @@ class DateExprMonad(DateMixin, ExprMonad): pass class TimeExprMonad(TimeMixin, ExprMonad): pass class TimedeltaExprMonad(TimedeltaMixin, ExprMonad): pass class DatetimeExprMonad(DatetimeMixin, ExprMonad): pass +class JsonExprMonad(JsonMixin, ExprMonad): pass -class CastFromJsonExprMonad(ExprMonad): - - def __init__(monad, type_to, translator, sql): - monad.type_to = type_to - ExprMonad.__init__(monad, type_to, translator, sql) - - - @classmethod - def dispatch_type(cls, typ): - if issubclass(typ, bool): - return 'boolean' - if issubclass(typ, int): - return 'integer' - if issubclass(typ, float): - return 'real' - - - def getsql(monad): - sql_type = monad.dispatch_type(monad.type_to) - sql = ['CAST', monad.sql, sql_type] - return [sql] - - -class CastToJsonExprMonad(ExprMonad): - - cast_to = 'JSON' - target_monad = None - - def __new__(cls, *args, **kwargs): - kwargs.pop('target_monad', None) - return ExprMonad.__new__(cls, *args, **kwargs) - - def __init__(monad, translator, sql, target_monad=None): - ExprMonad.__init__(monad, translator, Json, sql) - if target_monad: - monad.target_monad = target_monad - - def getsql(monad): - sql = monad.sql - if monad.target_monad: - m = monad.target_monad - if isinstance(m, ConstMonad) and issubclass(m.type, bool): - sql = [ - "RAWSQL", - "'%s'" % ('true' if m.value else 'false') - ] - sql = ['CAST', sql, monad.cast_to] - return [sql] - - -class AbortCast(Exception): - pass - -class JsonExprMonad(JsonMixin, ExprMonad): - pass class JsonItemMonad(JsonMixin, Monad): def __init__(monad, parent, key): @@ -1807,6 +1759,7 @@ def new(translator, value): elif value_type is datetime: cls = translator.DatetimeConstMonad elif value_type is NoneType: cls = translator.NoneMonad elif value_type is buffer: cls = translator.BufferConstMonad + elif value_type is Json: cls = translator.JsonConstMonad elif issubclass(value_type, type(Ellipsis)): cls = translator.EllipsisMonad else: throw(NotImplementedError, value_type) # pragma: no cover result = cls(translator, value) @@ -1831,12 +1784,12 @@ def __init__(monad, translator, value=None): class EllipsisMonad(ConstMonad): pass -class BufferConstMonad(BufferMixin, ConstMonad): pass - class StringConstMonad(StringMixin, ConstMonad): def len(monad): return monad.translator.ConstMonad.new(monad.translator, len(monad.value)) +class JsonConstMonad(JsonMixin, ConstMonad): pass +class BufferConstMonad(BufferMixin, ConstMonad): pass class NumericConstMonad(NumericMixin, ConstMonad): pass class DateConstMonad(DateMixin, ConstMonad): pass class TimeConstMonad(TimeMixin, ConstMonad): pass @@ -1894,62 +1847,17 @@ def __init__(monad, op, left, right): result_type, left, right = coerce_monads(left, right) BoolMonad.__init__(monad, translator) monad.op = op - monad.left = left - monad.right = right monad.aggregated = getattr(left, 'aggregated', False) or getattr(right, 'aggregated', False) if isinstance(left, JsonMixin): - json_monad, other_monad = left, right - elif isinstance(right, JsonMixin): - json_monad, other_monad = right, left - else: - return - - # Customizing comparisons for Json - if op in ('==', '!='): - if isinstance(other_monad, StringConstMonad): - other_monad.value = '"%s"' % right.value - elif isinstance(other_monad, ParamMonad): - other_monad.converter = translator.database.provider \ - .get_converter_by_py_type(Json) + left = left.cast_from_json(right.type) + if isinstance(right, JsonMixin): + right = right.cast_from_json(left.type) + monad.left = left + monad.right = right def negate(monad): return monad.translator.CmpMonad(cmp_negate[monad.op], monad.left, monad.right) - - def make_json_cast_if_needed(monad, left_sql, right_sql): - translator = monad.left.translator - is_needed = monad.op in ('<', '>', '==', '!=') and any( - isinstance(m, NumericMixin) for m in (monad.left, monad.right) - ) - if not is_needed: - return left_sql, right_sql - # special handling for boolean constants - if monad.op in ('==', '!='): - if isinstance(monad.left, ConstMonad) and issubclass(monad.left.type, bool): - # FIXME use CastToJson - bool_sql = [ - "RAWSQL", - "'%s'" % ('true' if monad.left.value else 'false') - ] - return [bool_sql], right_sql - if isinstance(monad.right, ConstMonad) and issubclass(monad.right.type, bool): - bool_sql = [ - "RAWSQL", - "'%s'" % ('true' if monad.right.value else 'false') - ] - return left_sql, [bool_sql] - if isinstance(monad.left, JsonMixin): - other_monad = monad.right - expr = translator.CastFromJsonExprMonad( - other_monad.type, translator, left_sql[0] - ) - return expr.getsql(), right_sql - other_monad = monad.left - expr = translator.CastFromJsonExprMonad( - other_monad.type, translator, right_sql[0] - ) - return left_sql, expr.getsql() - def getsql(monad, subquery=None): op = monad.op left_sql = monad.left.getsql() @@ -1958,12 +1866,6 @@ def getsql(monad, subquery=None): if op == 'is not': return [ sqland([ [ 'IS_NOT_NULL', item ] for item in left_sql ]) ] right_sql = monad.right.getsql() - - if any(isinstance(m, JsonMixin) for m in (monad.left, monad.right)): - try: - left_sql, right_sql = monad.make_json_cast_if_needed(left_sql, right_sql) - except AbortCast: - pass if len(left_sql) == 1 and left_sql[0][0] == 'ROW': left_sql = left_sql[0][1:] if len(right_sql) == 1 and right_sql[0][0] == 'ROW': diff --git a/pony/orm/tests/test_declarative_exceptions.py b/pony/orm/tests/test_declarative_exceptions.py index 8721272bb..42af7a8ef 100644 --- a/pony/orm/tests/test_declarative_exceptions.py +++ b/pony/orm/tests/test_declarative_exceptions.py @@ -210,7 +210,9 @@ def test49(self): sum(s.name for s in Student) @raises_exception(NotImplementedError, "Parameter {'a':'b'} has unsupported type ") def test50(self): - select(s for s in Student if s.name == {'a' : 'b'}) + # cannot compare JSON value to dynamic string, + # because a database does not provide json.dumps(s.name) functionality + select(s for s in Student if s.name == {'a': 'b'}) @raises_exception(IncomparableTypesError, "Incomparable types '%s' and 'int' in expression: s.name > a & 2" % unicode.__name__) def test51(self): a = 1 From d10130ad1eea55485c7d2cf4e19d48f411372d0e Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 14 Jul 2016 09:14:37 +0300 Subject: [PATCH 038/547] Remove JsonConcatExprMonad --- pony/orm/sqltranslation.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 6cd6027dd..3ccbc502e 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1613,7 +1613,7 @@ def __or__(monad, other): left_sql, = monad.getsql() right_sql, = other.getsql() sql = ['JSON_CONCAT', left_sql, right_sql] - return translator.JsonConcatExprMonad(translator, Json, sql) + return translator.JsonExprMonad(translator, Json, sql) def len(monad): translator = monad.translator @@ -1909,11 +1909,6 @@ def getsql(monad, subquery=None): result.extend(operand_sql) return [ result ] - -class JsonConcatExprMonad(JsonMixin, ExprMonad): - pass - - class AndMonad(LogicalBinOpMonad): binop = 'AND' From 4628e07f9e68e0597545b64388038667c2e49f39 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 14 Jul 2016 09:36:51 +0300 Subject: [PATCH 039/547] Minor refactoring of JsonItemMonad.getsql() --- pony/orm/sqltranslation.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 3ccbc502e..c3f691c4e 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1737,12 +1737,10 @@ def getsql(monad): base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] path_sql = monad.get_path_sql(path) - if any(isinstance(item, AnyItem) for item in path_sql): - sql = ['JSON_GETPATH_STARRED'] - else: - sql = ['JSON_GETPATH'] - sql.extend((base_sql, path_sql)) - return [sql] + for item in path_sql: + if isinstance(item, AnyItem): + return [ ['JSON_GETPATH_STARRED', base_sql, path_sql] ] + return [ ['JSON_GETPATH', base_sql, path_sql] ] def nonzero(monad): return monad From 1a811fe9a5b3730eb2ce83514ccebf2cd8f966f7 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 14 Jul 2016 11:32:22 +0300 Subject: [PATCH 040/547] Cleanup of builder methods --- pony/orm/dbproviders/mysql.py | 6 ++---- pony/orm/dbproviders/oracle.py | 6 ++---- pony/orm/dbproviders/postgres.py | 8 -------- pony/orm/sqlbuilding.py | 8 ++++++-- 4 files changed, 10 insertions(+), 18 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index b0d566001..ba0134c68 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -92,12 +92,10 @@ def JSON_GETPATH(builder, expr, key): return 'json_extract(', builder(expr), ', ', builder(key), ')' def JSON_ARRAY_LENGTH(builder, value): return 'json_length(', builder(value), ')' - def AS_JSON(builder, target): - return 'CAST(', builder(target), ' AS JSON)' def EQ_JSON(builder, left, right): - return '(', builder(left), '=', builder.AS_JSON(right), ')' + return '(', builder(left), ' = CAST(', builder(right), ' AS JSON))' def NE_JSON(builder, left, right): - return '(', builder(left), '!=', builder.AS_JSON(right), ')' + return '(', builder(left), ' != CAST(', builder(right), ' AS JSON))' def JSON_CONTAINS(builder, expr, path, key): assert key[0] == 'VALUE' and isinstance(key[1], basestring) expr_sql = builder(expr) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 5aaa7b9e3..35a2abf2a 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -230,13 +230,11 @@ def DATETIME_SUB(builder, expr, delta): if isinstance(delta, timedelta): return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' - def JSON_GETPATH_STARRED(builder, expr, key): - return 'JSON_QUERY(', builder(expr), ', ', builder(key), ' WITH WRAPPER)' def JSON_GETPATH(builder, expr, key): query = 'JSON_QUERY(', builder(expr), ', ', builder(key), ' WITH WRAPPER)' return 'REGEXP_REPLACE(', query, ", '(^\[|\]$)', '')" - def JSON_EXISTS(builder, expr, key): - return 'JSON_EXISTS(', builder(expr), ', ', builder(key), ')' + def JSON_GETPATH_STARRED(builder, expr, key): + return 'JSON_QUERY(', builder(expr), ', ', builder(key), ' WITH WRAPPER)' def JSON_CONTAINS(builder, expr, path, key): assert key[0] == 'VALUE' and isinstance(key[1], basestring) expr_sql = builder(expr) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index dfccd2eaa..12ede0433 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -107,12 +107,6 @@ def JSON_CONCAT(builder, left, right): return '(', builder(left), '||', builder(right), ')' def JSON_CONTAINS(builder, expr, path, key): return (builder.JSON_GETPATH(expr, path) if path else builder(expr)), ' ? ', builder(key) - def JSON_IS_CONTAINED(builder, value, contained_in): - raise NotImplementedError('Not needed') - def JSON_HAS_ANY(builder, array, value): - raise NotImplementedError - def JSON_HAS_ALL(builder, array, value): - raise NotImplementedError def JSON_ARRAY_LENGTH(builder, value): return 'jsonb_array_length(', builder(value), ')' def CAST(builder, expr, type): @@ -121,8 +115,6 @@ def JSON_CAST(builder, expr, type): type = builder.get_cast_type_name(type) if type == 'text': return '(', builder(expr), ')::', type return '(', builder(expr), ')::text::', type - def SINGLE_QUOTES(builder, expr): - return "'", builder(expr), "'" class PGStrConverter(dbapiprovider.StrConverter): if PY2: diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index f7d36ccd7..7c837ae61 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -518,11 +518,15 @@ def JSON_PATH(builder, *items): result.append('\'') return result def JSON_GETPATH(builder, expr, key): - raise NotImplementedError + throw(NotImplementedError) def JSON_GETPATH_STARRED(builder, expr, key): return builder.JSON_GETPATH(expr, key) + def JSON_CONCAT(builder, left, right): + throw(NotImplementedError) def JSON_CONTAINS(builder, expr, path, key): - raise NotImplementedError + throw(NotImplementedError) + def JSON_ARRAY_LENGTH(builder, value): + throw(NotImplementedError) def CAST(builder, expr, type): type_name = builder.get_cast_type_name(type) if type_name is None: return builder(expr) From ddcb7d4aa7d6d99c69bbca9ab8ed418317dd747c Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 14 Jul 2016 15:20:05 +0300 Subject: [PATCH 041/547] Get rid of JsonMixin.get_path_sql() --- pony/orm/dbproviders/mysql.py | 8 +++---- pony/orm/dbproviders/oracle.py | 10 ++++----- pony/orm/dbproviders/postgres.py | 8 +++---- pony/orm/dbproviders/sqlite.py | 14 ++++++------ pony/orm/sqlbuilding.py | 10 ++++----- pony/orm/sqltranslation.py | 37 +++++++++++--------------------- 6 files changed, 38 insertions(+), 49 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index ba0134c68..57b54b078 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -88,8 +88,8 @@ def DATETIME_SUB(builder, expr, delta): if isinstance(delta, timedelta): return 'DATE_SUB(', builder(expr), ", INTERVAL '", timedelta2str(delta), "' HOUR_SECOND)" return 'SUBTIME(', builder(expr), ', ', builder(delta), ')' - def JSON_GETPATH(builder, expr, key): - return 'json_extract(', builder(expr), ', ', builder(key), ')' + def JSON_GETPATH(builder, expr, path): + return 'json_extract(', builder(expr), ', ', builder.json_path(path), ')' def JSON_ARRAY_LENGTH(builder, value): return 'json_length(', builder(value), ')' def EQ_JSON(builder, left, right): @@ -100,8 +100,8 @@ def JSON_CONTAINS(builder, expr, path, key): assert key[0] == 'VALUE' and isinstance(key[1], basestring) expr_sql = builder(expr) result = [ '(json_contains(', expr_sql, ', ', builder([ 'VALUE', json.dumps([ key[1] ]) ]) ] - path_sql = builder(path) - path_with_key_sql = builder(path + [ key[1] ]) + path_sql = builder.json_path(path) + path_with_key_sql = builder.json_path(path + [ key[1] ]) result += [ ', ', path_sql, ') or json_contains_path(', expr_sql, ", 'one', ", path_with_key_sql, '))' ] return result type_mapping = {str: 'text', bool: 'boolean', int: 'signed', float: None, ormtypes.Json: 'json'} diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 35a2abf2a..5f7206596 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -230,15 +230,15 @@ def DATETIME_SUB(builder, expr, delta): if isinstance(delta, timedelta): return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' - def JSON_GETPATH(builder, expr, key): - query = 'JSON_QUERY(', builder(expr), ', ', builder(key), ' WITH WRAPPER)' + def JSON_GETPATH(builder, expr, path): + query = 'JSON_QUERY(', builder(expr), ', ', builder.json_path(path), ' WITH WRAPPER)' return 'REGEXP_REPLACE(', query, ", '(^\[|\]$)', '')" - def JSON_GETPATH_STARRED(builder, expr, key): - return 'JSON_QUERY(', builder(expr), ', ', builder(key), ' WITH WRAPPER)' + def JSON_GETPATH_STARRED(builder, expr, path): + return 'JSON_QUERY(', builder(expr), ', ', builder.json_path(path), ' WITH WRAPPER)' def JSON_CONTAINS(builder, expr, path, key): assert key[0] == 'VALUE' and isinstance(key[1], basestring) expr_sql = builder(expr) - path_sql = builder(path + [ key[1] ]) + path_sql = builder.json_path(path + [ key[1] ]) return 'JSON_EXISTS(', expr_sql, ', ', path_sql, ')' class OraBoolConverter(dbapiprovider.BoolConverter): diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 12ede0433..7c76d505c 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -92,17 +92,17 @@ def DATETIME_SUB(builder, expr, delta): if isinstance(delta, timedelta): return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' DAY TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' - def JSON_PATH(builder, *items): + def json_path(builder, path): result = [] - for item in items: + for item in path: if isinstance(item, int): result.append(str(item)) elif isinstance(item, basestring): result.append(item if is_ident(item) else '"%s"' % item.replace('"', '\\"')) else: assert False, item return '{%s}' % ','.join(result) - def JSON_GETPATH(builder, expr, key): - return '(', builder(expr), "#>", builder(key), ')' + def JSON_GETPATH(builder, expr, path): + return '(', builder(expr), "#>", builder.json_path(path), ')' def JSON_CONCAT(builder, left, right): return '(', builder(left), '||', builder(right), ')' def JSON_CONTAINS(builder, expr, path, key): diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 2a471fe7f..6a796052e 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -138,14 +138,14 @@ def JSON_PATH(builder, *items): if builder.json1_available: return SQLBuilder.JSON_PATH(builder, *items) return builder.VALUE(json.dumps(items)) - def JSON_GETPATH(builder, expr, key): + def JSON_GETPATH(builder, expr, path): if not builder.json1_available: - return 'py_json_extract(', builder(expr), ', ', builder(key), ', 0)' - return 'json_extract(', builder(expr), ', ', builder(key), ')' - def JSON_GETPATH__QUOTE_STRINGS(builder, expr, key): + return 'py_json_extract(', builder(expr), ', ', builder.json_path(path), ', 0)' + return 'json_extract(', builder(expr), ', ', builder.json_path(path), ')' + def JSON_GETPATH__QUOTE_STRINGS(builder, expr, path): if not builder.json1_available: - return 'py_json_extract(', builder(expr), ', ', builder(key), ', 1)' - ret = 'json_extract(', builder(expr), ', null, ', builder(key), ')' + return 'py_json_extract(', builder(expr), ', ', builder.json_path(path), ', 1)' + ret = 'json_extract(', builder(expr), ', null, ', builder.json_path(path), ')' return 'unwrap_extract_json(', ret, ')' def JSON_ARRAY_LENGTH(builder, value): if not builder.json1_available: @@ -155,7 +155,7 @@ def JSON_CONTAINS(builder, expr, path, key): # if builder.json1_available: # TODO impl with builder.json1_disabled(): - return 'py_json_contains(', builder(expr), ', ', builder(path), ', ', builder(key), ')' + return 'py_json_contains(', builder(expr), ', ', builder.json_path(path), ', ', builder(key), ')' @contextmanager def json1_disabled(builder): diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 7c837ae61..31d298201 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -503,9 +503,9 @@ def RANDOM(builder): def RAWSQL(builder, sql): if isinstance(sql, basestring): return sql return [ x if isinstance(x, basestring) else builder(x) for x in sql ] - def JSON_PATH(builder, *items): + def json_path(builder, path): result = ['\'$'] - for item in items: + for item in path: if isinstance(item, int): result.append('[%d]' % item) elif item is AnyNum: @@ -517,10 +517,10 @@ def JSON_PATH(builder, *items): else: assert False result.append('\'') return result - def JSON_GETPATH(builder, expr, key): + def JSON_GETPATH(builder, expr, path): throw(NotImplementedError) - def JSON_GETPATH_STARRED(builder, expr, key): - return builder.JSON_GETPATH(expr, key) + def JSON_GETPATH_STARRED(builder, expr, path): + return builder.JSON_GETPATH(expr, path) def JSON_CONCAT(builder, left, right): throw(NotImplementedError) def JSON_CONTAINS(builder, expr, path, key): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index c3f691c4e..e469cd8d5 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1580,18 +1580,6 @@ def cast_from_json(monad, type): def get_path(monad): return monad, [] - def get_path_sql(monad, path): - result = [ 'JSON_PATH' ] - for item in path: - if isinstance(item, EllipsisMonad): - result.append(AnyStr) - elif isinstance(item, slice): - result.append(AnyNum) - elif isinstance(item, (NumericConstMonad, StringConstMonad)): - result.append(item.value) - raise TypeError('Invalid JSON path item: %s' % ast2src(item.node)) - return result - def __getitem__(monad, key): return monad.translator.JsonItemMonad(monad, key) @@ -1600,9 +1588,8 @@ def contains(monad, key, not_in=False): translator = monad.translator base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] - path_sql = monad.get_path_sql(path) key_sql = key.getsql()[0] - sql = [ 'JSON_CONTAINS', base_sql, path_sql, key_sql ] + sql = [ 'JSON_CONTAINS', base_sql, path, key_sql ] if not_in: sql = [ 'NOT', sql ] return translator.BoolExprMonad(translator, sql) @@ -1714,16 +1701,19 @@ class JsonExprMonad(JsonMixin, ExprMonad): pass class JsonItemMonad(JsonMixin, Monad): def __init__(monad, parent, key): assert isinstance(parent, JsonMixin), parent - translator = parent.translator + Monad.__init__(monad, parent.translator, Json) + monad.parent = parent if isinstance(key, slice): for item in (key.start, key.stop, key.step): if not isinstance(item, (NoneType, NoneMonad)): throw(NotImplementedError) - elif not isinstance(key, (EllipsisMonad, StringConstMonad, NumericConstMonad)): - throw(NotImplementedError) - Monad.__init__(monad, translator, Json) - monad.parent = parent - monad.key = key + keyval = AnyNum + elif isinstance(key, EllipsisMonad): + keyval = AnyStr + elif isinstance(key, (StringConstMonad, NumericConstMonad)): + keyval = key.value + else: throw(TypeError, 'Invalid JSON path item: %s' % ast2src(key.node)) + monad.key = keyval def get_path(monad): path = [] @@ -1736,11 +1726,10 @@ def get_path(monad): def getsql(monad): base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] - path_sql = monad.get_path_sql(path) - for item in path_sql: + for item in path: if isinstance(item, AnyItem): - return [ ['JSON_GETPATH_STARRED', base_sql, path_sql] ] - return [ ['JSON_GETPATH', base_sql, path_sql] ] + return [ ['JSON_GETPATH_STARRED', base_sql, path] ] + return [ ['JSON_GETPATH', base_sql, path] ] def nonzero(monad): return monad From 31833bda7faf7c6a7c476d2dd695c5f994c0832c Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 14 Jul 2016 15:29:21 +0300 Subject: [PATCH 042/547] Remove JSON_GETPATH_STARRED --- pony/orm/dbproviders/oracle.py | 9 +++++---- pony/orm/sqlbuilding.py | 2 -- pony/orm/sqltranslation.py | 3 --- 3 files changed, 5 insertions(+), 9 deletions(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 5f7206596..ddeaf5bab 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -13,7 +13,7 @@ from pony.orm import core, sqlbuilding, dbapiprovider, sqltranslation from pony.orm.core import log_orm, log_sql, DatabaseError, TranslationError from pony.orm.dbschema import DBSchema, DBObject, Table, Column -from pony.orm.ormtypes import Json +from pony.orm.ormtypes import Json, AnyItem from pony.orm.dbapiprovider import DBAPIProvider, wrap_dbapi_exceptions, get_version_tuple from pony.utils import throw from pony.converting import timedelta2str @@ -231,10 +231,11 @@ def DATETIME_SUB(builder, expr, delta): return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' def JSON_GETPATH(builder, expr, path): + for item in path: + if isinstance(item, AnyItem): + return 'JSON_QUERY(', builder(expr), ', ', builder.json_path(path), ' WITH WRAPPER)' query = 'JSON_QUERY(', builder(expr), ', ', builder.json_path(path), ' WITH WRAPPER)' - return 'REGEXP_REPLACE(', query, ", '(^\[|\]$)', '')" - def JSON_GETPATH_STARRED(builder, expr, path): - return 'JSON_QUERY(', builder(expr), ', ', builder.json_path(path), ' WITH WRAPPER)' + return 'REGEXP_REPLACE(', query, ", '(^\\[|\\]$)', '')" def JSON_CONTAINS(builder, expr, path, key): assert key[0] == 'VALUE' and isinstance(key[1], basestring) expr_sql = builder(expr) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 31d298201..3a4b7d5dc 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -519,8 +519,6 @@ def json_path(builder, path): return result def JSON_GETPATH(builder, expr, path): throw(NotImplementedError) - def JSON_GETPATH_STARRED(builder, expr, path): - return builder.JSON_GETPATH(expr, path) def JSON_CONCAT(builder, left, right): throw(NotImplementedError) def JSON_CONTAINS(builder, expr, path, key): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index e469cd8d5..3ffca9cc6 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1726,9 +1726,6 @@ def get_path(monad): def getsql(monad): base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] - for item in path: - if isinstance(item, AnyItem): - return [ ['JSON_GETPATH_STARRED', base_sql, path] ] return [ ['JSON_GETPATH', base_sql, path] ] def nonzero(monad): From cc9cb4c0ca26297a262136e15b03c05bbc9204b0 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 22 Jul 2016 17:25:41 +0300 Subject: [PATCH 043/547] Minor formatting changes --- pony/orm/sqltranslation.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 3ffca9cc6..7c471f7ba 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1571,18 +1571,14 @@ class JsonMixin(object): def mixin_init(monad): assert monad.type is Json, monad.type - def cast_from_json(monad, type): translator = monad.translator if issubclass(type, Json): return monad return translator.ExprMonad.new(translator, type, ['JSON_CAST', monad.getsql()[0], type ]) - def get_path(monad): return monad, [] - def __getitem__(monad, key): return monad.translator.JsonItemMonad(monad, key) - def contains(monad, key, not_in=False): if not isinstance(key, StringConstMonad): raise NotImplementedError translator = monad.translator @@ -1592,21 +1588,18 @@ def contains(monad, key, not_in=False): sql = [ 'JSON_CONTAINS', base_sql, path, key_sql ] if not_in: sql = [ 'NOT', sql ] return translator.BoolExprMonad(translator, sql) - def __or__(monad, other): translator = monad.translator if not isinstance(other, translator.JsonMixin): raise TypeError('Should be JSON: %s' % ast2src(other.node)) - left_sql, = monad.getsql() - right_sql, = other.getsql() - sql = ['JSON_CONCAT', left_sql, right_sql] + left_sql = monad.getsql()[0] + right_sql = other.getsql()[0] + sql = [ 'JSON_CONCAT', left_sql, right_sql ] return translator.JsonExprMonad(translator, Json, sql) - def len(monad): translator = monad.translator - sql, = monad.getsql() - return translator.NumericExprMonad( - translator, int, ['JSON_ARRAY_LENGTH', sql]) + sql = [ 'JSON_ARRAY_LENGTH', monad.getsql()[0] ] + return translator.NumericExprMonad(translator, int, sql) class JsonAttrMonad(JsonMixin, AttrMonad): pass @@ -1697,7 +1690,6 @@ class TimedeltaExprMonad(TimedeltaMixin, ExprMonad): pass class DatetimeExprMonad(DatetimeMixin, ExprMonad): pass class JsonExprMonad(JsonMixin, ExprMonad): pass - class JsonItemMonad(JsonMixin, Monad): def __init__(monad, parent, key): assert isinstance(parent, JsonMixin), parent @@ -1714,7 +1706,6 @@ def __init__(monad, parent, key): keyval = key.value else: throw(TypeError, 'Invalid JSON path item: %s' % ast2src(key.node)) monad.key = keyval - def get_path(monad): path = [] while isinstance(monad, JsonItemMonad): @@ -1722,12 +1713,10 @@ def get_path(monad): monad = monad.parent path.reverse() return monad, path - def getsql(monad): base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] - return [ ['JSON_GETPATH', base_sql, path] ] - + return [ [ 'JSON_GETPATH', base_sql, path ] ] def nonzero(monad): return monad From 6c7cbdb656f68dcc41a1a2a6b9f2d2897855a7e6 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 18 Jul 2016 14:10:54 +0300 Subject: [PATCH 044/547] Replace JSON_GETPATH and JSON_CAST to JSON_QUERY and JSON_VALUE --- pony/orm/dbproviders/mysql.py | 10 +- pony/orm/dbproviders/oracle.py | 11 +- pony/orm/dbproviders/postgres.py | 18 +-- pony/orm/dbproviders/sqlite.py | 145 ++++++++++-------- pony/orm/sqlbuilding.py | 14 +- pony/orm/sqltranslation.py | 17 +- pony/orm/tests/test_declarative_exceptions.py | 2 +- 7 files changed, 122 insertions(+), 95 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 57b54b078..df8aa2ee5 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -88,8 +88,15 @@ def DATETIME_SUB(builder, expr, delta): if isinstance(delta, timedelta): return 'DATE_SUB(', builder(expr), ", INTERVAL '", timedelta2str(delta), "' HOUR_SECOND)" return 'SUBTIME(', builder(expr), ', ', builder(delta), ')' - def JSON_GETPATH(builder, expr, path): + def JSON_QUERY(builder, expr, path): return 'json_extract(', builder(expr), ', ', builder.json_path(path), ')' + def JSON_VALUE(builder, expr, path, type): + result = 'json_extract(', builder(expr), ', ', builder.json_path(path), ')' + if type is NoneType: + return 'NULLIF(', result, ", CAST('null' as JSON))" + if type in (bool, int): + return 'CAST(', result, ' AS SIGNED)' + return 'json_unquote(', result, ')' def JSON_ARRAY_LENGTH(builder, value): return 'json_length(', builder(value), ')' def EQ_JSON(builder, left, right): @@ -104,7 +111,6 @@ def JSON_CONTAINS(builder, expr, path, key): path_with_key_sql = builder.json_path(path + [ key[1] ]) result += [ ', ', path_sql, ') or json_contains_path(', expr_sql, ", 'one', ", path_with_key_sql, '))' ] return result - type_mapping = {str: 'text', bool: 'boolean', int: 'signed', float: None, ormtypes.Json: 'json'} class MySQLStrConverter(dbapiprovider.StrConverter): def sql_type(converter): diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index ddeaf5bab..ba42b2005 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -230,12 +230,17 @@ def DATETIME_SUB(builder, expr, delta): if isinstance(delta, timedelta): return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' - def JSON_GETPATH(builder, expr, path): + def JSON_QUERY(builder, expr, path): for item in path: if isinstance(item, AnyItem): return 'JSON_QUERY(', builder(expr), ', ', builder.json_path(path), ' WITH WRAPPER)' - query = 'JSON_QUERY(', builder(expr), ', ', builder.json_path(path), ' WITH WRAPPER)' - return 'REGEXP_REPLACE(', query, ", '(^\\[|\\]$)', '')" + return 'REGEXP_REPLACE(JSON_QUERY(', \ + builder(expr), ', ', builder.json_path(path), " WITH WRAPPER), '(^\\[|\\]$)', '')" + json_value_type_mapping = {bool: 'NUMBER', int: 'NUMBER', float: 'NUMBER'} + def JSON_VALUE(builder, expr, path, type): + if type is Json: return builder.JSON_QUERY(expr, path) + type_name = builder.json_value_type_mapping.get(type, 'VARCHAR2') + return 'JSON_VALUE(', builder(expr), ', ', builder.json_path(path), ' RETURNING ', type_name, ')' def JSON_CONTAINS(builder, expr, path, key): assert key[0] == 'VALUE' and isinstance(key[1], basestring) expr_sql = builder(expr) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 7c76d505c..24cc2a9cf 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -101,20 +101,20 @@ def json_path(builder, path): result.append(item if is_ident(item) else '"%s"' % item.replace('"', '\\"')) else: assert False, item return '{%s}' % ','.join(result) - def JSON_GETPATH(builder, expr, path): - return '(', builder(expr), "#>", builder.json_path(path), ')' + def JSON_QUERY(builder, expr, path): + return '(', builder(expr), " #> ", builder.json_path(path), ')' + json_value_type_mapping = {bool: 'boolean', int: 'integer', float: 'real'} + def JSON_VALUE(builder, expr, path, type): + if type is ormtypes.Json: return builder.JSON_QUERY(expr, path) + type_name = builder.json_value_type_mapping.get(type, 'text') + sql = '(', builder(expr), " #>> ", builder.json_path(path), ')' + return sql if type_name == 'text' else (sql, '::', type_name) def JSON_CONCAT(builder, left, right): return '(', builder(left), '||', builder(right), ')' def JSON_CONTAINS(builder, expr, path, key): - return (builder.JSON_GETPATH(expr, path) if path else builder(expr)), ' ? ', builder(key) + return (builder.JSON_QUERY(expr, path) if path else builder(expr)), ' ? ', builder(key) def JSON_ARRAY_LENGTH(builder, value): return 'jsonb_array_length(', builder(value), ')' - def CAST(builder, expr, type): - return '(', builder(expr), ')::', builder.get_cast_type_name(type) - def JSON_CAST(builder, expr, type): - type = builder.get_cast_type_name(type) - if type == 'text': return '(', builder(expr), ')::', type - return '(', builder(expr), ')::text::', type class PGStrConverter(dbapiprovider.StrConverter): if PY2: diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 6a796052e..83e119b81 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -1,8 +1,7 @@ from __future__ import absolute_import from pony.py23compat import PY2, imap, basestring, buffer, int_types, unicode -import os.path -import json +import os.path, re, json import sqlite3 as sqlite from decimal import Decimal from datetime import datetime, date, time, timedelta @@ -134,35 +133,19 @@ def RANDOM(builder): PY_UPPER = make_unary_func('py_upper') PY_LOWER = make_unary_func('py_lower') - def JSON_PATH(builder, *items): - if builder.json1_available: - return SQLBuilder.JSON_PATH(builder, *items) - return builder.VALUE(json.dumps(items)) - def JSON_GETPATH(builder, expr, path): - if not builder.json1_available: - return 'py_json_extract(', builder(expr), ', ', builder.json_path(path), ', 0)' - return 'json_extract(', builder(expr), ', ', builder.json_path(path), ')' - def JSON_GETPATH__QUOTE_STRINGS(builder, expr, path): - if not builder.json1_available: - return 'py_json_extract(', builder(expr), ', ', builder.json_path(path), ', 1)' - ret = 'json_extract(', builder(expr), ', null, ', builder.json_path(path), ')' - return 'unwrap_extract_json(', ret, ')' + def JSON_QUERY(builder, expr, path): + fname = 'json_extract' if builder.json1_available else 'py_json_extract' + return 'py_json_unwrap(', fname, '(', builder(expr), ', null, ', builder.json_path(path), '))' + # json_value_type_mapping = {unicode: 'text', bool: 'boolean', int: 'integer', float: 'real', Json: None} + def JSON_VALUE(builder, expr, path, type): + fname = 'json_extract' if builder.json1_available else 'py_json_extract' + return fname, '(', builder(expr), ', ', builder.json_path(path), ')' def JSON_ARRAY_LENGTH(builder, value): if not builder.json1_available: raise SqliteExtensionUnavailable('json1') return 'json_array_length(', builder(value), ')' def JSON_CONTAINS(builder, expr, path, key): - # if builder.json1_available: - # TODO impl - with builder.json1_disabled(): - return 'py_json_contains(', builder(expr), ', ', builder.json_path(path), ', ', builder(key), ')' - - @contextmanager - def json1_disabled(builder): - was_available = builder.json1_available - builder.json1_available = False - yield - builder.json1_available = was_available + return 'py_json_contains(', builder(expr), ', ', builder.json_path(path), ', ', builder(key), ')' class SQLiteIntConverter(dbapiprovider.IntConverter): def sql_type(converter): @@ -414,50 +397,86 @@ def func(value): py_lower = make_string_function('py_lower', unicode.lower) @print_traceback -def unwrap_extract_json(value): +def py_json_unwrap(value): # [null,some-value] -> some-value - assert value.startswith('[null,') - result = value[6:-1] - if not result.startswith(('[', '{')): - result = json.loads(result) + assert value.startswith('[null,'), value + return value[6:-1] + +path_cache = {} + +json_path_re = re.compile(r'\[(\d+)\]|\.(?:(\w+)|"([^"]*)")', re.UNICODE) + +def _parse_path(path): + if path in path_cache: + return path_cache[path] + keys = None + if isinstance(path, basestring) and path.startswith('$'): + keys = [] + pos = 1 + path_len = len(path) + while pos < path_len: + match = json_path_re.match(path, pos) + if match is not None: + g1, g2, g3 = match.groups() + keys.append(int(g1) if g1 else g2 or g3) + pos = match.end() + else: + keys = None + break + else: keys = tuple(keys) + path_cache[path] = keys + return keys + +def _traverse(obj, keys): + if keys is None: return None + list_or_dict = (list, dict) + for key in keys: + if type(obj) not in list_or_dict: return None + try: obj = obj[key] + except (KeyError, IndexError): return None + return obj + +def _extract(expr, *paths): + expr = json.loads(expr) if isinstance(expr, basestring) else expr + result = [] + for path in paths: + keys = _parse_path(path) + result.append(_traverse(expr, keys)) + return result[0] if len(paths) == 1 else result + +@print_traceback +def py_json_extract(expr, *paths): + result = _extract(expr, *paths) + if type(result) in (list, dict): + result = json.dumps(result, separators=(',', ':')) return result @print_traceback -def py_json_extract(value, path, quote_strings): - value = json.loads(value) - for item in json.loads(path): - try: - value = value[item] - except (KeyError, IndexError): - value = None - break - if isinstance(value, int) and not isinstance(value, bool): - return value - if isinstance(value, basestring) and not quote_strings: - return value - return json.dumps(value, separators=(',', ':')) +def py_json_query(expr, path, with_wrapper): + result = _extract(expr, path) + if type(result) not in (list, dict): + if not with_wrapper: return None + result = [result] + return json.dumps(result, separators=(',', ':')) @print_traceback -def py_json_contains(value, path, key): - value = json.loads(value) - try: - for item in json.loads(path): - value = value[item] - except (KeyError, IndexError): - value = None - if isinstance(value, (list, dict)): - return key in value +def py_json_value(expr, path): + result = _extract(expr, path) + return result if type(result) not in (list, dict) else None @print_traceback -def py_json_nonzero(value, path): - value = json.loads(value) - try: - for item in json.loads(path): - value = value[item] - except (KeyError, IndexError): - value = None - return bool(value) +def py_json_contains(expr, path, key): + expr = json.loads(expr) if isinstance(expr, basestring) else expr + keys = _parse_path(path) + expr = _traverse(expr, keys) + return type(expr) in (list, dict) and key in expr +@print_traceback +def py_json_nonzero(expr, path): + expr = json.loads(expr) if isinstance(expr, basestring) else expr + keys = _parse_path(path) + expr = _traverse(expr, keys) + return bool(expr) class SQLitePool(Pool): def __init__(pool, filename, create_db): # called separately in each thread @@ -474,8 +493,8 @@ def _connect(pool): con.create_function('rand', 0, random) con.create_function('py_upper', 1, py_upper) con.create_function('py_lower', 1, py_lower) - con.create_function('unwrap_extract_json', 1, unwrap_extract_json) - con.create_function('py_json_extract', 3, py_json_extract) + con.create_function('py_json_unwrap', 1, py_json_unwrap) + con.create_function('py_json_extract', -1, py_json_extract) con.create_function('py_json_contains', 3, py_json_contains) con.create_function('py_json_nonzero', 2, py_json_nonzero) con.create_function('py_lower', 1, py_lower) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 3a4b7d5dc..db4f8cf7d 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -517,7 +517,9 @@ def json_path(builder, path): else: assert False result.append('\'') return result - def JSON_GETPATH(builder, expr, path): + def JSON_QUERY(builder, expr, path): + throw(NotImplementedError) + def JSON_VALUE(builder, expr, path, type): throw(NotImplementedError) def JSON_CONCAT(builder, left, right): throw(NotImplementedError) @@ -525,13 +527,3 @@ def JSON_CONTAINS(builder, expr, path, key): throw(NotImplementedError) def JSON_ARRAY_LENGTH(builder, value): throw(NotImplementedError) - def CAST(builder, expr, type): - type_name = builder.get_cast_type_name(type) - if type_name is None: return builder(expr) - return 'CAST(', builder(expr), ' AS ', type_name, ')' - JSON_CAST = CAST - def get_cast_type_name(builder, type): - if isinstance(type, basestring): return type - if type not in builder.typecast_mapping: throw(NotImplementedError, type) - return builder.typecast_mapping[type] - typecast_mapping = {unicode: 'text', bool: 'boolean', int: 'integer', float: 'real', Json: None} diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 7c471f7ba..ffdfbfdcd 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1070,7 +1070,7 @@ def __or__(monad): throw(TypeError) def __and__(monad): throw(TypeError) def __xor__(monad): throw(TypeError) def abs(monad): throw(TypeError) - def cast_from_json(monad, type): throw(TypeError) + def cast_from_json(monad, type): assert False, monad class RawSQLMonad(Monad): def __init__(monad, translator, rawtype, varkey): @@ -1571,10 +1571,6 @@ class JsonMixin(object): def mixin_init(monad): assert monad.type is Json, monad.type - def cast_from_json(monad, type): - translator = monad.translator - if issubclass(type, Json): return monad - return translator.ExprMonad.new(translator, type, ['JSON_CAST', monad.getsql()[0], type ]) def get_path(monad): return monad, [] def __getitem__(monad, key): @@ -1600,6 +1596,9 @@ def len(monad): translator = monad.translator sql = [ 'JSON_ARRAY_LENGTH', monad.getsql()[0] ] return translator.NumericExprMonad(translator, int, sql) + def cast_from_json(monad, type): + if type in (Json, NoneType): return monad + throw(TypeError, 'Cannot compare whole JSON value, you need to select specific sub-item: {EXPR}') class JsonAttrMonad(JsonMixin, AttrMonad): pass @@ -1713,10 +1712,16 @@ def get_path(monad): monad = monad.parent path.reverse() return monad, path + def cast_from_json(monad, type): + translator = monad.translator + if issubclass(type, Json): return monad + base_monad, path = monad.get_path() + sql = [ 'JSON_VALUE', base_monad.getsql()[0], path, type ] + return translator.ExprMonad.new(translator, type, sql) def getsql(monad): base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] - return [ [ 'JSON_GETPATH', base_sql, path ] ] + return [ [ 'JSON_QUERY', base_sql, path ] ] def nonzero(monad): return monad diff --git a/pony/orm/tests/test_declarative_exceptions.py b/pony/orm/tests/test_declarative_exceptions.py index 42af7a8ef..4f5935cef 100644 --- a/pony/orm/tests/test_declarative_exceptions.py +++ b/pony/orm/tests/test_declarative_exceptions.py @@ -208,7 +208,7 @@ def test48(self): @raises_exception(TypeError, "'sum' is valid for numeric attributes only") def test49(self): sum(s.name for s in Student) - @raises_exception(NotImplementedError, "Parameter {'a':'b'} has unsupported type ") + @raises_exception(TypeError, "Cannot compare whole JSON value, you need to select specific sub-item: s.name == {'a':'b'}") def test50(self): # cannot compare JSON value to dynamic string, # because a database does not provide json.dumps(s.name) functionality From 7fc3fcc45dcacbc788ce0f3c6f87aceaf78a2e99 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 26 Jul 2016 12:54:07 +0300 Subject: [PATCH 045/547] JsonMixin.nonzero() & builder.JSON_NONZERO methods added --- pony/orm/dbproviders/mysql.py | 2 ++ pony/orm/dbproviders/oracle.py | 6 ++---- pony/orm/dbproviders/postgres.py | 18 +++--------------- pony/orm/dbproviders/sqlite.py | 2 ++ pony/orm/sqlbuilding.py | 2 ++ pony/orm/sqltranslation.py | 5 +++-- 6 files changed, 14 insertions(+), 21 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index df8aa2ee5..cd30ef41d 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -97,6 +97,8 @@ def JSON_VALUE(builder, expr, path, type): if type in (bool, int): return 'CAST(', result, ' AS SIGNED)' return 'json_unquote(', result, ')' + def JSON_NONZERO(builder, expr): + return 'COALESCE(CAST(', builder(expr), ''' as CHAR), 'null') NOT IN ('null', 'false', '0', '""', '[]', '{}')''' def JSON_ARRAY_LENGTH(builder, value): return 'json_length(', builder(value), ')' def EQ_JSON(builder, left, right): diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index ba42b2005..8bb10a029 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -125,10 +125,6 @@ def get_normalized_type_of(translator, value): if value == '': return NoneType return sqltranslation.SQLTranslator.get_normalized_type_of(value) - class JsonItemMonad(sqltranslation.JsonItemMonad): - def nonzero(monad): - raise NotImplementedError - class OraBuilder(sqlbuilding.SQLBuilder): dialect = 'Oracle' def INSERT(builder, table_name, columns, values, returning=None): @@ -241,6 +237,8 @@ def JSON_VALUE(builder, expr, path, type): if type is Json: return builder.JSON_QUERY(expr, path) type_name = builder.json_value_type_mapping.get(type, 'VARCHAR2') return 'JSON_VALUE(', builder(expr), ', ', builder.json_path(path), ' RETURNING ', type_name, ')' + def JSON_NONZERO(builder, expr): + return 'COALESCE(', builder(expr), ''', 'null') NOT IN ('null', 'false', '0', '""', '[]', '{}')''' def JSON_CONTAINS(builder, expr, path, key): assert key[0] == 'VALUE' and isinstance(key[1], basestring) expr_sql = builder(expr) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 24cc2a9cf..e581fa86e 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -39,21 +39,6 @@ class PGSchema(dbschema.DBSchema): class PGTranslator(SQLTranslator): dialect = 'PostgreSQL' - class JsonItemMonad(sqltranslation.JsonItemMonad): - def nonzero(monad): - translator = monad.translator - empty_str = translator.StringExprMonad( - translator, str, ['RAWSQL', '\'""\'::jsonb'] - ) - str_not_empty = translator.CmpMonad( - '!=', monad, empty_str - ) - is_true = monad.cast_from_json(bool).getsql()[0] - sql = ['AND'] - sql.extend(str_not_empty.getsql()) - sql.extend(is_true.getsql()) - return translator.BoolExprMonad(translator, sql) - class PGValue(Value): __slots__ = [] def __unicode__(self): @@ -109,6 +94,9 @@ def JSON_VALUE(builder, expr, path, type): type_name = builder.json_value_type_mapping.get(type, 'text') sql = '(', builder(expr), " #>> ", builder.json_path(path), ')' return sql if type_name == 'text' else (sql, '::', type_name) + def JSON_NONZERO(builder, expr): + return 'coalesce(', builder(expr), ", 'null'::jsonb) NOT IN (" \ + "'null'::jsonb, 'false'::jsonb, '0'::jsonb, '\"\"'::jsonb, '[]'::jsonb, '{}'::jsonb)" def JSON_CONCAT(builder, left, right): return '(', builder(left), '||', builder(right), ')' def JSON_CONTAINS(builder, expr, path, key): diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 83e119b81..2d83f1fab 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -140,6 +140,8 @@ def JSON_QUERY(builder, expr, path): def JSON_VALUE(builder, expr, path, type): fname = 'json_extract' if builder.json1_available else 'py_json_extract' return fname, '(', builder(expr), ', ', builder.json_path(path), ')' + def JSON_NONZERO(builder, expr): + return builder(expr), ''' NOT IN ('null', 'false', '0', '""', '[]', '{}')''' def JSON_ARRAY_LENGTH(builder, value): if not builder.json1_available: raise SqliteExtensionUnavailable('json1') diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index db4f8cf7d..faf3a1d6b 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -521,6 +521,8 @@ def JSON_QUERY(builder, expr, path): throw(NotImplementedError) def JSON_VALUE(builder, expr, path, type): throw(NotImplementedError) + def JSON_NONZERO(builder, expr): + throw(NotImplementedError) def JSON_CONCAT(builder, left, right): throw(NotImplementedError) def JSON_CONTAINS(builder, expr, path, key): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index ffdfbfdcd..2ecc5213a 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1599,6 +1599,9 @@ def len(monad): def cast_from_json(monad, type): if type in (Json, NoneType): return monad throw(TypeError, 'Cannot compare whole JSON value, you need to select specific sub-item: {EXPR}') + def nonzero(monad): + translator = monad.translator + return translator.BoolExprMonad(translator, [ 'JSON_NONZERO', monad.getsql()[0] ]) class JsonAttrMonad(JsonMixin, AttrMonad): pass @@ -1722,8 +1725,6 @@ def getsql(monad): base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] return [ [ 'JSON_QUERY', base_sql, path ] ] - def nonzero(monad): - return monad class ConstMonad(Monad): @staticmethod From bd7627e5c1d23fe6672c126105890f8f34f1c877 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 16 Aug 2016 10:25:31 +0300 Subject: [PATCH 046/547] To compare JSON fragments in SQLite & Oracle, keys should be sorted --- pony/orm/dbproviders/oracle.py | 1 + pony/orm/dbproviders/sqlite.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 8bb10a029..05a111bf3 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -337,6 +337,7 @@ def sql_type(converter): return 'RAW(16)' class OraJsonConverter(dbapiprovider.JsonConverter): + json_kwargs = {'separators': (',', ':'), 'sort_keys': True} optimistic = False # CLOBs cannot be compared with strings, and TO_CHAR(CLOB) returns first 4000 chars only def sql2py(converter, dbval): if hasattr(dbval, 'read'): dbval = dbval.read() diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 2d83f1fab..5f8b5bdc6 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -201,7 +201,7 @@ def py2sql(converter, val): return datetime2timestamp(val) class SQLiteJsonConverter(dbapiprovider.JsonConverter): - json_kwargs = {'separators': (',', ':')} + json_kwargs = {'separators': (',', ':'), 'sort_keys': True} def print_traceback(func): @wraps(func) @@ -450,7 +450,7 @@ def _extract(expr, *paths): def py_json_extract(expr, *paths): result = _extract(expr, *paths) if type(result) in (list, dict): - result = json.dumps(result, separators=(',', ':')) + result = json.dumps(result, **SQLiteJsonConverter.json_kwargs) return result @print_traceback @@ -459,7 +459,7 @@ def py_json_query(expr, path, with_wrapper): if type(result) not in (list, dict): if not with_wrapper: return None result = [result] - return json.dumps(result, separators=(',', ':')) + return json.dumps(result, **SQLiteJsonConverter.json_kwargs) @print_traceback def py_json_value(expr, path): From 1205721170c075f67466ec0b9667ee4af62fa37f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 18 Aug 2016 14:50:30 +0300 Subject: [PATCH 047/547] Turn off json.dumps `unsure_ascii` flag in SQLite & Oracle --- pony/orm/dbproviders/oracle.py | 2 +- pony/orm/dbproviders/sqlite.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 05a111bf3..6d24569f6 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -337,7 +337,7 @@ def sql_type(converter): return 'RAW(16)' class OraJsonConverter(dbapiprovider.JsonConverter): - json_kwargs = {'separators': (',', ':'), 'sort_keys': True} + json_kwargs = {'separators': (',', ':'), 'sort_keys': True, 'ensure_ascii': False} optimistic = False # CLOBs cannot be compared with strings, and TO_CHAR(CLOB) returns first 4000 chars only def sql2py(converter, dbval): if hasattr(dbval, 'read'): dbval = dbval.read() diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 5f8b5bdc6..19bed387a 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -201,7 +201,7 @@ def py2sql(converter, val): return datetime2timestamp(val) class SQLiteJsonConverter(dbapiprovider.JsonConverter): - json_kwargs = {'separators': (',', ':'), 'sort_keys': True} + json_kwargs = {'separators': (',', ':'), 'sort_keys': True, 'ensure_ascii': False} def print_traceback(func): @wraps(func) From 8f95b70695b5b9e1205987642c570f1ee3fea40a Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 17 Aug 2016 14:57:56 +0300 Subject: [PATCH 048/547] Add `py_json_array_length` function to SQLite --- pony/orm/dbproviders/sqlite.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 19bed387a..5e3ea4fe0 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -143,9 +143,8 @@ def JSON_VALUE(builder, expr, path, type): def JSON_NONZERO(builder, expr): return builder(expr), ''' NOT IN ('null', 'false', '0', '""', '[]', '{}')''' def JSON_ARRAY_LENGTH(builder, value): - if not builder.json1_available: - raise SqliteExtensionUnavailable('json1') - return 'json_array_length(', builder(value), ')' + func_name = 'json_array_length' if builder.json1_available else 'py_json_array_length' + return func_name, '(', builder(value), ')' def JSON_CONTAINS(builder, expr, path, key): return 'py_json_contains(', builder(expr), ', ', builder.json_path(path), ', ', builder(key), ')' @@ -480,6 +479,14 @@ def py_json_nonzero(expr, path): expr = _traverse(expr, keys) return bool(expr) +@print_traceback +def py_json_array_length(expr, path=None): + expr = json.loads(expr) if isinstance(expr, basestring) else expr + if path: + keys = _parse_path(path) + expr = _traverse(expr, keys) + return len(expr) if type(expr) is list else 0 + class SQLitePool(Pool): def __init__(pool, filename, create_db): # called separately in each thread pool.filename = filename @@ -499,6 +506,7 @@ def _connect(pool): con.create_function('py_json_extract', -1, py_json_extract) con.create_function('py_json_contains', 3, py_json_contains) con.create_function('py_json_nonzero', 2, py_json_nonzero) + con.create_function('py_json_array_length', -1, py_json_array_length) con.create_function('py_lower', 1, py_lower) if sqlite.sqlite_version_info >= (3, 6, 19): con.execute('PRAGMA foreign_keys = true') From c8ff02609edbd194d0e9d9fdd5df0e22ac0bc597 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 27 Jul 2016 13:54:41 +0300 Subject: [PATCH 049/547] Refactoring: use None instead of NoneMonad in slices --- pony/orm/sqltranslation.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 2ecc5213a..a9269bf6a 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -775,8 +775,11 @@ def postSubscript(translator, node): if isinstance(sub, ast.Sliceobj): start, stop, step = (sub.nodes+[None])[:3] if start is not None: start = start.monad + if isinstance(start, NoneMonad): start = None if stop is not None: stop = stop.monad + if isinstance(stop, NoneMonad): stop = None if step is not None: step = step.monad + if isinstance(step, NoneMonad): step = None return node.expr.monad[start:stop:step] else: return node.expr.monad[sub.monad] def postSlice(translator, node): @@ -784,8 +787,10 @@ def postSlice(translator, node): expr_monad = node.expr.monad upper = node.upper if upper is not None: upper = upper.monad + if isinstance(upper, NoneMonad): upper = None lower = node.lower if lower is not None: lower = lower.monad + if isinstance(lower, NoneMonad): lower = None return expr_monad[lower:upper] def postSliceobj(translator, node): pass @@ -1338,8 +1343,6 @@ def __getitem__(monad, index): elif isinstance(index, slice): if index.step is not None: throw(TypeError, 'Step is not supported in {EXPR}') start, stop = index.start, index.stop - if isinstance(start, NoneMonad): start = None - if isinstance(stop, NoneMonad): stop = None if start is None and stop is None: return monad if isinstance(monad, translator.StringConstMonad) \ and (start is None or isinstance(start, translator.NumericConstMonad)) \ @@ -1698,9 +1701,7 @@ def __init__(monad, parent, key): Monad.__init__(monad, parent.translator, Json) monad.parent = parent if isinstance(key, slice): - for item in (key.start, key.stop, key.step): - if not isinstance(item, (NoneType, NoneMonad)): - throw(NotImplementedError) + if key != slice(None, None, None): throw(NotImplementedError) keyval = AnyNum elif isinstance(key, EllipsisMonad): keyval = AnyStr From 43d7538324109d5276a731f427de1a8f8b20a214 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 26 Jul 2016 22:19:18 +0300 Subject: [PATCH 050/547] Refactoring: param.eval(values) method instead of convert(values, params) function --- pony/orm/sqlbuilding.py | 45 ++++++++++++++++++----------------------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index faf3a1d6b..ba2f47738 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -19,12 +19,23 @@ def __init__(param, paramstyle, id, paramkey, converter=None): param.id = id param.paramkey = paramkey param.converter = converter - def py2sql(param, val): - converter = param.converter - if converter is not None: - val = converter.val2dbval(val) - val = converter.py2sql(val) - return val + def eval(param, values): + varkey, i, j = param.paramkey + value = values[varkey] + t = type(value) + if i is not None: + if t is tuple: value = value[i] + elif t is RawSQL: value = value.values[i] + else: assert False + if j is not None: + assert type(type(value)).__name__ == 'EntityMeta' + value = value._get_raw_pkval_()[j] + if value is not None: # can value be None at all? + converter = param.converter + if converter is not None: + value = converter.val2dbval(value) + value = converter.py2sql(value) + return value def __unicode__(param): paramstyle = param.style if paramstyle == 'qmark': return u'?' @@ -133,22 +144,6 @@ def new_method(builder, *args, **kwargs): new_method.__name__ = method.__name__ return new_method -def convert(values, params): - for param in params: - varkey, i, j = param.paramkey - value = values[varkey] - t = type(value) - if i is not None: - if t is tuple: value = value[i] - elif t is RawSQL: value = value.values[i] - else: assert False - if j is not None: - assert type(type(value)).__name__ == 'EntityMeta' - value = value._get_raw_pkval_()[j] - if value is not None: # can value be None at all? - value = param.py2sql(value) - yield value - class SQLBuilder(object): dialect = None make_param = Param @@ -168,15 +163,15 @@ def __init__(builder, provider, ast): if paramstyle in ('qmark', 'format'): params = tuple(x for x in builder.result if isinstance(x, Param)) def adapter(values): - return tuple(convert(values, params)) + return tuple((param.eval(values) for param in params)) elif paramstyle == 'numeric': params = tuple(param for param in sorted(itervalues(builder.keys), key=attrgetter('id'))) def adapter(values): - return tuple(convert(values, params)) + return tuple(param.eval(values) for param in params) elif paramstyle in ('named', 'pyformat'): params = tuple(param for param in sorted(itervalues(builder.keys), key=attrgetter('id'))) def adapter(values): - return dict(('p%d' % param.id, value) for param, value in izip(params, convert(values, params))) + return {'p%d' % param.id: param.eval(values) for param in params} else: throw(NotImplementedError, paramstyle) builder.params = params builder.layout = tuple(param.paramkey for param in params) From bc12bb1179187d24e249449b9b32e4418cca194b Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 27 Jul 2016 14:50:40 +0300 Subject: [PATCH 051/547] builder.make_value -> builder.value_class, builder.make_param -> builder.param_class --- pony/orm/dbproviders/postgres.py | 2 +- pony/orm/sqlbuilding.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index e581fa86e..70d5aec2d 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -49,7 +49,7 @@ def __unicode__(self): class PGSQLBuilder(SQLBuilder): dialect = 'PostgreSQL' - make_value = PGValue + value_class = PGValue def INSERT(builder, table_name, columns, values, returning=None): if not values: result = [ 'INSERT INTO ', builder.quote_name(table_name) ,' DEFAULT VALUES' ] else: result = SQLBuilder.INSERT(builder, table_name, columns, values) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index ba2f47738..3f171efae 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -146,8 +146,8 @@ def new_method(builder, *args, **kwargs): class SQLBuilder(object): dialect = None - make_param = Param - make_value = Value + param_class = Param + value_class = Value indent_spaces = " " * 4 def __init__(builder, provider, ast): builder.provider = provider @@ -351,13 +351,13 @@ def PARAM(builder, paramkey, converter=None): keys = builder.keys param = keys.get(paramkey) if param is None: - param = Param(builder.paramstyle, len(keys) + 1, paramkey, converter) + param = builder.param_class(builder.paramstyle, len(keys) + 1, paramkey, converter) keys[paramkey] = param return [ param ] def ROW(builder, *items): return '(', join(', ', imap(builder, items)), ')' def VALUE(builder, value): - return [ builder.make_value(builder.paramstyle, value) ] + return [builder.value_class(builder.paramstyle, value)] def AND(builder, *cond_list): cond_list = [ builder(condition) for condition in cond_list ] return join(' AND ', cond_list) From 39d84993d60551633d59469d47bda3463a5f31f3 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 27 Jul 2016 19:13:24 +0300 Subject: [PATCH 052/547] Composite params for JSON path expressions --- pony/orm/dbproviders/mysql.py | 10 ++- pony/orm/dbproviders/oracle.py | 16 ++-- pony/orm/dbproviders/postgres.py | 20 +++-- pony/orm/dbproviders/sqlite.py | 11 ++- pony/orm/ormtypes.py | 9 +- pony/orm/sqlbuilding.py | 84 ++++++++++++++----- pony/orm/sqltranslation.py | 13 ++- .../tests/test_sqlbuilding_formatstyles.py | 14 ++-- 8 files changed, 107 insertions(+), 70 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index cd30ef41d..02d0e7164 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -89,9 +89,11 @@ def DATETIME_SUB(builder, expr, delta): return 'DATE_SUB(', builder(expr), ", INTERVAL '", timedelta2str(delta), "' HOUR_SECOND)" return 'SUBTIME(', builder(expr), ', ', builder(delta), ')' def JSON_QUERY(builder, expr, path): - return 'json_extract(', builder(expr), ', ', builder.json_path(path), ')' + path_sql, has_params, has_wildcards = builder.json_path(path) + return 'json_extract(', builder(expr), ', ', path_sql, ')' def JSON_VALUE(builder, expr, path, type): - result = 'json_extract(', builder(expr), ', ', builder.json_path(path), ')' + path_sql, has_params, has_wildcards = builder.json_path(path) + result = 'json_extract(', builder(expr), ', ', path_sql, ')' if type is NoneType: return 'NULLIF(', result, ", CAST('null' as JSON))" if type in (bool, int): @@ -109,8 +111,8 @@ def JSON_CONTAINS(builder, expr, path, key): assert key[0] == 'VALUE' and isinstance(key[1], basestring) expr_sql = builder(expr) result = [ '(json_contains(', expr_sql, ', ', builder([ 'VALUE', json.dumps([ key[1] ]) ]) ] - path_sql = builder.json_path(path) - path_with_key_sql = builder.json_path(path + [ key[1] ]) + path_sql, has_params, has_wildcards = builder.json_path(path) + path_with_key_sql, _, _ = builder.json_path(path + [ key[1] ]) result += [ ', ', path_sql, ') or json_contains_path(', expr_sql, ", 'one', ", path_with_key_sql, '))' ] return result diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 6d24569f6..60788bd13 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -13,7 +13,7 @@ from pony.orm import core, sqlbuilding, dbapiprovider, sqltranslation from pony.orm.core import log_orm, log_sql, DatabaseError, TranslationError from pony.orm.dbschema import DBSchema, DBObject, Table, Column -from pony.orm.ormtypes import Json, AnyItem +from pony.orm.ormtypes import Json from pony.orm.dbapiprovider import DBAPIProvider, wrap_dbapi_exceptions, get_version_tuple from pony.utils import throw from pony.converting import timedelta2str @@ -227,22 +227,22 @@ def DATETIME_SUB(builder, expr, delta): return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' def JSON_QUERY(builder, expr, path): - for item in path: - if isinstance(item, AnyItem): - return 'JSON_QUERY(', builder(expr), ', ', builder.json_path(path), ' WITH WRAPPER)' - return 'REGEXP_REPLACE(JSON_QUERY(', \ - builder(expr), ', ', builder.json_path(path), " WITH WRAPPER), '(^\\[|\\]$)', '')" + expr_sql = builder(expr) + path_sql, has_params, has_wildcards = builder.json_path(path) + if has_wildcards: return 'JSON_QUERY(', expr_sql, ', ', path_sql, ' WITH WRAPPER)' + return 'REGEXP_REPLACE(JSON_QUERY(', expr_sql, ', ', path_sql, " WITH WRAPPER), '(^\\[|\\]$)', '')" json_value_type_mapping = {bool: 'NUMBER', int: 'NUMBER', float: 'NUMBER'} def JSON_VALUE(builder, expr, path, type): if type is Json: return builder.JSON_QUERY(expr, path) + path_sql, has_params, has_wildcards = builder.json_path(path) type_name = builder.json_value_type_mapping.get(type, 'VARCHAR2') - return 'JSON_VALUE(', builder(expr), ', ', builder.json_path(path), ' RETURNING ', type_name, ')' + return 'JSON_VALUE(', builder(expr), ', ', path_sql, ' RETURNING ', type_name, ')' def JSON_NONZERO(builder, expr): return 'COALESCE(', builder(expr), ''', 'null') NOT IN ('null', 'false', '0', '""', '[]', '{}')''' def JSON_CONTAINS(builder, expr, path, key): assert key[0] == 'VALUE' and isinstance(key[1], basestring) expr_sql = builder(expr) - path_sql = builder.json_path(path + [ key[1] ]) + path_sql, has_params, has_wildcards = builder.json_path(path + [ key[1] ]) return 'JSON_EXISTS(', expr_sql, ', ', path_sql, ')' class OraBoolConverter(dbapiprovider.BoolConverter): diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 70d5aec2d..8054d8e1f 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -77,22 +77,24 @@ def DATETIME_SUB(builder, expr, delta): if isinstance(delta, timedelta): return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' DAY TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' - def json_path(builder, path): + def eval_json_path(builder, values): result = [] - for item in path: - if isinstance(item, int): - result.append(str(item)) - elif isinstance(item, basestring): - result.append(item if is_ident(item) else '"%s"' % item.replace('"', '\\"')) - else: assert False, item + for value in values: + if isinstance(value, int): + result.append(str(value)) + elif isinstance(value, basestring): + result.append(value if is_ident(value) else '"%s"' % value.replace('"', '\\"')) + else: assert False, value return '{%s}' % ','.join(result) def JSON_QUERY(builder, expr, path): - return '(', builder(expr), " #> ", builder.json_path(path), ')' + path_sql, has_params, has_wildcards = builder.json_path(path) + return '(', builder(expr), " #> ", path_sql, ')' json_value_type_mapping = {bool: 'boolean', int: 'integer', float: 'real'} def JSON_VALUE(builder, expr, path, type): if type is ormtypes.Json: return builder.JSON_QUERY(expr, path) + path_sql, has_params, has_wildcards = builder.json_path(path) + sql = '(', builder(expr), " #>> ", path_sql, ')' type_name = builder.json_value_type_mapping.get(type, 'text') - sql = '(', builder(expr), " #>> ", builder.json_path(path), ')' return sql if type_name == 'text' else (sql, '::', type_name) def JSON_NONZERO(builder, expr): return 'coalesce(', builder(expr), ", 'null'::jsonb) NOT IN (" \ diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 5e3ea4fe0..4bd2cc4d9 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -135,18 +135,21 @@ def RANDOM(builder): def JSON_QUERY(builder, expr, path): fname = 'json_extract' if builder.json1_available else 'py_json_extract' - return 'py_json_unwrap(', fname, '(', builder(expr), ', null, ', builder.json_path(path), '))' + path_sql, has_params, has_wildcards = builder.json_path(path) + return 'py_json_unwrap(', fname, '(', builder(expr), ', null, ', path_sql, '))' # json_value_type_mapping = {unicode: 'text', bool: 'boolean', int: 'integer', float: 'real', Json: None} def JSON_VALUE(builder, expr, path, type): - fname = 'json_extract' if builder.json1_available else 'py_json_extract' - return fname, '(', builder(expr), ', ', builder.json_path(path), ')' + func_name = 'json_extract' if builder.json1_available else 'py_json_extract' + path_sql, has_params, has_wildcards = builder.json_path(path) + return func_name, '(', builder(expr), ', ', path_sql, ')' def JSON_NONZERO(builder, expr): return builder(expr), ''' NOT IN ('null', 'false', '0', '""', '[]', '{}')''' def JSON_ARRAY_LENGTH(builder, value): func_name = 'json_array_length' if builder.json1_available else 'py_json_array_length' return func_name, '(', builder(value), ')' def JSON_CONTAINS(builder, expr, path, key): - return 'py_json_contains(', builder(expr), ', ', builder.json_path(path), ', ', builder(key), ')' + path_sql, has_params, has_wildcards = builder.json_path(path) + return 'py_json_contains(', builder(expr), ', ', path_sql, ', ', builder(key), ')' class SQLiteIntConverter(dbapiprovider.IntConverter): def sql_type(converter): diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 9f07a9462..e0c475772 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -153,7 +153,7 @@ def normalize_type(t): if t is NoneType: return t t = type_normalization_dict.get(t, t) if t in primitive_types: return t - if issubclass(t, (AnyItem, slice, type(Ellipsis))): return t + if t in (slice, type(Ellipsis)): return t if issubclass(t, basestring): return unicode if issubclass(t, (dict, Json)): return Json throw(TypeError, 'Unsupported type %r' % t.__name__) @@ -294,10 +294,3 @@ def __init__(self, wrapped): def __repr__(self): return '' % self.wrapped - -class AnyItem(object): - def __init__(self, type): - self.type = type - -AnyStr = AnyItem('Str') -AnyNum = AnyItem('Number') diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 3f171efae..a22d10739 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -8,15 +8,15 @@ from pony import options from pony.utils import datetime2timestamp, throw, is_ident -from pony.orm.ormtypes import RawSQL, Json, AnyNum, AnyStr +from pony.orm.ormtypes import RawSQL, Json class AstError(Exception): pass class Param(object): __slots__ = 'style', 'id', 'paramkey', 'converter' - def __init__(param, paramstyle, id, paramkey, converter=None): + def __init__(param, paramstyle, paramkey, converter=None): param.style = paramstyle - param.id = id + param.id = None param.paramkey = paramkey param.converter = converter def eval(param, values): @@ -48,6 +48,17 @@ def __unicode__(param): def __repr__(param): return '%s(%r)' % (param.__class__.__name__, param.paramkey) +class CompositeParam(Param): + __slots__ = 'items', 'func' + def __init__(param, paramstyle, paramkey, items, func): + for item in items: assert isinstance(item, (Param, Value)), item + Param.__init__(param, paramstyle, paramkey) + param.items = items + param.func = func + def eval(param, values): + args = [ item.eval(values) if isinstance(item, Param) else item.value for item in param.items ] + return param.func(args) + class Value(object): __slots__ = 'paramstyle', 'value' def __init__(self, paramstyle, value): @@ -147,6 +158,7 @@ def new_method(builder, *args, **kwargs): class SQLBuilder(object): dialect = None param_class = Param + composite_param_class = CompositeParam value_class = Value indent_spaces = " " * 4 def __init__(builder, provider, ast): @@ -159,22 +171,24 @@ def __init__(builder, provider, ast): builder.inner_join_syntax = options.INNER_JOIN_SYNTAX builder.suppress_aliases = False builder.result = flat(builder(ast)) + params = tuple(x for x in builder.result if isinstance(x, Param)) + layout = [] + for i, param in enumerate(params): + if param.id is None: param.id = i + 1 + layout.append(param.paramkey) + builder.layout = layout builder.sql = u''.join(imap(unicode, builder.result)).rstrip('\n') if paramstyle in ('qmark', 'format'): - params = tuple(x for x in builder.result if isinstance(x, Param)) def adapter(values): - return tuple((param.eval(values) for param in params)) + return tuple(param.eval(values) for param in params) elif paramstyle == 'numeric': - params = tuple(param for param in sorted(itervalues(builder.keys), key=attrgetter('id'))) def adapter(values): return tuple(param.eval(values) for param in params) elif paramstyle in ('named', 'pyformat'): - params = tuple(param for param in sorted(itervalues(builder.keys), key=attrgetter('id'))) def adapter(values): return {'p%d' % param.id: param.eval(values) for param in params} else: throw(NotImplementedError, paramstyle) builder.params = params - builder.layout = tuple(param.paramkey for param in params) builder.adapter = adapter def __call__(builder, ast): if isinstance(ast, basestring): @@ -348,12 +362,16 @@ def COLUMN(builder, table_alias, col_name): return [ '%s' % builder.quote_name(col_name) ] return [ '%s.%s' % (builder.quote_name(table_alias), builder.quote_name(col_name)) ] def PARAM(builder, paramkey, converter=None): + return builder.make_param(builder.param_class, paramkey, converter) + def make_param(builder, param_class, paramkey, *args): keys = builder.keys param = keys.get(paramkey) if param is None: - param = builder.param_class(builder.paramstyle, len(keys) + 1, paramkey, converter) + param = param_class(builder.paramstyle, paramkey, *args) keys[paramkey] = param return [ param ] + def make_composite_param(builder, paramkey, items, func): + return builder.make_param(builder.composite_param_class, paramkey, items, func) def ROW(builder, *items): return '(', join(', ', imap(builder, items)), ')' def VALUE(builder, value): @@ -499,19 +517,41 @@ def RAWSQL(builder, sql): if isinstance(sql, basestring): return sql return [ x if isinstance(x, basestring) else builder(x) for x in sql ] def json_path(builder, path): - result = ['\'$'] - for item in path: - if isinstance(item, int): - result.append('[%d]' % item) - elif item is AnyNum: - result.append('[*]') - elif isinstance(item, str): - result.append('.' + item if is_ident(item) else '."%s"' % item.replace('"', '\\"')) - elif item is AnyStr: - result.append('.*') - else: assert False - result.append('\'') - return result + items = [] + for element in path: items.extend(builder(element)) + empty_slice = slice(None, None, None) + has_params = False + has_wildcards = False + for item in items: + if isinstance(item, Param): + has_params = True + elif isinstance(item, Value): + value = item.value + if value is Ellipsis or value == empty_slice: has_wildcards = True + else: assert isinstance(value, (int, basestring)), value + else: assert False, item + if has_params: + paramkey = tuple(item.paramkey if isinstance(item, Param) else + None if type(item.value) is slice else item.value + for item in items) + path_sql = builder.make_composite_param(paramkey, items, builder.eval_json_path) + else: + result_value = builder.eval_json_path(item.value for item in items) + path_sql = builder.value_class(builder.paramstyle, result_value) + return path_sql, has_params, has_wildcards + @classmethod + def eval_json_path(cls, values): + result = ['$'] + append = result.append + empty_slice = slice(None, None, None) + for value in values: + if isinstance(value, int): append('[%d]' % value) + elif isinstance(value, str): + append('.' + value if is_ident(value) else '."%s"' % value.replace('"', '\\"')) + elif value is Ellipsis: append('.*') + elif value == empty_slice: append('[*]') + else: assert False, value + return ''.join(result) def JSON_QUERY(builder, expr, path): throw(NotImplementedError) def JSON_VALUE(builder, expr, path, type): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index a9269bf6a..845d80f6a 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -17,7 +17,7 @@ from pony.orm.ormtypes import \ numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ get_normalized_type_of, normalize_type, coerce_types, are_comparable_types, \ - Json, AnyStr, AnyNum, AnyItem + Json from pony.orm import core from pony.orm.core import EntityMeta, Set, JOIN, OptimizationFailed, Attribute, DescWrapper @@ -1702,17 +1702,14 @@ def __init__(monad, parent, key): monad.parent = parent if isinstance(key, slice): if key != slice(None, None, None): throw(NotImplementedError) - keyval = AnyNum - elif isinstance(key, EllipsisMonad): - keyval = AnyStr - elif isinstance(key, (StringConstMonad, NumericConstMonad)): - keyval = key.value + monad.key_ast = [ 'VALUE', key ] + elif isinstance(key, (ParamMonad, StringConstMonad, NumericConstMonad, EllipsisMonad)): + monad.key_ast = key.getsql()[0] else: throw(TypeError, 'Invalid JSON path item: %s' % ast2src(key.node)) - monad.key = keyval def get_path(monad): path = [] while isinstance(monad, JsonItemMonad): - path.append(monad.key) + path.append(monad.key_ast) monad = monad.parent path.reverse() return monad, path diff --git a/pony/orm/tests/test_sqlbuilding_formatstyles.py b/pony/orm/tests/test_sqlbuilding_formatstyles.py index 2a975959f..40804ffea 100644 --- a/pony/orm/tests/test_sqlbuilding_formatstyles.py +++ b/pony/orm/tests/test_sqlbuilding_formatstyles.py @@ -8,8 +8,8 @@ class TestFormatStyles(unittest.TestCase): def setUp(self): - self.key1 = object() - self.key2 = object() + self.key1 = 'KEY1' + self.key2 = 'KEY2' self.provider = DBAPIProvider(pony_pool_mockup=TestPool(None)) self.ast = [ SELECT, [ ALL, [COLUMN, None, 'A']], [ FROM, [None, TABLE, 'T1']], [ WHERE, [ EQ, [COLUMN, None, 'B'], [ PARAM, self.key1 ] ], @@ -24,35 +24,35 @@ def test_qmark(self): self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = ?\n AND "C" = ?\n AND "D" = ?\n AND "E" = ?') - self.assertEqual(b.layout, (self.key1, self.key2, self.key2, self.key1)) + self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) def test_numeric(self): self.provider.paramstyle = 'numeric' b = SQLBuilder(self.provider, self.ast) self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = :1\n AND "C" = :2\n AND "D" = :2\n AND "E" = :1') - self.assertEqual(b.layout, (self.key1, self.key2)) + self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) def test_named(self): self.provider.paramstyle = 'named' b = SQLBuilder(self.provider, self.ast) self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = :p1\n AND "C" = :p2\n AND "D" = :p2\n AND "E" = :p1') - self.assertEqual(b.layout, (self.key1, self.key2)) + self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) def test_format(self): self.provider.paramstyle = 'format' b = SQLBuilder(self.provider, self.ast) self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = %s\n AND "C" = %s\n AND "D" = %s\n AND "E" = %s') - self.assertEqual(b.layout, (self.key1, self.key2, self.key2, self.key1)) + self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) def test_pyformat(self): self.provider.paramstyle = 'pyformat' b = SQLBuilder(self.provider, self.ast) self.assertEqual(b.sql, 'SELECT "A"\n' 'FROM "T1"\n' 'WHERE "B" = %(p1)s\n AND "C" = %(p2)s\n AND "D" = %(p2)s\n AND "E" = %(p1)s') - self.assertEqual(b.layout, (self.key1, self.key2)) + self.assertEqual(b.layout, [self.key1, self.key2, self.key2, self.key1]) if __name__ == "__main__": From 1d767b139cfb7dad7cad33f589d5d21ac23cb25b Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 17 Aug 2016 20:39:53 +0300 Subject: [PATCH 053/547] Add `param.optimistic` flag & `converter.dbvals_equal` method --- pony/orm/core.py | 35 +++++++++++++++-------------------- pony/orm/dbapiprovider.py | 8 +++++++- pony/orm/sqlbuilding.py | 11 ++++++----- 3 files changed, 28 insertions(+), 26 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 5d0ae7d3c..5cc9f1b2a 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2046,11 +2046,10 @@ def db_set(attr, obj, new_dbval, is_reverse_call=False): assert attr.pk_offset is None if new_dbval is NOT_LOADED: assert is_reverse_call old_dbval = obj._dbvals_.get(attr, NOT_LOADED) - - if attr.py_type is float: - if old_dbval is NOT_LOADED: pass - elif attr.converters[0].equals(old_dbval, new_dbval): return - elif old_dbval == new_dbval: return + if old_dbval is not NOT_LOADED: + if old_dbval == new_dbval or ( + not attr.reverse and attr.converters[0].dbvals_equal(old_dbval, new_dbval)): + return bit = obj._bits_[attr] if obj._rbits_ & bit: @@ -4085,12 +4084,14 @@ def _get_attrs_(entity, only=None, exclude=None, with_collections=False, with_la entity._attrnames_cache_[key] = attrs return attrs -def populate_criteria_list(criteria_list, columns, converters, operations, params_count=0, table_alias=None): +def populate_criteria_list(criteria_list, columns, converters, operations, + params_count=0, table_alias=None, optimistic=False): for column, op, converter in izip(columns, operations, converters): if op == 'IS_NULL': criteria_list.append([ op, [ 'COLUMN', None, column ] ]) else: - criteria_list.append([ op, [ 'COLUMN', table_alias, column ], [ 'PARAM', (params_count, None, None), converter ] ]) + criteria_list.append([ op, [ 'COLUMN', table_alias, column ], + [ 'PARAM', (params_count, None, None), converter, optimistic ] ]) params_count += 1 return params_count @@ -4347,17 +4348,11 @@ def _db_set_(obj, avdict, unpickling=False): assert attr.pk_offset is None assert new_dbval is not NOT_LOADED old_dbval = get_dbval(attr, NOT_LOADED) - if unpickling and old_dbval is not NOT_LOADED: - del avdict[attr] - continue - elif attr.py_type is float: - if old_dbval is NOT_LOADED: pass - elif attr.converters[0].equals(old_dbval, new_dbval): + if old_dbval is not NOT_LOADED: + if unpickling or old_dbval == new_dbval or ( + not attr.reverse and attr.converters[0].dbvals_equal(old_dbval, new_dbval)): del avdict[attr] continue - elif old_dbval == new_dbval: - del avdict[attr] - continue bit = obj._bits_[attr] if rbits & bit: throw(UnrepeatableReadError, @@ -4727,8 +4722,8 @@ def _save_updated_(obj): pk_columns = obj._pk_columns_ pk_converters = obj._pk_converters_ params_count = populate_criteria_list(where_list, pk_columns, pk_converters, repeat('EQ'), params_count) - if optimistic_columns: - populate_criteria_list(where_list, optimistic_columns, optimistic_converters, optimistic_ops, params_count) + if optimistic_columns: populate_criteria_list( + where_list, optimistic_columns, optimistic_converters, optimistic_ops, params_count, optimistic=True) sql_ast = [ 'UPDATE', obj._table_, list(izip(update_columns, update_params)), where_list ] sql, adapter = database._ast2sql(sql_ast) obj._update_sql_cache_[query_key] = sql, adapter @@ -4756,8 +4751,8 @@ def _save_deleted_(obj): if cached_sql is None: where_list = [ 'WHERE' ] params_count = populate_criteria_list(where_list, obj._pk_columns_, obj._pk_converters_, repeat('EQ')) - if optimistic_columns: - populate_criteria_list(where_list, optimistic_columns, optimistic_converters, optimistic_ops, params_count) + if optimistic_columns: populate_criteria_list( + where_list, optimistic_columns, optimistic_converters, optimistic_ops, params_count, optimistic=True) from_ast = [ 'FROM', [ None, 'TABLE', obj._table_ ] ] sql_ast = [ 'DELETE', None, from_ast, where_list ] sql, adapter = database._ast2sql(sql_ast) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index d52323d0b..69378b227 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -352,6 +352,8 @@ def val2dbval(self, val, obj=None): return val def dbval2val(self, dbval, obj=None): return dbval + def dbvals_equal(self, x, y): + return x == y def get_sql_type(converter, attr=None): if attr is not None and attr.sql_type is not None: return attr.sql_type @@ -534,7 +536,7 @@ def validate(converter, val): throw(ValueError, 'Value %r of attr %s is greater than the maximum allowed value %r' % (val, converter.attr, converter.max_val)) return val - def equals(converter, x, y): + def dbvals_equal(converter, x, y): tolerance = converter.tolerance if tolerance is None or x is None or y is None: return x == y denominator = max(abs(x), abs(y)) @@ -749,5 +751,9 @@ def dbval2val(self, dbval, obj=None): if obj is None: return val return TrackedValue.make(obj, self.attr, val) + def dbvals_equal(self, x, y): + if isinstance(x, basestring): x = json.loads(x) + if isinstance(y, basestring): y = json.loads(y) + return x == y def sql_type(self): return "JSON" diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index a22d10739..57073ef5b 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -13,12 +13,13 @@ class AstError(Exception): pass class Param(object): - __slots__ = 'style', 'id', 'paramkey', 'converter' - def __init__(param, paramstyle, paramkey, converter=None): + __slots__ = 'style', 'id', 'paramkey', 'converter', 'optimistic' + def __init__(param, paramstyle, paramkey, converter=None, optimistic=False): param.style = paramstyle param.id = None param.paramkey = paramkey param.converter = converter + param.optimistic = optimistic def eval(param, values): varkey, i, j = param.paramkey value = values[varkey] @@ -33,7 +34,7 @@ def eval(param, values): if value is not None: # can value be None at all? converter = param.converter if converter is not None: - value = converter.val2dbval(value) + if not param.optimistic: value = converter.val2dbval(value) value = converter.py2sql(value) return value def __unicode__(param): @@ -361,8 +362,8 @@ def COLUMN(builder, table_alias, col_name): if builder.suppress_aliases or not table_alias: return [ '%s' % builder.quote_name(col_name) ] return [ '%s.%s' % (builder.quote_name(table_alias), builder.quote_name(col_name)) ] - def PARAM(builder, paramkey, converter=None): - return builder.make_param(builder.param_class, paramkey, converter) + def PARAM(builder, paramkey, converter=None, optimistic=False): + return builder.make_param(builder.param_class, paramkey, converter, optimistic) def make_param(builder, param_class, paramkey, *args): keys = builder.keys param = keys.get(paramkey) From ca40693aefb9fd00824313974e3a254a1bf0ec82 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 27 Jul 2016 19:30:41 +0300 Subject: [PATCH 054/547] Rename: builder.json_path() -> builder.build_json_path() --- pony/orm/dbproviders/mysql.py | 8 ++++---- pony/orm/dbproviders/oracle.py | 6 +++--- pony/orm/dbproviders/postgres.py | 4 ++-- pony/orm/dbproviders/sqlite.py | 6 +++--- pony/orm/sqlbuilding.py | 2 +- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 02d0e7164..bda8549ac 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -89,10 +89,10 @@ def DATETIME_SUB(builder, expr, delta): return 'DATE_SUB(', builder(expr), ", INTERVAL '", timedelta2str(delta), "' HOUR_SECOND)" return 'SUBTIME(', builder(expr), ', ', builder(delta), ')' def JSON_QUERY(builder, expr, path): - path_sql, has_params, has_wildcards = builder.json_path(path) + path_sql, has_params, has_wildcards = builder.build_json_path(path) return 'json_extract(', builder(expr), ', ', path_sql, ')' def JSON_VALUE(builder, expr, path, type): - path_sql, has_params, has_wildcards = builder.json_path(path) + path_sql, has_params, has_wildcards = builder.build_json_path(path) result = 'json_extract(', builder(expr), ', ', path_sql, ')' if type is NoneType: return 'NULLIF(', result, ", CAST('null' as JSON))" @@ -111,8 +111,8 @@ def JSON_CONTAINS(builder, expr, path, key): assert key[0] == 'VALUE' and isinstance(key[1], basestring) expr_sql = builder(expr) result = [ '(json_contains(', expr_sql, ', ', builder([ 'VALUE', json.dumps([ key[1] ]) ]) ] - path_sql, has_params, has_wildcards = builder.json_path(path) - path_with_key_sql, _, _ = builder.json_path(path + [ key[1] ]) + path_sql, has_params, has_wildcards = builder.build_json_path(path) + path_with_key_sql, _, _ = builder.build_json_path(path + [key[1]]) result += [ ', ', path_sql, ') or json_contains_path(', expr_sql, ", 'one', ", path_with_key_sql, '))' ] return result diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 60788bd13..02b7e58ea 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -228,13 +228,13 @@ def DATETIME_SUB(builder, expr, delta): return '(', builder(expr), ' - ', builder(delta), ')' def JSON_QUERY(builder, expr, path): expr_sql = builder(expr) - path_sql, has_params, has_wildcards = builder.json_path(path) + path_sql, has_params, has_wildcards = builder.build_json_path(path) if has_wildcards: return 'JSON_QUERY(', expr_sql, ', ', path_sql, ' WITH WRAPPER)' return 'REGEXP_REPLACE(JSON_QUERY(', expr_sql, ', ', path_sql, " WITH WRAPPER), '(^\\[|\\]$)', '')" json_value_type_mapping = {bool: 'NUMBER', int: 'NUMBER', float: 'NUMBER'} def JSON_VALUE(builder, expr, path, type): if type is Json: return builder.JSON_QUERY(expr, path) - path_sql, has_params, has_wildcards = builder.json_path(path) + path_sql, has_params, has_wildcards = builder.build_json_path(path) type_name = builder.json_value_type_mapping.get(type, 'VARCHAR2') return 'JSON_VALUE(', builder(expr), ', ', path_sql, ' RETURNING ', type_name, ')' def JSON_NONZERO(builder, expr): @@ -242,7 +242,7 @@ def JSON_NONZERO(builder, expr): def JSON_CONTAINS(builder, expr, path, key): assert key[0] == 'VALUE' and isinstance(key[1], basestring) expr_sql = builder(expr) - path_sql, has_params, has_wildcards = builder.json_path(path + [ key[1] ]) + path_sql, has_params, has_wildcards = builder.build_json_path(path + [key[1]]) return 'JSON_EXISTS(', expr_sql, ', ', path_sql, ')' class OraBoolConverter(dbapiprovider.BoolConverter): diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 8054d8e1f..90c5b5c20 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -87,12 +87,12 @@ def eval_json_path(builder, values): else: assert False, value return '{%s}' % ','.join(result) def JSON_QUERY(builder, expr, path): - path_sql, has_params, has_wildcards = builder.json_path(path) + path_sql, has_params, has_wildcards = builder.build_json_path(path) return '(', builder(expr), " #> ", path_sql, ')' json_value_type_mapping = {bool: 'boolean', int: 'integer', float: 'real'} def JSON_VALUE(builder, expr, path, type): if type is ormtypes.Json: return builder.JSON_QUERY(expr, path) - path_sql, has_params, has_wildcards = builder.json_path(path) + path_sql, has_params, has_wildcards = builder.build_json_path(path) sql = '(', builder(expr), " #>> ", path_sql, ')' type_name = builder.json_value_type_mapping.get(type, 'text') return sql if type_name == 'text' else (sql, '::', type_name) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 4bd2cc4d9..ef96a9400 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -135,12 +135,12 @@ def RANDOM(builder): def JSON_QUERY(builder, expr, path): fname = 'json_extract' if builder.json1_available else 'py_json_extract' - path_sql, has_params, has_wildcards = builder.json_path(path) + path_sql, has_params, has_wildcards = builder.build_json_path(path) return 'py_json_unwrap(', fname, '(', builder(expr), ', null, ', path_sql, '))' # json_value_type_mapping = {unicode: 'text', bool: 'boolean', int: 'integer', float: 'real', Json: None} def JSON_VALUE(builder, expr, path, type): func_name = 'json_extract' if builder.json1_available else 'py_json_extract' - path_sql, has_params, has_wildcards = builder.json_path(path) + path_sql, has_params, has_wildcards = builder.build_json_path(path) return func_name, '(', builder(expr), ', ', path_sql, ')' def JSON_NONZERO(builder, expr): return builder(expr), ''' NOT IN ('null', 'false', '0', '""', '[]', '{}')''' @@ -148,7 +148,7 @@ def JSON_ARRAY_LENGTH(builder, value): func_name = 'json_array_length' if builder.json1_available else 'py_json_array_length' return func_name, '(', builder(value), ')' def JSON_CONTAINS(builder, expr, path, key): - path_sql, has_params, has_wildcards = builder.json_path(path) + path_sql, has_params, has_wildcards = builder.build_json_path(path) return 'py_json_contains(', builder(expr), ', ', path_sql, ', ', builder(key), ')' class SQLiteIntConverter(dbapiprovider.IntConverter): diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 57073ef5b..6232559d1 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -517,7 +517,7 @@ def RANDOM(builder): def RAWSQL(builder, sql): if isinstance(sql, basestring): return sql return [ x if isinstance(x, basestring) else builder(x) for x in sql ] - def json_path(builder, path): + def build_json_path(builder, path): items = [] for element in path: items.extend(builder(element)) empty_slice = slice(None, None, None) From f0c5c1ae867e0dabaa13d0539a8f66a66e036ca3 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 16 Aug 2016 10:36:38 +0300 Subject: [PATCH 055/547] Wildcards are not allowed in MySQL json_contains() --- pony/orm/dbproviders/mysql.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index bda8549ac..28bc15069 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -31,7 +31,7 @@ from pony.orm import core, dbschema, dbapiprovider, ormtypes, sqltranslation from pony.orm.core import log_orm from pony.orm.dbapiprovider import DBAPIProvider, Pool, get_version_tuple, wrap_dbapi_exceptions -from pony.orm.sqltranslation import SQLTranslator +from pony.orm.sqltranslation import SQLTranslator, TranslationError from pony.orm.sqlbuilding import SQLBuilder, join from pony.utils import throw from pony.converting import str2timedelta, timedelta2str @@ -112,6 +112,7 @@ def JSON_CONTAINS(builder, expr, path, key): expr_sql = builder(expr) result = [ '(json_contains(', expr_sql, ', ', builder([ 'VALUE', json.dumps([ key[1] ]) ]) ] path_sql, has_params, has_wildcards = builder.build_json_path(path) + if has_wildcards: throw(TranslationError, 'Wildcards are not allowed in json_contains()') path_with_key_sql, _, _ = builder.build_json_path(path + [key[1]]) result += [ ', ', path_sql, ') or json_contains_path(', expr_sql, ", 'one', ", path_with_key_sql, '))' ] return result From 4faf249c345c7412cef3f3887b4603e51d6ccb7b Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 27 Jul 2016 19:55:04 +0300 Subject: [PATCH 056/547] Minor simplification --- pony/orm/sqlbuilding.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 6232559d1..8fe892fe1 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -370,13 +370,13 @@ def make_param(builder, param_class, paramkey, *args): if param is None: param = param_class(builder.paramstyle, paramkey, *args) keys[paramkey] = param - return [ param ] + return param def make_composite_param(builder, paramkey, items, func): return builder.make_param(builder.composite_param_class, paramkey, items, func) def ROW(builder, *items): return '(', join(', ', imap(builder, items)), ')' def VALUE(builder, value): - return [builder.value_class(builder.paramstyle, value)] + return builder.value_class(builder.paramstyle, value) def AND(builder, *cond_list): cond_list = [ builder(condition) for condition in cond_list ] return join(' AND ', cond_list) @@ -518,11 +518,10 @@ def RAWSQL(builder, sql): if isinstance(sql, basestring): return sql return [ x if isinstance(x, basestring) else builder(x) for x in sql ] def build_json_path(builder, path): - items = [] - for element in path: items.extend(builder(element)) empty_slice = slice(None, None, None) has_params = False has_wildcards = False + items = [ builder(element) for element in path ] for item in items: if isinstance(item, Param): has_params = True From 70b052f7419df94d15aa22966f9481c38c14b981 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 28 Jul 2016 13:08:15 +0300 Subject: [PATCH 057/547] Minor refactoring --- pony/orm/dbproviders/oracle.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 02b7e58ea..871d8c338 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -10,10 +10,11 @@ import cx_Oracle -from pony.orm import core, sqlbuilding, dbapiprovider, sqltranslation +from pony.orm import core, dbapiprovider, sqltranslation from pony.orm.core import log_orm, log_sql, DatabaseError, TranslationError from pony.orm.dbschema import DBSchema, DBObject, Table, Column from pony.orm.ormtypes import Json +from pony.orm.sqlbuilding import SQLBuilder, Value from pony.orm.dbapiprovider import DBAPIProvider, wrap_dbapi_exceptions, get_version_tuple from pony.utils import throw from pony.converting import timedelta2str @@ -125,10 +126,10 @@ def get_normalized_type_of(translator, value): if value == '': return NoneType return sqltranslation.SQLTranslator.get_normalized_type_of(value) -class OraBuilder(sqlbuilding.SQLBuilder): +class OraBuilder(SQLBuilder): dialect = 'Oracle' def INSERT(builder, table_name, columns, values, returning=None): - result = sqlbuilding.SQLBuilder.INSERT(builder, table_name, columns, values) + result = SQLBuilder.INSERT(builder, table_name, columns, values) if returning is not None: result.extend((' RETURNING ', builder.quote_name(returning), ' INTO :new_id')) return result @@ -240,10 +241,10 @@ def JSON_VALUE(builder, expr, path, type): def JSON_NONZERO(builder, expr): return 'COALESCE(', builder(expr), ''', 'null') NOT IN ('null', 'false', '0', '""', '[]', '{}')''' def JSON_CONTAINS(builder, expr, path, key): - assert key[0] == 'VALUE' and isinstance(key[1], basestring) - expr_sql = builder(expr) - path_sql, has_params, has_wildcards = builder.build_json_path(path + [key[1]]) - return 'JSON_EXISTS(', expr_sql, ', ', path_sql, ')' + key_sql = builder(key) + assert isinstance(key_sql, Value) and isinstance(key_sql.value, basestring) + path_sql, has_params, has_wildcards = builder.build_json_path(path + [ key_sql.value ]) + return 'JSON_EXISTS(', builder(expr), ', ', path_sql, ')' class OraBoolConverter(dbapiprovider.BoolConverter): if not PY2: From 16a285ac39da3b06ee26bd005dcec8b3e3b7654f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 28 Jul 2016 16:05:59 +0300 Subject: [PATCH 058/547] Support external parameters for `key in JSON` operation in SQLite, PostgreSQL, MySQL --- pony/orm/dbproviders/mysql.py | 17 +++++++++++++---- pony/orm/sqltranslation.py | 6 +++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 28bc15069..8245bc4f2 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -32,7 +32,7 @@ from pony.orm.core import log_orm from pony.orm.dbapiprovider import DBAPIProvider, Pool, get_version_tuple, wrap_dbapi_exceptions from pony.orm.sqltranslation import SQLTranslator, TranslationError -from pony.orm.sqlbuilding import SQLBuilder, join +from pony.orm.sqlbuilding import Value, Param, SQLBuilder, join from pony.utils import throw from pony.converting import str2timedelta, timedelta2str @@ -108,14 +108,23 @@ def EQ_JSON(builder, left, right): def NE_JSON(builder, left, right): return '(', builder(left), ' != CAST(', builder(right), ' AS JSON))' def JSON_CONTAINS(builder, expr, path, key): - assert key[0] == 'VALUE' and isinstance(key[1], basestring) + key_sql = builder(key) + if isinstance(key_sql, Value): + wrapped_key = builder.value_class(builder.paramstyle, json.dumps([ key_sql.value ])) + elif isinstance(key_sql, Param): + wrapped_key = builder.make_composite_param( + (key_sql.paramkey,), [key_sql], builder.wrap_param_to_json_array) + else: assert False expr_sql = builder(expr) - result = [ '(json_contains(', expr_sql, ', ', builder([ 'VALUE', json.dumps([ key[1] ]) ]) ] + result = [ '(json_contains(', expr_sql, ', ', wrapped_key ] path_sql, has_params, has_wildcards = builder.build_json_path(path) if has_wildcards: throw(TranslationError, 'Wildcards are not allowed in json_contains()') - path_with_key_sql, _, _ = builder.build_json_path(path + [key[1]]) + path_with_key_sql, _, _ = builder.build_json_path(path + [key]) result += [ ', ', path_sql, ') or json_contains_path(', expr_sql, ", 'one', ", path_with_key_sql, '))' ] return result + @classmethod + def wrap_param_to_json_array(cls, values): + return json.dumps(values) class MySQLStrConverter(dbapiprovider.StrConverter): def sql_type(converter): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 845d80f6a..fe16fb1cd 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1579,8 +1579,12 @@ def get_path(monad): def __getitem__(monad, key): return monad.translator.JsonItemMonad(monad, key) def contains(monad, key, not_in=False): - if not isinstance(key, StringConstMonad): raise NotImplementedError translator = monad.translator + if isinstance(key, ParamMonad): + if translator.dialect == 'Oracle': throw(TypeError, + 'For `key in JSON` operation %s supports literal key values only, ' + 'parameters are not allowed: {EXPR}' % translator.dialect) + elif not isinstance(key, StringConstMonad): raise NotImplementedError base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] key_sql = key.getsql()[0] From 5f51cb06a120d4f04b561dfdf7a6ee814b645095 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 9 Aug 2016 11:55:47 +0300 Subject: [PATCH 059/547] Add JSON_PARAM to wrap with CAST( AS JSON) in MySQL --- pony/orm/dbproviders/mysql.py | 2 ++ pony/orm/sqlbuilding.py | 2 ++ pony/orm/sqltranslation.py | 4 ++-- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 8245bc4f2..fc2316672 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -125,6 +125,8 @@ def JSON_CONTAINS(builder, expr, path, key): @classmethod def wrap_param_to_json_array(cls, values): return json.dumps(values) + def JSON_PARAM(builder, expr): + return 'CAST(', builder(expr), ' AS JSON)' class MySQLStrConverter(dbapiprovider.StrConverter): def sql_type(converter): diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 8fe892fe1..6ba8af430 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -564,3 +564,5 @@ def JSON_CONTAINS(builder, expr, path, key): throw(NotImplementedError) def JSON_ARRAY_LENGTH(builder, value): throw(NotImplementedError) + def JSON_PARAM(builder, expr): + return builder(expr) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index fe16fb1cd..2f1889d79 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1667,9 +1667,9 @@ class DatetimeParamMonad(DatetimeMixin, ParamMonad): pass class BufferParamMonad(BufferMixin, ParamMonad): pass class UuidParamMonad(UuidMixin, ParamMonad): pass - class JsonParamMonad(JsonMixin, ParamMonad): - pass + def getsql(monad, subquery=None): + return [ [ 'JSON_PARAM', ParamMonad.getsql(monad)[0] ] ] class ExprMonad(Monad): @staticmethod From 4598959af7b8de372e8c01aca8261a9987466eac Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 10 Aug 2016 14:08:12 +0300 Subject: [PATCH 060/547] Fix JSON_CONTAINS in Oracle --- pony/orm/dbproviders/oracle.py | 31 ++++++++++++++++++++++++++----- 1 file changed, 26 insertions(+), 5 deletions(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 871d8c338..f9943c60b 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -4,6 +4,7 @@ import os os.environ["NLS_LANG"] = "AMERICAN_AMERICA.UTF8" +import re from datetime import datetime, date, time, timedelta from decimal import Decimal from uuid import UUID @@ -16,7 +17,7 @@ from pony.orm.ormtypes import Json from pony.orm.sqlbuilding import SQLBuilder, Value from pony.orm.dbapiprovider import DBAPIProvider, wrap_dbapi_exceptions, get_version_tuple -from pony.utils import throw +from pony.utils import throw, is_ident from pony.converting import timedelta2str NoneType = type(None) @@ -241,10 +242,30 @@ def JSON_VALUE(builder, expr, path, type): def JSON_NONZERO(builder, expr): return 'COALESCE(', builder(expr), ''', 'null') NOT IN ('null', 'false', '0', '""', '[]', '{}')''' def JSON_CONTAINS(builder, expr, path, key): - key_sql = builder(key) - assert isinstance(key_sql, Value) and isinstance(key_sql.value, basestring) - path_sql, has_params, has_wildcards = builder.build_json_path(path + [ key_sql.value ]) - return 'JSON_EXISTS(', builder(expr), ', ', path_sql, ')' + assert key[0] == 'VALUE' and isinstance(key[1], basestring) + path_sql, has_params, has_wildcards = builder.build_json_path(path) + path_with_key_sql, _, _ = builder.build_json_path(path + [ key ]) + expr_sql = builder(expr) + result = 'JSON_EXISTS(', expr_sql, ', ', path_with_key_sql, ')' + if json_item_re.match(key[1]): + item = r'"([^"]|\\")*"' + list_start = r'\[\s*(%s\s*,\s*)*' % item + list_end = r'\s*(,\s*%s\s*)*\]' % item + pattern = r'%s"%s"%s' % (list_start, key[1], list_end) + if has_wildcards: + sublist = r'\[[^]]*\]' + item_or_sublist = '(%s|%s)' % (item, sublist) + wrapper_list_start = r'^\[\s*(%s\s*,\s*)*' % item_or_sublist + wrapper_list_end = r'\s*(,\s*%s\s*)*\]$' % item_or_sublist + pattern = r'%s%s%s' % (wrapper_list_start, pattern, wrapper_list_end) + result += ' OR REGEXP_LIKE(JSON_QUERY(', expr_sql, ', ', path_sql, " WITH WRAPPER), '%s')" % pattern + else: + pattern = '^%s$' % pattern + result += ' OR REGEXP_LIKE(JSON_QUERY(', expr_sql, ', ', path_sql, "), '%s')" % pattern + return result + +json_item_re = re.compile('[\w\s]*') + class OraBoolConverter(dbapiprovider.BoolConverter): if not PY2: From 32c4d4930ce3693a516f7964304984e87bea36f1 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 16 Aug 2016 10:25:56 +0300 Subject: [PATCH 061/547] Add `json_path_wildcard_syntax` & `json_values_are_comparable` translator options --- pony/orm/dbproviders/mysql.py | 1 + pony/orm/dbproviders/oracle.py | 2 ++ pony/orm/sqltranslation.py | 15 ++++++++++++--- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index fc2316672..10ec005c9 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -46,6 +46,7 @@ class MySQLSchema(dbschema.DBSchema): class MySQLTranslator(SQLTranslator): dialect = 'MySQL' + json_path_wildcard_syntax = True class MySQLBuilder(SQLBuilder): dialect = 'MySQL' diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index f9943c60b..51337e4c9 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -119,6 +119,8 @@ def new(translator, value): class OraTranslator(sqltranslation.SQLTranslator): dialect = 'Oracle' rowid_support = True + json_path_wildcard_syntax = True + json_values_are_comparable = False NoneMonad = OraNoneMonad ConstMonad = OraConstMonad diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 2f1889d79..7c68561a7 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -69,6 +69,8 @@ def type2str(t): class SQLTranslator(ASTTranslator): dialect = None row_value_syntax = True + json_path_wildcard_syntax = False + json_values_are_comparable = True rowid_support = False def default_post(translator, node): @@ -1680,6 +1682,7 @@ def new(translator, type, sql): elif type is time: cls = translator.TimeExprMonad elif type is timedelta: cls = translator.TimedeltaExprMonad elif type is datetime: cls = translator.DatetimeExprMonad + elif type is Json: cls = translator.JsonExprMonad else: throw(NotImplementedError, type) # pragma: no cover return cls(translator, type, sql) def __new__(cls, *args): @@ -1702,7 +1705,8 @@ class JsonExprMonad(JsonMixin, ExprMonad): pass class JsonItemMonad(JsonMixin, Monad): def __init__(monad, parent, key): assert isinstance(parent, JsonMixin), parent - Monad.__init__(monad, parent.translator, Json) + translator = parent.translator + Monad.__init__(monad, translator, Json) monad.parent = parent if isinstance(key, slice): if key != slice(None, None, None): throw(NotImplementedError) @@ -1710,6 +1714,8 @@ def __init__(monad, parent, key): elif isinstance(key, (ParamMonad, StringConstMonad, NumericConstMonad, EllipsisMonad)): monad.key_ast = key.getsql()[0] else: throw(TypeError, 'Invalid JSON path item: %s' % ast2src(key.node)) + if isinstance(key, (slice, EllipsisMonad)) and not translator.json_path_wildcard_syntax: + throw(TranslationError, '%s does not support wildcards in JSON path: {EXPR}' % translator.dialect) def get_path(monad): path = [] while isinstance(monad, JsonItemMonad): @@ -1719,10 +1725,13 @@ def get_path(monad): return monad, path def cast_from_json(monad, type): translator = monad.translator - if issubclass(type, Json): return monad + if issubclass(type, Json): + if not translator.json_values_are_comparable: throw(TranslationError, + '%s does not support comparison of json structures: {EXPR}' % translator.dialect) + return monad base_monad, path = monad.get_path() sql = [ 'JSON_VALUE', base_monad.getsql()[0], path, type ] - return translator.ExprMonad.new(translator, type, sql) + return translator.ExprMonad.new(translator, Json if type is NoneType else type, sql) def getsql(monad): base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] From 44c3b0f4036745296535f95556d83d0db1e8f144 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 17 Aug 2016 15:06:40 +0300 Subject: [PATCH 062/547] Oracle does not provide `length` function for JSON arrays --- pony/orm/dbproviders/oracle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 51337e4c9..1337a50d2 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -265,6 +265,8 @@ def JSON_CONTAINS(builder, expr, path, key): pattern = '^%s$' % pattern result += ' OR REGEXP_LIKE(JSON_QUERY(', expr_sql, ', ', path_sql, "), '%s')" % pattern return result + def JSON_ARRAY_LENGTH(builder, value): + throw(TranslationError, 'Oracle does not provide `length` function for JSON arrays') json_item_re = re.compile('[\w\s]*') From 158a41a3e1927fb1119a89f398b8f7d8a16a6a00 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 17 Aug 2016 15:08:07 +0300 Subject: [PATCH 063/547] Oracle doesn't allow parameters in JSON paths --- pony/orm/dbproviders/oracle.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 1337a50d2..d7b4636a0 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -230,6 +230,10 @@ def DATETIME_SUB(builder, expr, delta): if isinstance(delta, timedelta): return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' + def build_json_path(builder, path): + path_sql, has_params, has_wildcards = SQLBuilder.build_json_path(builder, path) + if has_params: throw(TranslationError, "Oracle doesn't allow parameters in JSON paths") + return path_sql, has_params, has_wildcards def JSON_QUERY(builder, expr, path): expr_sql = builder(expr) path_sql, has_params, has_wildcards = builder.build_json_path(path) From 04c1655e78661b491d334aeb4dfaafd9b91326a0 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 11 Aug 2016 11:32:08 +0300 Subject: [PATCH 064/547] Move code around --- pony/orm/sqltranslation.py | 86 +++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 44 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 7c68561a7..2e02a2209 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1467,6 +1467,48 @@ def call_lstrip(monad, chars=None): def call_rstrip(monad, chars=None): return monad.strip(chars, 'RTRIM') +class JsonMixin(object): + disable_distinct = True # at least in Oracle we cannot use DISTINCT with JSON column + disable_ordering = True # at least in Oracle we cannot use ORDER BY with JSON column + + def mixin_init(monad): + assert monad.type is Json, monad.type + def get_path(monad): + return monad, [] + def __getitem__(monad, key): + return monad.translator.JsonItemMonad(monad, key) + def contains(monad, key, not_in=False): + translator = monad.translator + if isinstance(key, ParamMonad): + if translator.dialect == 'Oracle': throw(TypeError, + 'For `key in JSON` operation %s supports literal key values only, ' + 'parameters are not allowed: {EXPR}' % translator.dialect) + elif not isinstance(key, StringConstMonad): raise NotImplementedError + base_monad, path = monad.get_path() + base_sql = base_monad.getsql()[0] + key_sql = key.getsql()[0] + sql = [ 'JSON_CONTAINS', base_sql, path, key_sql ] + if not_in: sql = [ 'NOT', sql ] + return translator.BoolExprMonad(translator, sql) + def __or__(monad, other): + translator = monad.translator + if not isinstance(other, translator.JsonMixin): + raise TypeError('Should be JSON: %s' % ast2src(other.node)) + left_sql = monad.getsql()[0] + right_sql = other.getsql()[0] + sql = [ 'JSON_CONCAT', left_sql, right_sql ] + return translator.JsonExprMonad(translator, Json, sql) + def len(monad): + translator = monad.translator + sql = [ 'JSON_ARRAY_LENGTH', monad.getsql()[0] ] + return translator.NumericExprMonad(translator, int, sql) + def cast_from_json(monad, type): + if type in (Json, NoneType): return monad + throw(TypeError, 'Cannot compare whole JSON value, you need to select specific sub-item: {EXPR}') + def nonzero(monad): + translator = monad.translator + return translator.BoolExprMonad(translator, [ 'JSON_NONZERO', monad.getsql()[0] ]) + class ObjectMixin(MonadMixin): def mixin_init(monad): assert isinstance(monad.type, EntityMeta) @@ -1568,50 +1610,6 @@ class TimedeltaAttrMonad(TimedeltaMixin, AttrMonad): pass class DatetimeAttrMonad(DatetimeMixin, AttrMonad): pass class BufferAttrMonad(BufferMixin, AttrMonad): pass class UuidAttrMonad(UuidMixin, AttrMonad): pass - - -class JsonMixin(object): - disable_distinct = True # at least in Oracle we cannot use DISTINCT with JSON column - disable_ordering = True # at least in Oracle we cannot use ORDER BY with JSON column - - def mixin_init(monad): - assert monad.type is Json, monad.type - def get_path(monad): - return monad, [] - def __getitem__(monad, key): - return monad.translator.JsonItemMonad(monad, key) - def contains(monad, key, not_in=False): - translator = monad.translator - if isinstance(key, ParamMonad): - if translator.dialect == 'Oracle': throw(TypeError, - 'For `key in JSON` operation %s supports literal key values only, ' - 'parameters are not allowed: {EXPR}' % translator.dialect) - elif not isinstance(key, StringConstMonad): raise NotImplementedError - base_monad, path = monad.get_path() - base_sql = base_monad.getsql()[0] - key_sql = key.getsql()[0] - sql = [ 'JSON_CONTAINS', base_sql, path, key_sql ] - if not_in: sql = [ 'NOT', sql ] - return translator.BoolExprMonad(translator, sql) - def __or__(monad, other): - translator = monad.translator - if not isinstance(other, translator.JsonMixin): - raise TypeError('Should be JSON: %s' % ast2src(other.node)) - left_sql = monad.getsql()[0] - right_sql = other.getsql()[0] - sql = [ 'JSON_CONCAT', left_sql, right_sql ] - return translator.JsonExprMonad(translator, Json, sql) - def len(monad): - translator = monad.translator - sql = [ 'JSON_ARRAY_LENGTH', monad.getsql()[0] ] - return translator.NumericExprMonad(translator, int, sql) - def cast_from_json(monad, type): - if type in (Json, NoneType): return monad - throw(TypeError, 'Cannot compare whole JSON value, you need to select specific sub-item: {EXPR}') - def nonzero(monad): - translator = monad.translator - return translator.BoolExprMonad(translator, [ 'JSON_NONZERO', monad.getsql()[0] ]) - class JsonAttrMonad(JsonMixin, AttrMonad): pass class ParamMonad(Monad): From 8321ac199b954f525772be94f49b0ce250ed1662 Mon Sep 17 00:00:00 2001 From: Vitalii Date: Fri, 19 Aug 2016 14:18:22 +0300 Subject: [PATCH 065/547] json1: skip tests if extension is not available --- pony/orm/tests/fixtures.py | 69 ++++++++++++++------------------ pony/orm/tests/test_json/test.py | 1 - 2 files changed, 29 insertions(+), 41 deletions(-) diff --git a/pony/orm/tests/fixtures.py b/pony/orm/tests/fixtures.py index 47c8921b5..e337e9fac 100644 --- a/pony/orm/tests/fixtures.py +++ b/pony/orm/tests/fixtures.py @@ -3,7 +3,8 @@ import logging from pony.py23compat import PY2 -from ponytest import with_cli_args, pony_fixtures, provider_validators, provider +from ponytest import with_cli_args, pony_fixtures, provider_validators, provider, Fixture, \ + ValidationError from functools import wraps, partial import click @@ -12,9 +13,9 @@ from pony.utils import cached_property, class_property if not PY2: - from contextlib import contextmanager, ContextDecorator, ExitStack + from contextlib import contextmanager, ContextDecorator else: - from contextlib2 import contextmanager, ContextDecorator, ExitStack + from contextlib2 import contextmanager, ContextDecorator import unittest @@ -29,7 +30,6 @@ import threading - class DBContext(ContextDecorator): fixture = 'db' @@ -164,7 +164,7 @@ def init_db(self): self.drop_db(c) except pyodbc.DatabaseError as exc: print('Failed to drop db: %s' % exc) - c.execute('create database %s' % self.db_name) + c.execute('''CREATE DATABASE %s DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci''' % self.db_name ) c.execute('use %s' % self.db_name) def drop_db(self, cursor): @@ -173,8 +173,8 @@ def drop_db(self, cursor): @provider() -class SqliteContext(DBContext): - provider_key = 'sqlite' +class SqliteJson1(DBContext): + provider_key = 'sqlite_json1' enabled = True def init_db(self): @@ -183,7 +183,11 @@ def init_db(self): except OSError as exc: print('Failed to drop db: %s' % exc) - fixture_name = 'sqlite, with json1' + def __enter__(self): + result = super(SqliteJson1, self).__enter__() + if not self.db.provider.json1_available: + raise unittest.SkipTest + return result # TODO if json1 is not installed, do not run the tests @@ -342,10 +346,9 @@ def logging_context(test): logging.getLogger().setLevel(level) sql_debug(debug) -# @provider('log_all', scope='class', weight=-100, enabled=False) -# def log_all(Test): -# return logging_context(Test) - +@provider(fixture='log_all', weight=-100, enabled=False) +def log_all(Test): + return logging_context(Test) # @with_cli_args @@ -358,17 +361,15 @@ def logging_context(test): # yield log_all - - -@provider() -class DBSessionProvider(object): - - fixture= 'db_session' - - weight = 30 - - def __new__(cls, test): - return db_session +# @provider(enabled=False) +# class DBSessionProvider(object): +# +# fixture= 'db_session' +# +# weight = 30 +# +# def __new__(cls, test): +# return db_session @provider(fixture='rollback', weight=40) @@ -386,11 +387,8 @@ class SeparateProcess(object): # TODO read failures from sep process better fixture = 'separate_process' - enabled = False - scope = 'class' - def __init__(self, Test): self.Test = Test @@ -448,26 +446,25 @@ def __exit__(self, *exc_info): @provider() -class NoJson1(SqliteContext): +class SqliteNoJson1(SqliteJson1): provider_key = 'sqlite_no_json1' - fixture = 'db' def __init__(self, cls): self.Test = cls cls.no_json1 = True - return super(NoJson1, self).__init__(cls) + return super(SqliteNoJson1, self).__init__(cls) fixture_name = 'sqlite, no json1' def __enter__(self): - resource = super(NoJson1, self).__enter__() + resource = super(SqliteNoJson1, self).__enter__() self.json1_available = self.Test.db.provider.json1_available self.Test.db.provider.json1_available = False return resource def __exit__(self, *exc_info): self.Test.db.provider.json1_available = self.json1_available - return super(NoJson1, self).__exit__() + return super(SqliteNoJson1, self).__exit__() @@ -484,7 +481,6 @@ def __init__(self, Test, timeout): self.Test = Test self.timeout = timeout if timeout else Test.TIMEOUT - scope = 'class' enabled = False class Exception(Exception): @@ -529,19 +525,12 @@ def validate_chain(cls, fixtures, klass, timeout): pony_fixtures['test'].extend([ 'log', 'clear_tables', - 'db_session', ]) pony_fixtures['class'].extend([ 'separate_process', 'timeout', 'db', + 'log_all', 'generate_mapping', ]) - -# def db_is_required(providers, config): -# return providers - -# provider_validators.update({ -# 'db': db_is_required, -# }) \ No newline at end of file diff --git a/pony/orm/tests/test_json/test.py b/pony/orm/tests/test_json/test.py index f70b0e3b6..52bc90f7c 100644 --- a/pony/orm/tests/test_json/test.py +++ b/pony/orm/tests/test_json/test.py @@ -16,7 +16,6 @@ class TestJson(TestCase): - in_db_session = False @classmethod def make_entities(cls): From 19af6be490feed4d4df4185df97af8ec80c9cda3 Mon Sep 17 00:00:00 2001 From: Vitalii Date: Fri, 19 Aug 2016 19:33:22 +0300 Subject: [PATCH 066/547] test fixtures: use no_json1 as default for sqlite --- pony/orm/tests/fixtures.py | 69 ++++++++++++++++++-------------------- 1 file changed, 33 insertions(+), 36 deletions(-) diff --git a/pony/orm/tests/fixtures.py b/pony/orm/tests/fixtures.py index e337e9fac..f8fef0a32 100644 --- a/pony/orm/tests/fixtures.py +++ b/pony/orm/tests/fixtures.py @@ -172,10 +172,7 @@ def drop_db(self, cursor): cursor.execute('drop database %s' % self.db_name) -@provider() -class SqliteJson1(DBContext): - provider_key = 'sqlite_json1' - enabled = True +class SqliteMixin(DBContext): def init_db(self): try: @@ -183,15 +180,6 @@ def init_db(self): except OSError as exc: print('Failed to drop db: %s' % exc) - def __enter__(self): - result = super(SqliteJson1, self).__enter__() - if not self.db.provider.json1_available: - raise unittest.SkipTest - return result - - - # TODO if json1 is not installed, do not run the tests - @cached_property def db_path(self): p = os.path.dirname(__file__) @@ -203,6 +191,38 @@ def db(self): return Database('sqlite', self.db_path, create_db=True) +@provider() +class SqliteNoJson1(SqliteMixin): + provider_key = 'sqlite_no_json1' + enabled = True + + def __init__(self, cls): + self.Test = cls + cls.no_json1 = True + return super(SqliteNoJson1, self).__init__(cls) + + def __enter__(self): + resource = super(SqliteNoJson1, self).__enter__() + self.json1_available = self.Test.db.provider.json1_available + self.Test.db.provider.json1_available = False + return resource + + def __exit__(self, *exc_info): + self.Test.db.provider.json1_available = self.json1_available + return super(SqliteNoJson1, self).__exit__() + + +@provider() +class SqliteJson1(SqliteMixin): + provider_key = 'sqlite_json1' + + def __enter__(self): + result = super(SqliteJson1, self).__enter__() + if not self.db.provider.json1_available: + raise unittest.SkipTest + return result + + @provider() class PostgresContext(DBContext): provider_key = 'postgresql' @@ -445,29 +465,6 @@ def __exit__(self, *exc_info): delete(i for i in entity) -@provider() -class SqliteNoJson1(SqliteJson1): - provider_key = 'sqlite_no_json1' - - def __init__(self, cls): - self.Test = cls - cls.no_json1 = True - return super(SqliteNoJson1, self).__init__(cls) - - fixture_name = 'sqlite, no json1' - - def __enter__(self): - resource = super(SqliteNoJson1, self).__enter__() - self.json1_available = self.Test.db.provider.json1_available - self.Test.db.provider.json1_available = False - return resource - - def __exit__(self, *exc_info): - self.Test.db.provider.json1_available = self.json1_available - return super(SqliteNoJson1, self).__exit__() - - - import signal @provider() From 914867c4c98429737e1ce61c35273dfc60cda043 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 19 Aug 2016 21:13:54 +0300 Subject: [PATCH 067/547] Convert test_json from package to module --- .../tests/{test_json/test.py => test_json.py} | 0 pony/orm/tests/test_json/__init__.py | 0 pony/orm/tests/test_json/_postgres.py | 127 ------------------ 3 files changed, 127 deletions(-) rename pony/orm/tests/{test_json/test.py => test_json.py} (100%) delete mode 100644 pony/orm/tests/test_json/__init__.py delete mode 100644 pony/orm/tests/test_json/_postgres.py diff --git a/pony/orm/tests/test_json/test.py b/pony/orm/tests/test_json.py similarity index 100% rename from pony/orm/tests/test_json/test.py rename to pony/orm/tests/test_json.py diff --git a/pony/orm/tests/test_json/__init__.py b/pony/orm/tests/test_json/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/pony/orm/tests/test_json/_postgres.py b/pony/orm/tests/test_json/_postgres.py deleted file mode 100644 index 7e94e9016..000000000 --- a/pony/orm/tests/test_json/_postgres.py +++ /dev/null @@ -1,127 +0,0 @@ -''' -Postgres-specific tests -''' - -import unittest - -from pony.orm import * -from pony.orm.ormtypes import Json -from pony.orm.tests.testutils import raises_exception - -from . import SetupTest - - -class JsonConcatTest(SetupTest, unittest.TestCase): - - @classmethod - def bindDb(cls): - cls.db = Database('postgres', user='postgres', password='postgres', - database='testjson', host='localhost') - - @db_session - def setUp(self): - info = ['description', 4, {'size': '100x50'}] - self.E(article='A-347', info=info, extra_info={'overpriced': True}) - - @db_session - def test_field(self): - result = select(m.info[2] | m.extra_info for m in self.M)[:] - self.assertDictEqual(result[0], {u'overpriced': True, u'size': u'100x50'}) - - @db_session - def test_param(self): - x = 17 - result = select(m.info[2] | {"weight": x} for m in self.M)[:] - self.assertEqual(len(result), 1) - self.assertDictEqual(result[0], {'weight': 17, 'size': '100x50'}) - - @db_session - def test_complex_param(self): - x = {"weight": {'net': 17}} - result = select(m.info[2] | x for m in self.M)[:] - self.assertEqual(len(result), 1) - self.assertDictEqual(result[0], {'weight': {'net': 17}, 'size': '100x50'}) - - @db_session - def test_complex_param_2(self): - x = {'net': 17} - result = select(m.info[2] | {"weight": x} for m in self.M)[:] - self.assertEqual(len(result), 1) - self.assertDictEqual(result[0], {'weight': {'net': 17}, 'size': '100x50'}) - - @db_session - def test_str_const(self): - result = select(m.info[2] | {"weight": 17} for m in self.M)[:] - self.assertEqual(len(result), 1) - self.assertDictEqual(result[0], {'weight': 17, 'size': '100x50'}) - - @db_session - def test_str_param(self): - extra = {"weight": 17} - result = select(m.info[2] | extra for m in self.M)[:] - self.assertEqual(len(result), 1) - self.assertDictEqual(result[0], {'weight': 17, 'size': '100x50'}) - - @raises_exception(Exception) - @db_session - def test_no_json_wrapper(self): - result = select(m.info[2] | '{"weight": 17}' for m in self.M)[:] - self.assertEqual(len(result), 1) - self.assertDictEqual(result[0], {'weight': 17, 'size': '100x50'}) - - -class JsonContainsTest(SetupTest, unittest.TestCase): - - @classmethod - def bindDb(cls): - cls.db = Database('postgres', user='postgres', password='postgres', - database='testjson', host='localhost') - - @db_session - def setUp(self): - info = ['description', 4, {'size': '100x50'}] - self.M(article='A-347', info=info, extra_info={'overpriced': True}) - - @db_session - def test_key_in(self): - result = select('size' in m.info[2] for m in self.M)[:] - self.assertEqual(len(result), 1) - self.assertEqual(result[0], True) - - @db_session - def test_contains(self): - result = select({"size": "100x50"} in m.info[2] for m in self.M)[:] - self.assertEqual(len(result), 1) - self.assertEqual(result[0], True) - - @db_session - def test_contains_param(self): - for size in ['100x50', '200x100']: - result = select({"size": "%s" % size} in m.info[2] for m in self.M)[:] - self.assertEqual(len(result), 1) - self.assertEqual(result[0], size == '100x50') - - @db_session - def test_list(self): - result = select(Json(["description"]) in m.info for m in self.M)[:] - self.assertEqual(len(result), 1) - self.assertEqual(result[0], True) - - @db_session - def test_contains_field(self): - result = select({"size": "100x50"} in m.info[2] for m in self.M)[:] - self.assertEqual(len(result), 1) - self.assertEqual(result[0], True) - - @db_session - def test_inverse_order(self): - result = select(m.info[2] in {"size": "100x50", "weight": 1} for m in self.M)[:] - self.assertEqual(len(result), 1) - self.assertEqual(result[0], True) - - @db_session - def test_with_concat(self): - result = select((m.info[2] | {'weight': 1}) in {"size": "100x50", "weight": 1} - for m in self.M)[:] - self.assertEqual(len(result), 1) - self.assertEqual(result[0], True) From be9345149eeed0c02a1cc5dfcd32dcca54c28e04 Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Mon, 22 Aug 2016 13:35:26 +0300 Subject: [PATCH 068/547] update setup.py: add pony.utils subpackage --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 95020b54a..3ab1a3104 100644 --- a/setup.py +++ b/setup.py @@ -84,7 +84,8 @@ "pony.orm.integration", "pony.orm.tests", "pony.thirdparty", - "pony.thirdparty.compiler" + "pony.thirdparty.compiler", + "pony.utils" ] download_url = "http://pypi.python.org/pypi/pony/" From 0a4ac5addb57af965192846b44b636990d12420f Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Mon, 22 Aug 2016 12:24:15 +0300 Subject: [PATCH 069/547] Pony ORM Release 0.6.6 changelog --- CHANGELOG.md | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index bde74fca1..b1b434124 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,24 @@ +# Pony ORM Release 0.6.6 (2016-08-22) + +## New features + +* Added native JSON data type support in all supported databases: https://docs.ponyorm.com/json.html + +## Backward incompatible changes + +* Dropped Python 2.6 support + +## Improvements + +* #179 Added the compatibility with PYPY using psycopg2cffi +* Added an experimental @db_session `strict` parameter: https://docs.ponyorm.com/transactions.html#strict + +## Bugfixes + +* #182 - LEFT JOIN doesn't work as expected for inherited entities when foreign key is None +* Some small bugs were fixed + + # Pony ORM Release 0.6.5 (2016-04-04) ## Improvements From 8e03a477a7ce9747de28efd70950032ed83569c6 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 22 Aug 2016 13:19:02 +0300 Subject: [PATCH 070/547] Update Pony version: 0.6.6-dev -> 0.6.6 --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index d838e528f..acfb7da3d 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.6.6-dev' +__version__ = '0.6.6' uid = str(random.randint(1, 1000000)) From 7cbac28dd239a72f85679fd7335c52e03b9c39c0 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 22 Aug 2016 16:14:09 +0300 Subject: [PATCH 071/547] Update Pony version: 0.6.6 -> 0.6.7-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index acfb7da3d..384388bdb 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.6.6' +__version__ = '0.6.7-dev' uid = str(random.randint(1, 1000000)) From 01cca570a0e7ce6b59a8f4b00c50aa231dbd2953 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 29 Aug 2016 19:46:37 +0300 Subject: [PATCH 072/547] Fix #190: Timedelta not supported when using pymysql --- pony/orm/dbproviders/mysql.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 10ec005c9..c9eaa9a97 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -22,10 +22,11 @@ import pymysql as mysql_module except ImportError: raise ImportError('No module named MySQLdb or pymysql found') + from pymysql.converters import escape_str import pymysql.converters as mysql_converters from pymysql.constants import FIELD_TYPE, FLAG, CLIENT - if PY2: mysql_converters.encoders[buffer] = lambda val: mysql_converters.escape_str(str(val)) - mysql_converters.encoders[timedelta] = lambda val: mysql_converters.escape_str(timedelta2str(val)) + if PY2: mysql_converters.encoders[buffer] = lambda val, encoders=None: escape_str(str(val), encoders) + mysql_converters.encoders[timedelta] = lambda val, encoders=None: escape_str(timedelta2str(val), encoders) mysql_module_name = 'pymysql' from pony.orm import core, dbschema, dbapiprovider, ormtypes, sqltranslation From be8f67f7d8b9142efc2fe35bc0685077013c4389 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 30 Aug 2016 11:44:52 +0300 Subject: [PATCH 073/547] Refactoring: import pymysql.converters.escape_str as string_literal to be similar with MySQLdb --- pony/orm/dbproviders/mysql.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index c9eaa9a97..1664f2f8e 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -22,11 +22,11 @@ import pymysql as mysql_module except ImportError: raise ImportError('No module named MySQLdb or pymysql found') - from pymysql.converters import escape_str + from pymysql.converters import escape_str as string_literal import pymysql.converters as mysql_converters from pymysql.constants import FIELD_TYPE, FLAG, CLIENT - if PY2: mysql_converters.encoders[buffer] = lambda val, encoders=None: escape_str(str(val), encoders) - mysql_converters.encoders[timedelta] = lambda val, encoders=None: escape_str(timedelta2str(val), encoders) + if PY2: mysql_converters.encoders[buffer] = lambda val, encoders=None: string_literal(str(val), encoders) + mysql_converters.encoders[timedelta] = lambda val, encoders=None: string_literal(timedelta2str(val), encoders) mysql_module_name = 'pymysql' from pony.orm import core, dbschema, dbapiprovider, ormtypes, sqltranslation From f28e5638f8f0253d216e3401e7b5c0136f77ef97 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 30 Aug 2016 11:58:44 +0300 Subject: [PATCH 074/547] Refactoring: specify converters during connection setup in the same way for MySQLdb and pymysql --- pony/orm/dbproviders/mysql.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 1664f2f8e..d178744b0 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -25,8 +25,6 @@ from pymysql.converters import escape_str as string_literal import pymysql.converters as mysql_converters from pymysql.constants import FIELD_TYPE, FLAG, CLIENT - if PY2: mysql_converters.encoders[buffer] = lambda val, encoders=None: string_literal(str(val), encoders) - mysql_converters.encoders[timedelta] = lambda val, encoders=None: string_literal(timedelta2str(val), encoders) mysql_module_name = 'pymysql' from pony.orm import core, dbschema, dbapiprovider, ormtypes, sqltranslation @@ -229,7 +227,10 @@ def get_pool(provider, *args, **kwargs): conv = mysql_converters.conversions.copy() if mysql_module_name == 'MySQLdb': conv[FIELD_TYPE.BLOB] = [(FLAG.BINARY, buffer)] - conv[timedelta] = lambda td, c: string_literal(timedelta2str(td), c) + else: + if PY2: + conv[buffer] = lambda val, encoders=None: string_literal(str(val), encoders) + conv[timedelta] = lambda val, encoders=None: string_literal(timedelta2str(val), encoders) conv[FIELD_TYPE.TIMESTAMP] = str2datetime conv[FIELD_TYPE.DATETIME] = str2datetime conv[FIELD_TYPE.TIME] = str2timedelta From ca7baab42cbf764d63e0672e31000df84671a6e2 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 30 Aug 2016 12:03:50 +0300 Subject: [PATCH 075/547] Refactoring: replace lambdas with named functions --- pony/orm/dbproviders/mysql.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index d178744b0..bb8346bf0 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -229,8 +229,15 @@ def get_pool(provider, *args, **kwargs): conv[FIELD_TYPE.BLOB] = [(FLAG.BINARY, buffer)] else: if PY2: - conv[buffer] = lambda val, encoders=None: string_literal(str(val), encoders) - conv[timedelta] = lambda val, encoders=None: string_literal(timedelta2str(val), encoders) + def encode_buffer(val, encoders=None): + return string_literal(str(val), encoders) + + conv[buffer] = encode_buffer + + def encode_timedelta(val, encoders=None): + return string_literal(timedelta2str(val), encoders) + + conv[timedelta] = encode_timedelta conv[FIELD_TYPE.TIMESTAMP] = str2datetime conv[FIELD_TYPE.DATETIME] = str2datetime conv[FIELD_TYPE.TIME] = str2timedelta From 6e15313c6de62f6239ea09958ca9cfdf4423ecd1 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 30 Aug 2016 14:51:57 +0300 Subject: [PATCH 076/547] Improved exception message --- pony/orm/sqltranslation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 2e02a2209..10b6ba51f 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1194,7 +1194,9 @@ def __init__(monad, translator, items): monad.items = items def contains(monad, x, not_in=False): translator = monad.translator - for item in monad.items: check_comparable(item, x) + if isinstance(x.type, SetType): throw(TypeError, + "Type of `%s` is '%s'. Expression `{EXPR}` is not supported" % (ast2src(x.node), type2str(x.type))) + for item in monad.items: check_comparable(x, item) left_sql = x.getsql() if len(left_sql) == 1: if not_in: sql = [ 'NOT_IN', left_sql[0], [ item.getsql()[0] for item in monad.items ] ] From ac21f410b0ac99529eb0060a9f002f198c2e894f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 2 Mar 2016 20:03:57 +0300 Subject: [PATCH 077/547] Add descriptions to db_session-related tests --- pony/orm/tests/test_db_session.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/pony/orm/tests/test_db_session.py b/pony/orm/tests/test_db_session.py index 741cf5277..3de210baf 100644 --- a/pony/orm/tests/test_db_session.py +++ b/pony/orm/tests/test_db_session.py @@ -32,6 +32,7 @@ def test_db_session_3(self): self.assertTrue(db_session is db_session()) def test_db_session_4(self): + # Nested db_sessions are ignored with db_session: with db_session: self.X(a=3, b=3) @@ -39,6 +40,7 @@ def test_db_session_4(self): self.assertEqual(count(x for x in self.X), 3) def test_db_session_decorator_1(self): + # Should commit changes on exit from db_session @db_session def test(): self.X(a=3, b=3) @@ -47,6 +49,7 @@ def test(): self.assertEqual(count(x for x in self.X), 3) def test_db_session_decorator_2(self): + # Should rollback changes if an exception is occurred @db_session def test(): self.X(a=3, b=3) @@ -60,6 +63,7 @@ def test(): self.fail() def test_db_session_decorator_3(self): + # Should rollback changes if the exception is not in the list of allowed exceptions @db_session(allowed_exceptions=[TypeError]) def test(): self.X(a=3, b=3) @@ -73,6 +77,7 @@ def test(): self.fail() def test_db_session_decorator_4(self): + # Should commit changes if the exception is in the list of allowed exceptions @db_session(allowed_exceptions=[ZeroDivisionError]) def test(): self.X(a=3, b=3) @@ -98,6 +103,7 @@ def test(): pass def test_db_session_decorator_7(self): + # Should not to do retry until retry count is specified counter = count() @db_session(retry_exceptions=[ZeroDivisionError]) def test(): @@ -114,6 +120,7 @@ def test(): self.fail() def test_db_session_decorator_8(self): + # Should rollback & retry 1 time if retry=1 counter = count() @db_session(retry=1, retry_exceptions=[ZeroDivisionError]) def test(): @@ -130,6 +137,7 @@ def test(): self.fail() def test_db_session_decorator_9(self): + # Should rollback & retry N time if retry=N counter = count() @db_session(retry=5, retry_exceptions=[ZeroDivisionError]) def test(): @@ -146,6 +154,7 @@ def test(): self.fail() def test_db_session_decorator_10(self): + # Should not retry if the exception not in the list of retry_exceptions counter = count() @db_session(retry=3, retry_exceptions=[TypeError]) def test(): @@ -162,6 +171,7 @@ def test(): self.fail() def test_db_session_decorator_11(self): + # Should commit after successful retrying counter = count() @db_session(retry=5, retry_exceptions=[ZeroDivisionError]) def test(): @@ -186,6 +196,7 @@ def test(): pass def test_db_session_decorator_13(self): + # allowed_exceptions may be callable, should commit if nonzero @db_session(allowed_exceptions=lambda e: isinstance(e, ZeroDivisionError)) def test(): self.X(a=3, b=3) @@ -199,6 +210,7 @@ def test(): self.fail() def test_db_session_decorator_14(self): + # allowed_exceptions may be callable, should rollback if not nonzero @db_session(allowed_exceptions=lambda e: isinstance(e, TypeError)) def test(): self.X(a=3, b=3) @@ -212,6 +224,7 @@ def test(): self.fail() def test_db_session_decorator_15(self): + # retry_exceptions may be callable, should retry if nonzero counter = count() @db_session(retry=3, retry_exceptions=lambda e: isinstance(e, ZeroDivisionError)) def test(): @@ -240,6 +253,7 @@ def test_db_session_manager_2(self): self.X(a=3, b=3) def test_db_session_manager_3(self): + # Should rollback if the exception is not in the list of allowed_exceptions try: with db_session(allowed_exceptions=[TypeError]): self.X(a=3, b=3) @@ -251,6 +265,7 @@ def test_db_session_manager_3(self): self.fail() def test_db_session_manager_4(self): + # Should commit if the exception is in the list of allowed_exceptions try: with db_session(allowed_exceptions=[ZeroDivisionError]): self.X(a=3, b=3) From 524e7adac74cc755c4da43cc9db57aa72abd11c4 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 6 Sep 2016 12:37:54 +0300 Subject: [PATCH 078/547] Whitespace normalization --- pony/orm/tests/test_db_session.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/pony/orm/tests/test_db_session.py b/pony/orm/tests/test_db_session.py index 3de210baf..13147cc8b 100644 --- a/pony/orm/tests/test_db_session.py +++ b/pony/orm/tests/test_db_session.py @@ -208,7 +208,7 @@ def test(): self.assertEqual(count(x for x in self.X), 3) else: self.fail() - + def test_db_session_decorator_14(self): # allowed_exceptions may be callable, should rollback if not nonzero @db_session(allowed_exceptions=lambda e: isinstance(e, TypeError)) @@ -322,65 +322,78 @@ class Student(db.Entity): class TestDBSessionScope(unittest.TestCase): def setUp(self): rollback() + def tearDown(self): rollback() + def test1(self): with db_session: s1 = Student[1] name = s1.name + @raises_exception(DatabaseSessionIsOver, 'Cannot load attribute Student[1].picture: the database session is over') def test2(self): with db_session: s1 = Student[1] picture = s1.picture + @raises_exception(DatabaseSessionIsOver, 'Cannot load attribute Group[1].major: the database session is over') def test3(self): with db_session: s1 = Student[1] group_id = s1.group.id major = s1.group.major + @raises_exception(DatabaseSessionIsOver, 'Cannot assign new value to attribute Student[1].name: the database session is over') def test4(self): with db_session: s1 = Student[1] s1.name = 'New name' + def test5(self): with db_session: g1 = Group[1] self.assertEqual(str(g1.students), 'StudentSet([...])') + @raises_exception(DatabaseSessionIsOver, 'Cannot load collection Group[1].students: the database session is over') def test6(self): with db_session: g1 = Group[1] l = len(g1.students) + @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].Group.students: the database session is over') def test7(self): with db_session: s1 = Student[1] g1 = Group[1] g1.students.remove(s1) + @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].Group.students: the database session is over') def test8(self): with db_session: g2_students = Group[2].students g1 = Group[1] g1.students = g2_students + @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].Group.students: the database session is over') def test9(self): with db_session: s3 = Student[3] g1 = Group[1] g1.students.add(s3) + @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].Group.students: the database session is over') def test10(self): with db_session: g1 = Group[1] g1.students.clear() + @raises_exception(DatabaseSessionIsOver, 'Cannot delete object Student[1]: the database session is over') def test11(self): with db_session: s1 = Student[1] s1.delete() + @raises_exception(DatabaseSessionIsOver, 'Cannot change object Student[1]: the database session is over') def test12(self): with db_session: From 37306aed7237258b36d77f0c119007d1d2c12b71 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 1 Mar 2016 17:12:40 +0300 Subject: [PATCH 079/547] Fixes #159: exceptions happened during flush() should not be wrapped with CommitException --- pony/orm/core.py | 11 ++++++++++- pony/orm/tests/test_db_session.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 5cc9f1b2a..16ac32c16 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -280,10 +280,19 @@ def transact_reraise(exc_class, exceptions): def commit(): caches = _get_caches() if not caches: return + + try: + for cache in caches: + cache.flush() + except: + rollback() + raise + primary_cache = caches[0] other_caches = caches[1:] exceptions = [] - try: primary_cache.commit() + try: + primary_cache.commit() except: exceptions.append(sys.exc_info()) for cache in other_caches: diff --git a/pony/orm/tests/test_db_session.py b/pony/orm/tests/test_db_session.py index 13147cc8b..3d56e060a 100644 --- a/pony/orm/tests/test_db_session.py +++ b/pony/orm/tests/test_db_session.py @@ -296,6 +296,34 @@ def test(): pass test() + @raises_exception(ZeroDivisionError) + def test_db_session_exceptions_1(self): + def before_insert(self): + 1/0 + self.X.before_insert = before_insert + with db_session: + self.X(a=3, b=3) + # Should raise ZeroDivisionError and not CommitException + + @raises_exception(ZeroDivisionError) + def test_db_session_exceptions_2(self): + def before_insert(self): + 1 / 0 + self.X.before_insert = before_insert + with db_session: + self.X(a=3, b=3) + commit() + # Should raise ZeroDivisionError and not CommitException + + @raises_exception(ZeroDivisionError) + def test_db_session_exceptions_3(self): + def before_insert(self): + 1 / 0 + self.X.before_insert = before_insert + with db_session: + self.X(a=3, b=3) + db.commit() + # Should raise ZeroDivisionError and not CommitException db = Database('sqlite', ':memory:') From 5df7dd4d68301dc777985b54e2266afedbae198a Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 1 Mar 2016 13:13:09 +0300 Subject: [PATCH 080/547] Micro refactoring --- pony/orm/core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 16ac32c16..2078f255a 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1527,13 +1527,11 @@ def prepare_connection_for_query_execution(cache): return connection def commit(cache): assert cache.is_alive - database = cache.database - provider = database.provider try: if cache.modified: cache.flush() if cache.in_transaction: assert cache.connection is not None - provider.commit(cache.connection, cache) + cache.database.provider.commit(cache.connection, cache) cache.for_update.clear() cache.immediate = True except: From a983af04aaa6a59864983b6335d0954ed54103c7 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 6 Sep 2016 11:34:17 +0300 Subject: [PATCH 081/547] db.commit() exception should be wrapped with CommitException similar to global commit() function. The same for db.rollback() --- pony/orm/core.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 2078f255a..2bb959ab6 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -607,11 +607,14 @@ def flush(database): @cut_traceback def commit(database): cache = local.db2cache.get(database) - if cache is not None: cache.commit() + if cache is not None: + cache.flush_and_commit() @cut_traceback def rollback(database): cache = local.db2cache.get(database) - if cache is not None: cache.rollback() + if cache is not None: + try: cache.rollback() + except: transact_reraise(RollbackException, [sys.exc_info()]) @cut_traceback def execute(database, sql, globals=None, locals=None): return database._exec_raw_sql(sql, globals, locals, frame_depth=3, start_transaction=True) @@ -1512,7 +1515,7 @@ def prepare_connection_for_query_execution(cache): # in the interactive mode, outside of the db_session if cache.in_transaction or cache.modified: local.db_session = None - try: cache.commit() + try: cache.flush_and_commit() finally: local.db_session = db_session cache.db_session = db_session cache.immediate = cache.immediate or db_session.immediate @@ -1525,6 +1528,13 @@ def prepare_connection_for_query_execution(cache): except Exception as e: connection = cache.reconnect(e) if not cache.noflush_counter and cache.modified: cache.flush() return connection + def flush_and_commit(cache): + try: cache.flush() + except: + cache.rollback() + raise + try: cache.commit() + except: transact_reraise(CommitException, [sys.exc_info()]) def commit(cache): assert cache.is_alive try: From 06a7fa21cf62f74a45fb0abd2e4ad497c88402be Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 6 Sep 2016 14:02:56 +0300 Subject: [PATCH 082/547] Remove trailing whitespaces --- pony/orm/examples/estore.py | 4 ++-- pony/orm/tests/test_core_multiset.py | 2 +- pony/orm/tests/test_declarative_attr_set_monad.py | 4 ++-- pony/orm/tests/test_declarative_sqltranslator.py | 2 +- pony/orm/tests/test_declarative_strings.py | 4 ++-- pony/orm/tests/test_diagram_attribute.py | 4 ++-- pony/orm/tests/test_indexes.py | 2 +- pony/orm/tests/test_inheritance.py | 4 ++-- pony/orm/tests/test_objects_to_save_cleanup.py | 7 +++---- pony/orm/tests/test_relations_m2m.py | 8 ++++---- pony/orm/tests/test_relations_one2many.py | 6 +++--- pony/orm/tests/test_relations_one2one3.py | 2 +- 12 files changed, 24 insertions(+), 25 deletions(-) diff --git a/pony/orm/examples/estore.py b/pony/orm/examples/estore.py index e62d3374d..a7de96a77 100644 --- a/pony/orm/examples/estore.py +++ b/pony/orm/examples/estore.py @@ -229,7 +229,7 @@ def test_queries(): print('Three most valuable customers') print() result = select(c for c in Customer).order_by(lambda c: desc(sum(c.orders.total_price)))[:3] - + print(result) print() @@ -276,7 +276,7 @@ def test_queries(): for customer in Customer for product in customer.orders.items.product for category in product.categories - if count(product) > 1)[:] + if count(product) > 1)[:] print(result) print() diff --git a/pony/orm/tests/test_core_multiset.py b/pony/orm/tests/test_core_multiset.py index 61b557c54..d1374cfa4 100644 --- a/pony/orm/tests/test_core_multiset.py +++ b/pony/orm/tests/test_core_multiset.py @@ -119,7 +119,7 @@ def test_multiset_ne(self): d = Department[1] multiset = d.groups.students.courses self.assertFalse(multiset != multiset) - + @db_session def test_multiset_contains(self): d = Department[1] diff --git a/pony/orm/tests/test_declarative_attr_set_monad.py b/pony/orm/tests/test_declarative_attr_set_monad.py index 4f7f855a1..c60610db9 100644 --- a/pony/orm/tests/test_declarative_attr_set_monad.py +++ b/pony/orm/tests/test_declarative_attr_set_monad.py @@ -60,12 +60,12 @@ class Mark(db.Entity): Mark(value=1, student=s3, subject=History) Mark(value=2, student=s3, subject=Math) Mark(value=2, student=s4, subject=Math) - + class TestAttrSetMonad(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() - + def tearDown(self): rollback() db_session.__exit__() diff --git a/pony/orm/tests/test_declarative_sqltranslator.py b/pony/orm/tests/test_declarative_sqltranslator.py index 00a1af29a..7d046321e 100644 --- a/pony/orm/tests/test_declarative_sqltranslator.py +++ b/pony/orm/tests/test_declarative_sqltranslator.py @@ -338,7 +338,7 @@ def test_tuple_param(self): x = Student[1], Student[2] result = set(select(s for s in Student if s not in x)) self.assertEqual(result, set([Student[3]])) - @raises_exception(TypeError, "Expression `x` should not contain None values") + @raises_exception(TypeError, "Expression `x` should not contain None values") def test_tuple_param_2(self): x = Student[1], None result = set(select(s for s in Student if s not in x)) diff --git a/pony/orm/tests/test_declarative_strings.py b/pony/orm/tests/test_declarative_strings.py index bb0f72795..bdeceacf0 100644 --- a/pony/orm/tests/test_declarative_strings.py +++ b/pony/orm/tests/test_declarative_strings.py @@ -32,12 +32,12 @@ def tearDown(self): def test_nonzero(self): result = set(select(s for s in Student if s.foo)) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + self.assertEqual(result, set([Student[1], Student[2], Student[3]])) def test_add(self): name = 'Jonny' result = set(select(s for s in Student if s.name + "ny" == name)) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, set([Student[1]])) def test_slice_1(self): result = set(select(s for s in Student if s.name[0:3] == "Jon")) diff --git a/pony/orm/tests/test_diagram_attribute.py b/pony/orm/tests/test_diagram_attribute.py index d5dc5660b..b0d9aff89 100644 --- a/pony/orm/tests/test_diagram_attribute.py +++ b/pony/orm/tests/test_diagram_attribute.py @@ -393,7 +393,7 @@ class WebinarShow(db.Entity): db.generate_mapping(create_tables=True) self.assertEqual(Stat.webinarshow.column, None) self.assertEqual(WebinarShow.stats.column, 'stats') - + def test_columns_22(self): db = Database('sqlite', ':memory:') class ZStat(db.Entity): @@ -669,7 +669,7 @@ def test_none_type(self): db = Database('sqlite', ':memory:') class Foo(db.Entity): x = Required(str, sql_default='') - + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_indexes.py b/pony/orm/tests/test_indexes.py index 6db92d0a7..0aee54622 100644 --- a/pony/orm/tests/test_indexes.py +++ b/pony/orm/tests/test_indexes.py @@ -57,7 +57,7 @@ class Person(db.Entity): create_script = db.schema.generate_create_script() index_sql = 'CREATE INDEX "idx_person__name_age" ON "Person" ("name", "age")' self.assertTrue(index_sql in create_script) - + def test_2(self): db = Database('sqlite', ':memory:') class User(db.Entity): diff --git a/pony/orm/tests/test_inheritance.py b/pony/orm/tests/test_inheritance.py index 565b0750b..4b4173528 100644 --- a/pony/orm/tests/test_inheritance.py +++ b/pony/orm/tests/test_inheritance.py @@ -39,7 +39,7 @@ class Entity4(Entity2, Entity3): self.assertEqual(Entity2._discriminator_, 'Entity2') self.assertEqual(Entity3._discriminator_, 'Entity3') self.assertEqual(Entity4._discriminator_, 'Entity4') - + @raises_exception(ERDiagramError, "Multiple inheritance graph must be diamond-like. " "Entity Entity3 inherits from Entity1 and Entity2 entities which don't have common base class.") def test_2(self): @@ -102,7 +102,7 @@ class Entity1(db.Entity): b = Required(int) class Entity2(Entity1): c = Required(int) - + self.assertTrue(Entity1._discriminator_attr_ is Entity1.a) self.assertTrue(Entity2._discriminator_attr_ is Entity1.a) diff --git a/pony/orm/tests/test_objects_to_save_cleanup.py b/pony/orm/tests/test_objects_to_save_cleanup.py index 5c6e5d74a..331d87932 100644 --- a/pony/orm/tests/test_objects_to_save_cleanup.py +++ b/pony/orm/tests/test_objects_to_save_cleanup.py @@ -28,7 +28,7 @@ def test_delete_updated(self): p = TestPost() self.make_flush(p) p.name = 'Pony' - self.assertEqual(p._status_, 'modified') + self.assertEqual(p._status_, 'modified') self.make_flush(p) self.assertEqual(p._status_, 'updated') p.delete() @@ -54,16 +54,15 @@ def test_cancelled(self): self.assertEqual(p._status_, 'cancelled') - class EntityStatusTestCase_ObjectFlush(EntityStatusTestCase, unittest.TestCase): def make_flush(self, obj=None): obj.flush() - + class EntityStatusTestCase_FullFlush(EntityStatusTestCase, unittest.TestCase): def make_flush(self, obj=None): - flush() # full flush \ No newline at end of file + flush() # full flush diff --git a/pony/orm/tests/test_relations_m2m.py b/pony/orm/tests/test_relations_m2m.py index c4fd15224..376bf113e 100644 --- a/pony/orm/tests/test_relations_m2m.py +++ b/pony/orm/tests/test_relations_m2m.py @@ -19,7 +19,7 @@ class Subject(db.Entity): self.db = db self.Group = Group self.Subject = Subject - + self.db.generate_mapping(create_tables=True) with db_session: @@ -208,7 +208,7 @@ def test_13(self): self.assertTrue(s1 in group_setdata) self.assertEqual(group_setdata.added, None) self.assertEqual(group_setdata.removed, None) - + subj_setdata = s1._vals_[Subject.groups] self.assertTrue(g1 in subj_setdata) self.assertEqual(subj_setdata.added, None) @@ -221,7 +221,7 @@ def test_13(self): self.assertTrue(g1 not in subj_setdata) self.assertEqual(subj_setdata.added, None) self.assertEqual(subj_setdata.removed, set([ g1 ])) - + g1.subjects.add(s1) self.assertTrue(s1 in group_setdata) self.assertEqual(group_setdata.added, set()) @@ -263,7 +263,7 @@ def test_15(self): e = g.subjects.is_empty() # should take result from the cache self.assertEqual(e, False) self.assertEqual(db.last_sql, None) - + g = Group[102] c = len(g.subjects) self.assertEqual(c, 0) diff --git a/pony/orm/tests/test_relations_one2many.py b/pony/orm/tests/test_relations_one2many.py index 204ee6423..450e59b18 100644 --- a/pony/orm/tests/test_relations_one2many.py +++ b/pony/orm/tests/test_relations_one2many.py @@ -74,7 +74,7 @@ def test_5(self): Group, Student = self.Group, self.Student g = Group[101] g.students.add(None) - + @raises_exception(ValueError, 'A single Student instance or Student iterable is expected. Got: None') def test_6(self): Group, Student = self.Group, self.Student @@ -127,7 +127,7 @@ def test_10(self): e = g.students.is_empty() # should take result from the cache self.assertEqual(e, False) self.assertEqual(db.last_sql, None) - + g = Group[102] c = g.students.count() self.assertEqual(c, 2) @@ -255,7 +255,7 @@ def test_5(self): Group, Student = self.Group, self.Student g = Group[101] g.students.add(None) - + @raises_exception(ValueError, 'A single Student instance or Student iterable is expected. Got: None') def test_6(self): Group, Student = self.Group, self.Student diff --git a/pony/orm/tests/test_relations_one2one3.py b/pony/orm/tests/test_relations_one2one3.py index 269ccd1e9..2419267f7 100644 --- a/pony/orm/tests/test_relations_one2one3.py +++ b/pony/orm/tests/test_relations_one2one3.py @@ -70,7 +70,7 @@ def test_5(self): sql = self.db.last_sql self.assertEqual(sql, '''DELETE FROM "Person" WHERE "id" = ? - AND "name" = ?''') + AND "name" = ?''') @raises_exception(ConstraintError, 'Cannot unlink Passport[1] from previous Person[1] object, because Passport.person attribute is required') @db_session From 6a2f0bda686dfd4a8b431625554da89fa34909b5 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 6 Sep 2016 16:22:52 +0300 Subject: [PATCH 083/547] Local variable renaming --- pony/orm/asttranslation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 7e80ce910..2eb94b15e 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -226,13 +226,13 @@ def __init__(translator, tree, globals, locals, def dispatch(translator, node): node.external = node.constant = None ASTTranslator.dispatch(translator, node) - childs = node.getChildNodes() - if node.external is None and childs and all( - getattr(child, 'external', False) and not getattr(child, 'raw_sql', False) for child in childs): + children = node.getChildNodes() + if node.external is None and children and all( + getattr(child, 'external', False) and not getattr(child, 'raw_sql', False) for child in children): node.external = True if node.external and not node.constant: externals = translator.externals - externals.difference_update(childs) + externals.difference_update(children) externals.add(node) def preGenExprInner(translator, node): translator.contexts.append(set()) From e5a8d4f3ffbe662250be5ba416907a59b9fcb221 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 6 Sep 2016 17:53:12 +0300 Subject: [PATCH 084/547] Micro refactoring --- pony/orm/sqltranslation.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 10b6ba51f..f87895c42 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -720,20 +720,21 @@ def postNot(translator, node): def preCallFunc(translator, node): if node.star_args is not None: throw(NotImplementedError, '*%s is not supported' % ast2src(node.star_args)) if node.dstar_args is not None: throw(NotImplementedError, '**%s is not supported' % ast2src(node.dstar_args)) - if not isinstance(node.node, (ast.Name, ast.Getattr)): throw(NotImplementedError) + func_node = node.node + if not isinstance(func_node, (ast.Name, ast.Getattr)): throw(NotImplementedError) if len(node.args) > 1: return if not node.args: return arg = node.args[0] if isinstance(arg, ast.GenExpr): - translator.dispatch(node.node) - func_monad = node.node.monad + translator.dispatch(func_node) + func_monad = func_node.monad translator.dispatch(arg) query_set_monad = arg.monad return func_monad(query_set_monad) if not isinstance(arg, ast.Lambda): return lambda_expr = arg - translator.dispatch(node.node) - method_monad = node.node.monad + translator.dispatch(func_node) + method_monad = func_node.monad if not isinstance(method_monad, MethodMonad): throw(NotImplementedError) entity_monad = method_monad.parent if not isinstance(entity_monad, EntityMonad): throw(NotImplementedError) From 736e098cb768d596717dcf36ac1471e8e74e5fcf Mon Sep 17 00:00:00 2001 From: Vitalii Date: Thu, 12 May 2016 14:55:00 +0300 Subject: [PATCH 085/547] getattr support in queries --- pony/orm/asttranslation.py | 42 ++++++++++++++-- pony/orm/core.py | 10 ++-- pony/orm/sqltranslation.py | 8 +++ pony/orm/tests/test_getattr.py | 91 ++++++++++++++++++++++++++++++++++ 4 files changed, 142 insertions(+), 9 deletions(-) create mode 100644 pony/orm/tests/test_getattr.py diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 2eb94b15e..26cd7f775 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -208,6 +208,7 @@ class PreTranslator(ASTTranslator): def __init__(translator, tree, globals, locals, special_functions, const_functions, additional_internal_names=()): ASTTranslator.__init__(translator, tree) + translator.getattr_nodes = set() translator.globals = globals translator.locals = locals translator.special_functions = special_functions @@ -284,6 +285,10 @@ def postCallFunc(translator, node): else: if x in translator.special_functions: if x.__name__ == 'raw_sql': node.raw_sql = True + elif x is getattr: + attr_node = node.args[1] + attr_node.parent_node = node + translator.getattr_nodes.add(attr_node) else: node.external = False elif x in translator.const_functions: for arg in node.args: @@ -292,21 +297,50 @@ def postCallFunc(translator, node): if node.dstar_args is not None and not node.dstar_args.constant: return node.constant = True +getattr_cache = {} extractors_cache = {} def create_extractors(code_key, tree, filter_num, globals, locals, special_functions, const_functions, additional_internal_names=()): - cache_key = code_key, filter_num - result = extractors_cache.get(cache_key) - if result is None: + result = None + getattr_key = code_key, filter_num + getattr_extractors = getattr_cache.get(getattr_key) + if getattr_extractors: + getattr_attrname_values = tuple(eval(code, globals, locals) for src, code in getattr_extractors) + extractors_key = (code_key, filter_num, getattr_attrname_values) + try: + result = extractors_cache.get(extractors_key) + except TypeError: + pass # unhashable type + if not result: pretranslator = PreTranslator( tree, globals, locals, special_functions, const_functions, additional_internal_names) + extractors = {} for node in pretranslator.externals: src = node.src = ast2src(node) if src == '.0': code = None else: code = compile(src, src, 'eval') extractors[filter_num, src] = code + + getattr_extractors = {} + getattr_attrname_values = {} + for node in pretranslator.getattr_nodes: + if node in pretranslator.externals: + code = extractors[filter_num, node.src] + getattr_extractors[src] = code + attrname_value = eval(code, globals, locals) + getattr_attrname_values[src] = attrname_value + elif isinstance(node, ast.Const): + attrname_value = node.value + else: throw(TypeError, '`%s` should be either external expression or constant.' % ast2src(node)) + if not isinstance(attrname_value, basestring): throw(TypeError, + '%s: attribute name must be string. Got: %r' % (ast2src(node.parent_node), attrname_value)) + node._attrname_value = attrname_value + getattr_cache[getattr_key] = tuple(sorted(getattr_extractors.items())) + varnames = list(sorted(extractors)) - result = extractors_cache[cache_key] = extractors, varnames, tree + getattr_attrname_values = tuple(val for key, val in sorted(getattr_attrname_values.items())) + extractors_key = (code_key, filter_num, getattr_attrname_values) + result = extractors_cache[extractors_key] = extractors, varnames, tree, extractors_key return result diff --git a/pony/orm/core.py b/pony/orm/core.py index 2bb959ab6..2485fc4fa 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5043,7 +5043,7 @@ def unpickle_ast(pickled): class Query(object): def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False): assert isinstance(tree, ast.GenExprInner) - extractors, varnames, tree = create_extractors( + extractors, varnames, tree, pretranslator_key = create_extractors( code_key, tree, 0, globals, locals, special_functions, const_functions) vars, vartypes = extract_vars(extractors, globals, locals, cells) @@ -5059,7 +5059,7 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False if database.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % origin.__name__) database.provider.normalize_vars(vars, vartypes) query._vars = vars - query._key = code_key, tuple(vartypes[name] for name in varnames), left_join + query._key = pretranslator_key, tuple(vartypes[name] for name in varnames), left_join query._database = database translator = database._translator_cache.get(query._key) @@ -5359,7 +5359,7 @@ def _process_lambda(query, func, globals, locals, order_by): 'Expected: %d, got: %d' % (expr_count, len(argnames))) filter_num = len(query._filters) + 1 - extractors, varnames, func_ast = create_extractors( + extractors, varnames, func_ast, pretranslator_key = create_extractors( func_id, func_ast, filter_num, globals, locals, special_functions, const_functions, argnames or prev_translator.subquery) if extractors: @@ -5370,7 +5370,7 @@ def _process_lambda(query, func, globals, locals, order_by): sorted_vartypes = tuple(vartypes[name] for name in varnames) else: new_query_vars, vartypes, sorted_vartypes = query._vars, {}, () - new_key = query._key + (('order_by' if order_by else 'filter', func_id, sorted_vartypes),) + new_key = query._key + (('order_by' if order_by else 'filter', pretranslator_key, sorted_vartypes),) new_filters = query._filters + ((order_by, func_ast, argnames, extractors, vartypes),) new_translator = query._database._translator_cache.get(new_key) if new_translator is None: @@ -5612,5 +5612,5 @@ def show(entity): from pprint import pprint pprint(x) -special_functions = set([ itertools.count, utils.count, count, random, raw_sql ]) +special_functions = set([ itertools.count, utils.count, count, random, raw_sql, getattr ]) const_functions = set([ buffer, Decimal, datetime.datetime, datetime.date, datetime.time, datetime.timedelta ]) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index f87895c42..a3f454124 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -721,6 +721,8 @@ def preCallFunc(translator, node): if node.star_args is not None: throw(NotImplementedError, '*%s is not supported' % ast2src(node.star_args)) if node.dstar_args is not None: throw(NotImplementedError, '**%s is not supported' % ast2src(node.dstar_args)) func_node = node.node + if isinstance(func_node, ast.CallFunc): + if isinstance(func_node.node, ast.Name) and func_node.node.name == 'getattr': return if not isinstance(func_node, (ast.Name, ast.Getattr)): throw(NotImplementedError) if len(node.args) > 1: return if not node.args: return @@ -2044,6 +2046,12 @@ class FuncLenMonad(FuncMonad): def call(monad, x): return x.len() +class GetattrMonad(FuncMonad): + func = getattr + def call(monad, obj_monad, name_monad): + name = name_monad.node._attrname_value + return obj_monad.getattr(name) + class FuncCountMonad(FuncMonad): func = itertools.count, utils.count, core.count def call(monad, x=None): diff --git a/pony/orm/tests/test_getattr.py b/pony/orm/tests/test_getattr.py new file mode 100644 index 000000000..c3a82b0d9 --- /dev/null +++ b/pony/orm/tests/test_getattr.py @@ -0,0 +1,91 @@ +from pony.py23compat import basestring + +import unittest + +from pony.orm import * +from pony import orm +from pony.utils import cached_property +from pony.orm.tests.testutils import raises_exception + +class Test(unittest.TestCase): + + @cached_property + def db(self): + return orm.Database('sqlite', ':memory:') + + def setUp(self): + db = self.db + + class Genre(db.Entity): + name = orm.Required(str) + artists = orm.Set('Artist') + + class Hobby(db.Entity): + name = orm.Required(str) + artists = orm.Set('Artist') + + class Artist(db.Entity): + name = orm.Required(str) + age = orm.Optional(int) + hobbies = orm.Set(Hobby) + genres = orm.Set(Genre) + + db.generate_mapping(check_tables=True, create_tables=True) + + with orm.db_session: + pop = Genre(name='pop') + Artist(name='Sia', age=40, genres=[pop]) + + pony.options.INNER_JOIN_SYNTAX = True + + @db_session + def test_no_caching(self): + for attr, type in zip(['name', 'age'], [basestring, int]): + val = select(getattr(x, attr) for x in self.db.Artist).first() + self.assertIsInstance(val, type) + + @db_session + def test_simple(self): + val = select(getattr(x, 'age') for x in self.db.Artist).first() + self.assertIsInstance(val, int) + + @db_session + def test_expr(self): + val = select(getattr(x, ''.join(['ag', 'e'])) for x in self.db.Artist).first() + self.assertIsInstance(val, int) + + @db_session + def test_external(self): + class data: + id = 1 + val = select(x.id for x in self.db.Artist if x.id >= getattr(data, 'id')).first() + self.assertIsNotNone(val) + + @db_session + def test_related(self): + val = select(getattr(x.genres, 'name') for x in self.db.Artist).first() + self.assertIsNotNone(val) + + @db_session + def test_not_instance_iter(self): + val = select(getattr(x.name, 'startswith')('S') for x in self.db.Artist).first() + self.assertTrue(val) + + @db_session + def test_not_external(self): + with self.assertRaisesRegexp(TypeError, 'should be either external expression or constant'): + select(getattr(x, x.name) for x in self.db.Artist) + + @raises_exception(TypeError, 'getattr(x, 1): attribute name must be string. Got: 1') + @db_session + def test_not_string(self): + select(getattr(x, 1) for x in self.db.Artist) + + + @raises_exception(TypeError, 'getattr(x, name): attribute name must be string. Got: 1') + @db_session + def test_not_string(self): + name = 1 + select(getattr(x, name) for x in self.db.Artist) + + From e9df8898c48afad680aaa3dba516c8686ab0b26b Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 12 Sep 2016 16:54:53 +0300 Subject: [PATCH 086/547] Python3 fix --- pony/orm/asttranslation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 26cd7f775..ca5e57661 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -1,4 +1,5 @@ from __future__ import absolute_import, print_function, division +from pony.py23compat import basestring from functools import update_wrapper From 51a513b5256c5cf3624c619d29dbc779367b9384 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 12 Sep 2016 17:32:18 +0300 Subject: [PATCH 087/547] Fix test in Python 3 --- pony/orm/tests/test_getattr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/tests/test_getattr.py b/pony/orm/tests/test_getattr.py index c3a82b0d9..8704f9f5d 100644 --- a/pony/orm/tests/test_getattr.py +++ b/pony/orm/tests/test_getattr.py @@ -72,9 +72,9 @@ def test_not_instance_iter(self): self.assertTrue(val) @db_session + @raises_exception(TypeError, '`x.name` should be either external expression or constant.') def test_not_external(self): - with self.assertRaisesRegexp(TypeError, 'should be either external expression or constant'): - select(getattr(x, x.name) for x in self.db.Artist) + select(getattr(x, x.name) for x in self.db.Artist) @raises_exception(TypeError, 'getattr(x, 1): attribute name must be string. Got: 1') @db_session From c9c8761893ff2e9304d32c4faa8513c50ee440b5 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 12 Sep 2016 13:48:36 +0300 Subject: [PATCH 088/547] Remove obsolete code --- pony/utils/utils.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/pony/utils/utils.py b/pony/utils/utils.py index f6a04e15a..d8d4e3b5f 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -13,17 +13,9 @@ from locale import getpreferredencoding from bisect import bisect from collections import defaultdict -from copy import deepcopy, _deepcopy_dispatch from functools import update_wrapper from xml.etree import cElementTree -# deepcopy instance method patch for Python < 2.7: -if types.MethodType not in _deepcopy_dispatch: - assert PY2 - def _deepcopy_method(x, memo): - return type(x)(x.im_func, deepcopy(x.im_self, memo), x.im_class) - _deepcopy_dispatch[types.MethodType] = _deepcopy_method - import pony from pony import options From 38290f309a8c961a35d73f3490dbf8a5d47c0cc4 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 12 Sep 2016 14:00:28 +0300 Subject: [PATCH 089/547] Turn obsolete check into assertion --- pony/orm/core.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 2485fc4fa..1beb533ab 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -378,8 +378,8 @@ def __exit__(db_session, exc_type=None, exc=None, tb=None): elif not callable(db_session.allowed_exceptions): can_commit = issubclass(exc_type, tuple(db_session.allowed_exceptions)) else: - # exc can be None in Python 2.6 even if exc_type is not None - try: can_commit = exc is not None and db_session.allowed_exceptions(exc) + assert exc is not None # exc can be None in Python 2.6 even if exc_type is not None + try: can_commit = db_session.allowed_exceptions(exc) except: rollback() raise @@ -405,12 +405,13 @@ def new_func(func, *args, **kwargs): exc_type = exc = tb = None try: return func(*args, **kwargs) except: - exc_type, exc, tb = sys.exc_info() # exc can be None in Python 2.6 + exc_type, exc, tb = sys.exc_info() retry_exceptions = db_session.retry_exceptions if not callable(retry_exceptions): do_retry = issubclass(exc_type, tuple(retry_exceptions)) else: - do_retry = exc is not None and retry_exceptions(exc) + assert exc is not None # exc can be None in Python 2.6 + do_retry = retry_exceptions(exc) if not do_retry: raise finally: db_session.__exit__(exc_type, exc, tb) reraise(exc_type, exc, tb) From 68dd244be47ab42ea0dde60f8ae16cd55c645b21 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 12 Sep 2016 14:57:08 +0300 Subject: [PATCH 090/547] Simplify code --- pony/orm/decompiling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index d989caf14..bed528309 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -74,7 +74,7 @@ def __init__(decompiler, code, start=0, end=None): decompiler.assnames = set() decompiler.decompile() decompiler.ast = decompiler.stack.pop() - decompiler.external_names = set(decompiler.names - decompiler.assnames) + decompiler.external_names = decompiler.names - decompiler.assnames assert not decompiler.stack, decompiler.stack def decompile(decompiler): code = decompiler.code From 32eada962c0b093de6a2e284b7313600735b7922 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 12 Sep 2016 14:43:11 +0300 Subject: [PATCH 091/547] Use dict comprehension syntax --- pony/orm/core.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 1beb533ab..1a6efd40b 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -570,7 +570,7 @@ def merge_local_stats(database): @property def global_stats(database): with database._global_stats_lock: - return dict((sql, stat.copy()) for sql, stat in iteritems(database._global_stats)) + return {sql: stat.copy() for sql, stat in iteritems(database._global_stats)} @property def global_stats_lock(database): deprecated(3, "global_stats_lock is deprecated, just use global_stats property without any locking") @@ -1173,7 +1173,7 @@ def deserialize(x): if t is list: return list(imap(deserialize, x)) if t is dict: if '_id_' not in x: - return dict((key, deserialize(val)) for key, val in iteritems(x)) + return {key: deserialize(val) for key, val in iteritems(x)} obj = objmap.get(x['_id_']) if obj is None: entity_name = x['class'] @@ -3370,7 +3370,7 @@ def __init__(entity, name, bases, cls_dict): entity._new_attrs_ = new_attrs entity._attrs_ = base_attrs + new_attrs - entity._adict_ = dict((attr.name, attr) for attr in entity._attrs_) + entity._adict_ = {attr.name: attr for attr in entity._attrs_} entity._subclass_attrs_ = [] entity._subclass_adict_ = {} for base in entity._all_bases_: @@ -3553,7 +3553,7 @@ def __getitem__(entity, key): if len(key) != len(entity._pk_attrs_): throw(TypeError, 'Invalid count of attrs in %s primary key (%s instead of %s)' % (entity.__name__, len(key), len(entity._pk_attrs_))) - kwargs = dict(izip(imap(attrgetter('name'), entity._pk_attrs_), key)) + kwargs = {attr.name: value for attr, value in izip(entity._pk_attrs_, key)} return entity._find_one_(kwargs) @cut_traceback def exists(entity, *args, **kwargs): @@ -3709,7 +3709,7 @@ def _find_in_cache_(entity, pkval, avdict, for_update=False): return None, unique def _find_in_db_(entity, avdict, unique=False, for_update=False, nowait=False): database = entity._database_ - query_attrs = dict((attr, value is None) for attr, value in iteritems(avdict)) + query_attrs = {attr: value is None for attr, value in iteritems(avdict)} limit = 2 if not unique else None sql, adapter, attr_offsets = entity._construct_sql_(query_attrs, False, limit, for_update, nowait) arguments = adapter(avdict) From 70868c74eae3d26e5d961dfda6dfb65993e9f8fb Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 12 Sep 2016 14:56:08 +0300 Subject: [PATCH 092/547] Use set comprehension syntax --- pony/orm/core.py | 12 ++++++------ pony/orm/sqltranslation.py | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 1a6efd40b..eb112a229 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2501,7 +2501,7 @@ def load(attr, obj, items=None): if items: if not reverse.is_collection: - items = set(item for item in items if reverse not in item._vals_) + items = {item for item in items if reverse not in item._vals_} else: items = set(items) items -= setdata @@ -2521,7 +2521,7 @@ def load(attr, obj, items=None): items.append(obj) arguments = adapter(items) cursor = database._exec_sql(sql, arguments) - loaded_items = set(imap(rentity._get_by_raw_pkval_, cursor.fetchall())) + loaded_items = {rentity._get_by_raw_pkval_(row) for row in cursor.fetchall()} setdata |= loaded_items reverse.db_reverse_add(loaded_items, obj) return setdata @@ -2564,7 +2564,7 @@ def load(attr, obj, items=None): items = d.get(obj2) if items is None: items = d[obj2] = set() items.add(item) - else: d[obj] = set(imap(rentity._get_by_raw_pkval_, cursor.fetchall())) + else: d[obj] = {rentity._get_by_raw_pkval_(row) for row in cursor.fetchall()} for obj2, items in iteritems(d): setdata2 = obj2._vals_.get(attr) if setdata2 is None: setdata2 = obj._vals_[attr] = SetData() @@ -3333,7 +3333,7 @@ def __init__(entity, name, bases, cls_dict): for attr in new_attrs: if attr.is_unique: indexes.append(Index(attr, is_pk=isinstance(attr, PrimaryKey))) for index in indexes: index._init_(entity) - primary_keys = set(index.attrs for index in indexes if index.is_pk) + primary_keys = {index.attrs for index in indexes if index.is_pk} if direct_bases: if primary_keys: throw(ERDiagramError, 'Primary key cannot be redefined in derived classes') base_indexes = [] @@ -3341,7 +3341,7 @@ def __init__(entity, name, bases, cls_dict): for index in base._indexes_: if index not in base_indexes and index not in indexes: base_indexes.append(index) indexes[:0] = base_indexes - primary_keys = set(index.attrs for index in indexes if index.is_pk) + primary_keys = {index.attrs for index in indexes if index.is_pk} if len(primary_keys) > 1: throw(ERDiagramError, 'Only one primary key can be defined in each entity class') elif not primary_keys: @@ -3883,7 +3883,7 @@ def _load_many_(entity, objects): cache = database._get_cache() seeds = cache.seeds[entity._pk_attrs_] if not seeds: return - objects = set(obj for obj in objects if obj in seeds) + objects = {obj for obj in objects if obj in seeds} objects = sorted(objects, key=attrgetter('_pkval_')) max_batch_size = database.provider.max_params_count // len(entity._pk_columns_) while objects: diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index a3f454124..22135e0e2 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2406,7 +2406,7 @@ def _joined_subselect(monad, make_aggr, extra_grouping=False, coalesce_to_zero=F outer_conditions = subquery.outer_conditions groupby_columns = [ inner_column[:] for cond, outer_column, inner_column in outer_conditions ] - assert len(set(alias for _, alias, column in groupby_columns)) == 1 + assert len({alias for _, alias, column in groupby_columns}) == 1 if extra_grouping: inner_alias = translator.subquery.get_short_alias(None, 't') From d09e0dd4dea26b777d0807a7722a22c563fb4ae2 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 12 Sep 2016 14:34:03 +0300 Subject: [PATCH 093/547] Use set literal syntax --- pony/orm/core.py | 14 ++-- pony/orm/ormtypes.py | 8 +-- pony/orm/tests/test_crud_raw_sql.py | 6 +- .../tests/test_declarative_attr_set_monad.py | 46 ++++++------ pony/orm/tests/test_declarative_func_monad.py | 24 +++---- .../test_declarative_object_flat_monad.py | 4 +- .../tests/test_declarative_orderby_limit.py | 28 ++++---- .../tests/test_declarative_query_set_monad.py | 46 ++++++------ .../tests/test_declarative_sqltranslator.py | 71 +++++++++---------- .../tests/test_declarative_sqltranslator2.py | 24 +++---- pony/orm/tests/test_declarative_strings.py | 66 ++++++++--------- pony/orm/tests/test_diagram.py | 8 +-- pony/orm/tests/test_diagram_attribute.py | 4 +- pony/orm/tests/test_filter.py | 10 +-- pony/orm/tests/test_frames.py | 12 ++-- pony/orm/tests/test_prefetching.py | 2 +- pony/orm/tests/test_raw_sql.py | 28 ++++---- pony/orm/tests/test_relations_m2m.py | 6 +- pony/orm/tests/test_relations_one2many.py | 6 +- .../orm/tests/test_relations_symmetric_m2m.py | 10 +-- .../tests/test_relations_symmetric_one2one.py | 2 +- 21 files changed, 212 insertions(+), 213 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index eb112a229..2d03c2285 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1337,7 +1337,7 @@ def get_user_groups(user): result = local.user_groups_cache.get(user) if result is not None: return result if user is None: return anybody_frozenset - result = set(['anybody']) + result = {'anybody'} for cls, func in usergroup_functions: if cls is None or isinstance(user, cls): groups = func(user) @@ -4113,10 +4113,10 @@ def populate_criteria_list(criteria_list, columns, converters, operations, params_count += 1 return params_count -statuses = set(['created', 'cancelled', 'loaded', 'modified', 'inserted', 'updated', 'marked_to_delete', 'deleted']) -del_statuses = set(['marked_to_delete', 'deleted', 'cancelled']) -created_or_deleted_statuses = set(['created']) | del_statuses -saved_statuses = set(['inserted', 'updated', 'deleted']) +statuses = {'created', 'cancelled', 'loaded', 'modified', 'inserted', 'updated', 'marked_to_delete', 'deleted'} +del_statuses = {'marked_to_delete', 'deleted', 'cancelled'} +created_or_deleted_statuses = {'created'} | del_statuses +saved_statuses = {'inserted', 'updated', 'deleted'} def throw_object_was_deleted(obj): assert obj._status_ in del_statuses @@ -5613,5 +5613,5 @@ def show(entity): from pprint import pprint pprint(x) -special_functions = set([ itertools.count, utils.count, count, random, raw_sql, getattr ]) -const_functions = set([ buffer, Decimal, datetime.datetime, datetime.date, datetime.time, datetime.timedelta ]) +special_functions = {itertools.count, utils.count, count, random, raw_sql, getattr} +const_functions = {buffer, Decimal, datetime.datetime, datetime.date, datetime.time, datetime.timedelta} diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index e0c475772..c9278d208 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -122,10 +122,10 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) -numeric_types = set([ bool, int, float, Decimal ]) -comparable_types = set([ int, float, Decimal, unicode, date, time, datetime, timedelta, bool, UUID ]) -primitive_types = comparable_types | set([ buffer ]) -function_types = set([type, types.FunctionType, types.BuiltinFunctionType]) +numeric_types = {bool, int, float, Decimal} +comparable_types = {int, float, Decimal, unicode, date, time, datetime, timedelta, bool, UUID} +primitive_types = comparable_types | {buffer} +function_types = {type, types.FunctionType, types.BuiltinFunctionType} type_normalization_dict = { long : int } if PY2 else {} def get_normalized_type_of(value): diff --git a/pony/orm/tests/test_crud_raw_sql.py b/pony/orm/tests/test_crud_raw_sql.py index 4083a15f2..43af748a5 100644 --- a/pony/orm/tests/test_crud_raw_sql.py +++ b/pony/orm/tests/test_crud_raw_sql.py @@ -45,16 +45,16 @@ def tearDown(self): def test1(self): students = set(Student.select_by_sql("select id, name, age, group_dept, group_grad_year from Student order by age")) - self.assertEqual(students, set([Student[3], Student[2], Student[1]])) + self.assertEqual(students, {Student[3], Student[2], Student[1]}) def test2(self): students = set(Student.select_by_sql("select id, age, group_dept from Student order by age")) - self.assertEqual(students, set([Student[3], Student[2], Student[1]])) + self.assertEqual(students, {Student[3], Student[2], Student[1]}) @raises_exception(NameError, "Column x does not belong to entity Student") def test3(self): students = set(Student.select_by_sql("select id, age, age*2 as x from Student order by age")) - self.assertEqual(students, set([Student[3], Student[2], Student[1]])) + self.assertEqual(students, {Student[3], Student[2], Student[1]}) @raises_exception(TypeError, 'The first positional argument must be lambda function or its text source. Got: 123') def test4(self): diff --git a/pony/orm/tests/test_declarative_attr_set_monad.py b/pony/orm/tests/test_declarative_attr_set_monad.py index c60610db9..f6f9a6418 100644 --- a/pony/orm/tests/test_declarative_attr_set_monad.py +++ b/pony/orm/tests/test_declarative_attr_set_monad.py @@ -75,7 +75,7 @@ def test1(self): self.assertEqual(groups, [Group[41]]) def test2(self): groups = set(select(g for g in Group if len(g.students.name) >= 2)) - self.assertEqual(groups, set([Group[41], Group[42]])) + self.assertEqual(groups, {Group[41], Group[42]}) def test3(self): groups = select(g for g in Group if len(g.students.marks) > 2)[:] self.assertEqual(groups, [Group[41]]) @@ -90,78 +90,78 @@ def test5(self): self.assertEqual(students, []) def test6(self): students = set(select(s for s in Student if len(s.marks.subject) >= 2)) - self.assertEqual(students, set([Student[2], Student[3]])) + self.assertEqual(students, {Student[2], Student[3]}) def test8(self): students = set(select(s for s in Student if s.group in (g for g in Group if g.department == 101))) - self.assertEqual(students, set([Student[1], Student[2], Student[3]])) + self.assertEqual(students, {Student[1], Student[2], Student[3]}) def test9(self): students = set(select(s for s in Student if s.group not in (g for g in Group if g.department == 101))) - self.assertEqual(students, set([Student[4], Student[5], Student[6]])) + self.assertEqual(students, {Student[4], Student[5], Student[6]}) def test10(self): students = set(select(s for s in Student if s.group in (g for g in Group if g.department == 101))) - self.assertEqual(students, set([Student[1], Student[2], Student[3]])) + self.assertEqual(students, {Student[1], Student[2], Student[3]}) def test11(self): students = set(select(g for g in Group if len(g.subjects.groups.subjects) > 1)) - self.assertEqual(students, set([Group[41], Group[42], Group[43]])) + self.assertEqual(students, {Group[41], Group[42], Group[43]}) def test12(self): groups = set(select(g for g in Group if len(g.subjects) >= 2)) - self.assertEqual(groups, set([Group[41], Group[42]])) + self.assertEqual(groups, {Group[41], Group[42]}) def test13(self): groups = set(select(g for g in Group if g.students)) - self.assertEqual(groups, set([Group[41], Group[42], Group[44]])) + self.assertEqual(groups, {Group[41], Group[42], Group[44]}) def test14(self): groups = set(select(g for g in Group if not g.students)) - self.assertEqual(groups, set([Group[43]])) + self.assertEqual(groups, {Group[43]}) def test15(self): groups = set(select(g for g in Group if exists(g.students))) - self.assertEqual(groups, set([Group[41], Group[42], Group[44]])) + self.assertEqual(groups, {Group[41], Group[42], Group[44]}) def test15a(self): groups = set(select(g for g in Group if not not exists(g.students))) - self.assertEqual(groups, set([Group[41], Group[42], Group[44]])) + self.assertEqual(groups, {Group[41], Group[42], Group[44]}) def test16(self): groups = select(g for g in Group if not exists(g.students))[:] self.assertEqual(groups, [Group[43]]) def test17(self): groups = set(select(g for g in Group if 100 in g.students.scholarship)) - self.assertEqual(groups, set([Group[41]])) + self.assertEqual(groups, {Group[41]}) def test18(self): groups = set(select(g for g in Group if 100 not in g.students.scholarship)) - self.assertEqual(groups, set([Group[42], Group[43], Group[44]])) + self.assertEqual(groups, {Group[42], Group[43], Group[44]}) def test19(self): groups = set(select(g for g in Group if not not not 100 not in g.students.scholarship)) - self.assertEqual(groups, set([Group[41]])) + self.assertEqual(groups, {Group[41]}) def test20(self): groups = set(select(g for g in Group if exists(s for s in Student if s.group == g and s.scholarship == 500))) - self.assertEqual(groups, set([Group[41], Group[42]])) + self.assertEqual(groups, {Group[41], Group[42]}) def test21(self): groups = set(select(g for g in Group if g.department is not None)) - self.assertEqual(groups, set([Group[41], Group[42], Group[43], Group[44]])) + self.assertEqual(groups, {Group[41], Group[42], Group[43], Group[44]}) def test21a(self): groups = set(select(g for g in Group if not g.department is not None)) - self.assertEqual(groups, set([])) + self.assertEqual(groups, set()) def test21b(self): groups = set(select(g for g in Group if not not not g.department is None)) - self.assertEqual(groups, set([Group[41], Group[42], Group[43], Group[44]])) + self.assertEqual(groups, {Group[41], Group[42], Group[43], Group[44]}) def test22(self): groups = set(select(g for g in Group if 700 in (s.scholarship for s in Student if s.group == g))) - self.assertEqual(groups, set([Group[42]])) + self.assertEqual(groups, {Group[42]}) def test23a(self): groups = set(select(g for g in Group if 700 not in g.students.scholarship)) - self.assertEqual(groups, set([Group[41], Group[43], Group[44]])) + self.assertEqual(groups, {Group[41], Group[43], Group[44]}) def test23b(self): groups = set(select(g for g in Group if 700 not in (s.scholarship for s in Student if s.group == g))) - self.assertEqual(groups, set([Group[41], Group[43], Group[44]])) + self.assertEqual(groups, {Group[41], Group[43], Group[44]}) @raises_exception(NotImplementedError) def test24(self): groups = set(select(g for g in Group for g2 in Group if g.students == g2.students)) def test25(self): m1 = Mark[Student[1], Subject["Math"]] students = set(select(s for s in Student if m1 in s.marks)) - self.assertEqual(students, set([Student[1]])) + self.assertEqual(students, {Student[1]}) def test26(self): s1 = Student[1] groups = set(select(g for g in Group if s1 in g.students)) - self.assertEqual(groups, set([Group[41]])) + self.assertEqual(groups, {Group[41]}) @raises_exception(AttributeError, 'g.students.name.foo') def test27(self): select(g for g in Group if g.students.name.foo == 1) diff --git a/pony/orm/tests/test_declarative_func_monad.py b/pony/orm/tests/test_declarative_func_monad.py index 09cba95a2..a0c17d51e 100644 --- a/pony/orm/tests/test_declarative_func_monad.py +++ b/pony/orm/tests/test_declarative_func_monad.py @@ -56,16 +56,16 @@ def tearDown(self): db_session.__exit__() def test_minmax1(self): result = set(select(s for s in Student if max(s.id, 3) == 3 )) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_minmax2(self): result = set(select(s for s in Student if min(s.id, 3) == 3 )) - self.assertEqual(result, set([Student[4], Student[5], Student[3]])) + self.assertEqual(result, {Student[4], Student[5], Student[3]}) def test_minmax3(self): result = set(select(s for s in Student if max(s.name, "CC") == "CC" )) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_minmax4(self): result = set(select(s for s in Student if min(s.name, "CC") == "CC" )) - self.assertEqual(result, set([Student[4], Student[5], Student[3]])) + self.assertEqual(result, {Student[4], Student[5], Student[3]}) def test_minmax5(self): x = chr(128) try: result = set(select(s for s in Student if min(s.name, x) == "CC" )) @@ -82,7 +82,7 @@ def test_minmax7(self): result = set(select(s for s in Student if min(s.phd, 2) == 2 )) def test_date_func1(self): result = set(select(s for s in Student if s.dob >= date(1983, 3, 3))) - self.assertEqual(result, set([Student[3], Student[4], Student[5]])) + self.assertEqual(result, {Student[3], Student[4], Student[5]}) # @raises_exception(ExprEvalError, "date(1983, 'three', 3) raises TypeError: an integer is required") @raises_exception(TypeError, "'month' argument of date(year, month, day) function must be of 'int' type. " "Got: '%s'" % unicode.__name__) @@ -94,13 +94,13 @@ def test_date_func2(self): # result = set(select(s for s in Student if s.dob >= date(1983, d, 3))) def test_datetime_func1(self): result = set(select(s for s in Student if s.last_visit >= date(2011, 3, 3))) - self.assertEqual(result, set([Student[3], Student[4], Student[5]])) + self.assertEqual(result, {Student[3], Student[4], Student[5]}) def test_datetime_func2(self): result = set(select(s for s in Student if s.last_visit >= datetime(2011, 3, 3))) - self.assertEqual(result, set([Student[3], Student[4], Student[5]])) + self.assertEqual(result, {Student[3], Student[4], Student[5]}) def test_datetime_func3(self): result = set(select(s for s in Student if s.last_visit >= datetime(2011, 3, 3, 13, 13, 13))) - self.assertEqual(result, set([Student[3], Student[4], Student[5]])) + self.assertEqual(result, {Student[3], Student[4], Student[5]}) # @raises_exception(ExprEvalError, "datetime(1983, 'three', 3) raises TypeError: an integer is required") @raises_exception(TypeError, "'month' argument of datetime(...) function must be of 'int' type. " "Got: '%s'" % unicode.__name__) @@ -112,7 +112,7 @@ def test_datetime_func4(self): # result = set(select(s for s in Student if s.last_visit >= date(1983, d, 3))) def test_datetime_now1(self): result = set(select(s for s in Student if s.dob < date.today())) - self.assertEqual(result, set([Student[1], Student[2], Student[3], Student[4], Student[5]])) + self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) @raises_exception(ExprEvalError, "1 < datetime.now() raises TypeError: " + ("can't compare datetime.datetime to int" if PY2 else "unorderable types: int() < datetime.datetime()")) @@ -120,13 +120,13 @@ def test_datetime_now2(self): select(s for s in Student if 1 < datetime.now()) def test_datetime_now3(self): result = set(select(s for s in Student if s.dob < datetime.today())) - self.assertEqual(result, set([Student[1], Student[2], Student[3], Student[4], Student[5]])) + self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) def test_decimal_func(self): result = set(select(s for s in Student if s.scholarship >= Decimal("303.3"))) - self.assertEqual(result, set([Student[3], Student[4], Student[5]])) + self.assertEqual(result, {Student[3], Student[4], Student[5]}) def test_concat_1(self): result = set(select(concat(s.name, ':', s.dob.year, ':', s.scholarship) for s in Student)) - self.assertEqual(result, set(['AA:1981:0', 'BB:1982:202.2', 'CC:1983:303.3', 'DD:1984:404.4', 'EE:1985:505.5'])) + self.assertEqual(result, {'AA:1981:0', 'BB:1982:202.2', 'CC:1983:303.3', 'DD:1984:404.4', 'EE:1985:505.5'}) @raises_exception(TranslationError, 'Invalid argument of concat() function: g.students') def test_concat_2(self): result = set(select(concat(g.number, g.students) for g in Group)) diff --git a/pony/orm/tests/test_declarative_object_flat_monad.py b/pony/orm/tests/test_declarative_object_flat_monad.py index f374962e8..c3c33c0c5 100644 --- a/pony/orm/tests/test_declarative_object_flat_monad.py +++ b/pony/orm/tests/test_declarative_object_flat_monad.py @@ -57,12 +57,12 @@ class TestObjectFlatMonad(unittest.TestCase): @db_session def test1(self): result = set(select(s.groups for s in Subject if len(s.name) == 4)) - self.assertEqual(result, set([Group[41], Group[42]])) + self.assertEqual(result, {Group[41], Group[42]}) @db_session def test2(self): result = set(select(g.students for g in Group if g.department == 102)) - self.assertEqual(result, set([Student[5], Student[4]])) + self.assertEqual(result, {Student[5], Student[4]}) if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_declarative_orderby_limit.py b/pony/orm/tests/test_declarative_orderby_limit.py index a880634b4..fe36e1f1a 100644 --- a/pony/orm/tests/test_declarative_orderby_limit.py +++ b/pony/orm/tests/test_declarative_orderby_limit.py @@ -32,34 +32,34 @@ def tearDown(self): def test1(self): students = set(select(s for s in Student).order_by(Student.name)) - self.assertEqual(students, set([Student[3], Student[1], Student[2], Student[4], Student[5]])) + self.assertEqual(students, {Student[3], Student[1], Student[2], Student[4], Student[5]}) def test2(self): students = set(select(s for s in Student).order_by(Student.name.asc)) - self.assertEqual(students, set([Student[3], Student[1], Student[2], Student[4], Student[5]])) + self.assertEqual(students, {Student[3], Student[1], Student[2], Student[4], Student[5]}) def test3(self): students = set(select(s for s in Student).order_by(Student.id.desc)) - self.assertEqual(students, set([Student[5], Student[4], Student[3], Student[2], Student[1]])) + self.assertEqual(students, {Student[5], Student[4], Student[3], Student[2], Student[1]}) def test4(self): students = set(select(s for s in Student).order_by(Student.scholarship.asc, Student.group.desc)) - self.assertEqual(students, set([Student[1], Student[4], Student[3], Student[5], Student[2]])) + self.assertEqual(students, {Student[1], Student[4], Student[3], Student[5], Student[2]}) def test5(self): students = set(select(s for s in Student).order_by(Student.name).limit(3)) - self.assertEqual(students, set([Student[3], Student[1], Student[2]])) + self.assertEqual(students, {Student[3], Student[1], Student[2]}) def test6(self): students = set(select(s for s in Student).order_by(Student.name).limit(3, 1)) - self.assertEqual(students, set([Student[1], Student[2], Student[4]])) + self.assertEqual(students, {Student[1], Student[2], Student[4]}) def test7(self): q = select(s for s in Student).order_by(Student.name).limit(3, 1) students = set(q) - self.assertEqual(students, set([Student[1], Student[2], Student[4]])) + self.assertEqual(students, {Student[1], Student[2], Student[4]}) students = set(q) - self.assertEqual(students, set([Student[1], Student[2], Student[4]])) + self.assertEqual(students, {Student[1], Student[2], Student[4]}) # @raises_exception(TypeError, "query.order_by() arguments must be attributes. Got: 'name'") # now generate: ExprEvalError: name raises NameError: name 'name' is not defined @@ -68,11 +68,11 @@ def test7(self): def test9(self): students = set(select(s for s in Student).order_by(Student.id)[1:4]) - self.assertEqual(students, set([Student[2], Student[3], Student[4]])) + self.assertEqual(students, {Student[2], Student[3], Student[4]}) def test10(self): students = set(select(s for s in Student).order_by(Student.id)[:4]) - self.assertEqual(students, set([Student[1], Student[2], Student[3], Student[4]])) + self.assertEqual(students, {Student[1], Student[2], Student[3], Student[4]}) @raises_exception(TypeError, "Parameter 'stop' of slice object should be specified") def test11(self): @@ -93,19 +93,19 @@ def test13(self): def test15(self): students = set(select(s for s in Student).order_by(Student.id)[0:4][1:3]) - self.assertEqual(students, set([Student[2], Student[3]])) + self.assertEqual(students, {Student[2], Student[3]}) def test16(self): students = set(select(s for s in Student).order_by(Student.id)[0:4][1:]) - self.assertEqual(students, set([Student[2], Student[3], Student[4]])) + self.assertEqual(students, {Student[2], Student[3], Student[4]}) def test17(self): students = set(select(s for s in Student).order_by(Student.id)[:4][1:]) - self.assertEqual(students, set([Student[2], Student[3], Student[4]])) + self.assertEqual(students, {Student[2], Student[3], Student[4]}) def test18(self): students = set(select(s for s in Student).order_by(Student.id)[:]) - self.assertEqual(students, set([Student[1], Student[2], Student[3], Student[4], Student[5]])) + self.assertEqual(students, {Student[1], Student[2], Student[3], Student[4], Student[5]}) def test19(self): q = select(s for s in Student).order_by(Student.id) diff --git a/pony/orm/tests/test_declarative_query_set_monad.py b/pony/orm/tests/test_declarative_query_set_monad.py index 07555ecd0..183a64446 100644 --- a/pony/orm/tests/test_declarative_query_set_monad.py +++ b/pony/orm/tests/test_declarative_query_set_monad.py @@ -47,36 +47,36 @@ def tearDown(self): def test_len(self): result = set(select(g for g in Group if len(g.students) > 1)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_len_2(self): result = set(select(g for g in Group if len(s for s in Student if s.group == g) > 1)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_len_3(self): result = set(select(g for g in Group if len(s.name for s in Student if s.group == g) > 1)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_count_1(self): result = set(select(g for g in Group if count(s.name for s in g.students) > 1)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_count_2(self): result = set(select(g for g in Group if select(s.name for s in g.students).count() > 1)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_count_3(self): result = set(select(s for s in Student if count(c for c in s.courses) > 1)) - self.assertEqual(result, set([Student[2], Student[3]])) + self.assertEqual(result, {Student[2], Student[3]}) def test_count_4(self): result = set(select(c for c in Course if count(s for s in c.students) > 1)) - self.assertEqual(result, set([Course['C1', 1], Course['C2', 1]])) + self.assertEqual(result, {Course['C1', 1], Course['C2', 1]}) @raises_exception(TypeError) def test_sum_1(self): result = set(select(g for g in Group if sum(s for s in Student if s.group == g) > 1)) - self.assertEqual(result, set([])) + self.assertEqual(result, set()) @raises_exception(TypeError) def test_sum_2(self): @@ -84,15 +84,15 @@ def test_sum_2(self): def test_sum_3(self): result = set(select(g for g in Group if sum(s.scholarship for s in Student if s.group == g) > 500)) - self.assertEqual(result, set([])) + self.assertEqual(result, set()) def test_sum_4(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).sum() > 200)) - self.assertEqual(result, set([Group[2]])) + self.assertEqual(result, {Group[2]}) def test_min_1(self): result = set(select(g for g in Group if min(s.name for s in Student if s.group == g) == 'S1')) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) @raises_exception(TypeError) def test_min_2(self): @@ -100,11 +100,11 @@ def test_min_2(self): def test_min_3(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).min() == 0)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_max_1(self): result = set(select(g for g in Group if max(s.scholarship for s in Student if s.group == g) > 100)) - self.assertEqual(result, set([Group[2]])) + self.assertEqual(result, {Group[2]}) @raises_exception(TypeError) def test_max_2(self): @@ -112,7 +112,7 @@ def test_max_2(self): def test_max_3(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).max() == 100)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_avg_1(self): result = select(g for g in Group if avg(s.scholarship for s in Student if s.group == g) == 50)[:] @@ -120,40 +120,40 @@ def test_avg_1(self): def test_avg_2(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).avg() == 50)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_exists(self): result = set(select(g for g in Group if exists(s for s in g.students if s.name == 'S1'))) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_negate(self): result = set(select(g for g in Group if not(s.scholarship for s in Student if s.group == g))) - self.assertEqual(result, set([])) + self.assertEqual(result, set()) def test_no_conditions(self): students = set(select(s for s in Student if s.group in (g for g in Group))) - self.assertEqual(students, set([Student[1], Student[2], Student[3]])) + self.assertEqual(students, {Student[1], Student[2], Student[3]}) def test_no_conditions_2(self): students = set(select(s for s in Student if s.scholarship == max(s.scholarship for s in Student))) - self.assertEqual(students, set([Student[3]])) + self.assertEqual(students, {Student[3]}) def test_hint_join_1(self): result = set(select(s for s in Student if JOIN(s.group in select(g for g in Group if g.id < 2)))) - self.assertEqual(result, set([Student[1], Student[2]])) + self.assertEqual(result, {Student[1], Student[2]}) def test_hint_join_2(self): result = set(select(s for s in Student if JOIN(s.group not in select(g for g in Group if g.id < 2)))) - self.assertEqual(result, set([Student[3]])) + self.assertEqual(result, {Student[3]}) def test_hint_join_3(self): result = set(select(s for s in Student if JOIN(s.scholarship in select(s.scholarship + 100 for s in Student if s.name != 'S2')))) - self.assertEqual(result, set([Student[2]])) + self.assertEqual(result, {Student[2]}) def test_hint_join_4(self): result = set(select(g for g in Group if JOIN(g in select(s.group for s in g.students)))) - self.assertEqual(result, set([Group[1], Group[2]])) + self.assertEqual(result, {Group[1], Group[2]}) if __name__ == "__main__": unittest.main() diff --git a/pony/orm/tests/test_declarative_sqltranslator.py b/pony/orm/tests/test_declarative_sqltranslator.py index 7d046321e..c28832374 100644 --- a/pony/orm/tests/test_declarative_sqltranslator.py +++ b/pony/orm/tests/test_declarative_sqltranslator.py @@ -97,14 +97,14 @@ def tearDown(self): db_session.__exit__() def test_select1(self): result = set(select(s for s in Student)) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_select_param(self): result = select(s for s in Student if s.name == name1)[:] self.assertEqual(result, [Student[1]]) def test_select_object_param(self): stud1 = Student[1] result = set(select(s for s in Student if s != stud1)) - self.assertEqual(result, set([Student[2], Student[3]])) + self.assertEqual(result, {Student[2], Student[3]}) def test_select_deref(self): x = 'S1' result = select(s for s in Student if s.name == x)[:] @@ -132,8 +132,7 @@ def test_function_min2(self): def test_min3(self): d = date(2011, 1, 1) result = set(select(g for g in Grade if min(g.date, d) == d and g.date is not None)) - self.assertEqual(result, set([Grade[Student[1], Course[u'Math', 1]], - Grade[Student[1], Course[u'Physics', 2]]])) + self.assertEqual(result, {Grade[Student[1], Course[u'Math', 1]], Grade[Student[1], Course[u'Physics', 2]]}) def test_function_len1(self): result = select(s for s in Student if len(s.grades) == 1)[:] self.assertEqual(result, [Student[2]]) @@ -168,13 +167,13 @@ def test_builtin_in_locals(self): # select(s for s in Student for g in g.subjects) def test_chain1(self): result = set(select(g for g in Group for s in g.students if s.name.endswith('3'))) - self.assertEqual(result, set([Group[2]])) + self.assertEqual(result, {Group[2]}) def test_chain2(self): result = set(select(s for g in Group if g.dept.number == 44 for s in g.students if s.name.startswith('S'))) - self.assertEqual(result, set([Student[1], Student[2]])) + self.assertEqual(result, {Student[1], Student[2]}) def test_chain_m2m(self): result = set(select(g for g in Group for r in g.rooms if r.name == 'Room2')) - self.assertEqual(result, set([Group[1], Group[2]])) + self.assertEqual(result, {Group[1], Group[2]}) @raises_exception(TranslationError, 'All entities in a query must belong to the same database') def test_two_diagrams(self): select(g for g in Group for r in Room2 if r.name == 'Room2') @@ -183,10 +182,10 @@ def test_add_sub_mul_etc(self): self.assertEqual(result, [Student[2]]) def test_subscript(self): result = set(select(s for s in Student if s.name[1] == '2')) - self.assertEqual(result, set([Student[2]])) + self.assertEqual(result, {Student[2]}) def test_slice(self): result = set(select(s for s in Student if s.name[:1] == 'S')) - self.assertEqual(result, set([Student[3], Student[2], Student[1]])) + self.assertEqual(result, {Student[3], Student[2], Student[1]}) def test_attr_chain(self): s1 = Student[1] result = select(s for s in Student if s == s1)[:] @@ -207,9 +206,9 @@ def test_list_monad3(self): grade1 = Grade[Student[1], Course['Physics', 2]] grade2 = Grade[Student[1], Course['Math', 1]] result = set(select(g for g in Grade if g in [grade1, grade2])) - self.assertEqual(result, set([grade1, grade2])) + self.assertEqual(result, {grade1, grade2}) result = set(select(g for g in Grade if g not in [grade1, grade2])) - self.assertEqual(result, set([Grade[Student[2], Course['Economics', 1]]])) + self.assertEqual(result, {Grade[Student[2], Course['Economics', 1]]}) def test_tuple_monad1(self): n1 = 'S1' n2 = 'S2' @@ -235,7 +234,7 @@ def test_expr1(self): result = select(a for s in Student) def test_expr2(self): result = set(select(s.group for s in Student)) - self.assertEqual(result, set([Group[1], Group[2]])) + self.assertEqual(result, {Group[1], Group[2]}) def test_numeric_binop(self): i = 100 f = 2.0 @@ -246,19 +245,19 @@ def test_string_const_monad(self): self.assertEqual(result, []) def test_numeric_to_bool1(self): result = set(select(s for s in Student if s.name != 'John' or s.scholarship)) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_numeric_to_bool2(self): result = set(select(s for s in Student if not s.scholarship)) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_not_monad1(self): result = set(select(s for s in Student if not (s.scholarship > 0 and s.name != 'S1'))) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_not_monad2(self): result = set(select(s for s in Student if not not (s.scholarship > 0 and s.name != 'S1'))) - self.assertEqual(result, set([Student[2], Student[3]])) + self.assertEqual(result, {Student[2], Student[3]}) def test_subquery_with_attr(self): result = set(select(s for s in Student if max(g.value for g in s.grades) == 'C')) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_query_reuse(self): q = select(s for s in Student if s.scholarship > 0) q.count() @@ -275,47 +274,47 @@ def test_order_by(self): self.assertEqual(result, [Student[1], Student[2], Student[3]]) def test_read_inside_query(self): result = set(select(s for s in Student if Group[1].dept.number == 44)) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_crud_attr_chain(self): result = set(select(s for s in Student if Group[1].dept.number == s.group.dept.number)) - self.assertEqual(result, set([Student[1], Student[2]])) + self.assertEqual(result, {Student[1], Student[2]}) def test_composite_key1(self): result = set(select(t for t in Teacher if Grade[Student[1], Course['Physics', 2]] in t.grades)) - self.assertEqual(result, set([Teacher.get(name='T1')])) + self.assertEqual(result, {Teacher.get(name='T1')}) def test_composite_key2(self): result = set(select(s for s in Student if Course['Math', 1] in s.courses)) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_composite_key3(self): result = set(select(s for s in Student if Course['Math', 1] not in s.courses)) - self.assertEqual(result, set([Student[2], Student[3]])) + self.assertEqual(result, {Student[2], Student[3]}) def test_composite_key4(self): result = set(select(s for s in Student if len(c for c in Course if c not in s.courses) == 2)) - self.assertEqual(result, set([Student[1], Student[2]])) + self.assertEqual(result, {Student[1], Student[2]}) def test_composite_key5(self): result = set(select(s for s in Student if not (c for c in Course if c not in s.courses))) self.assertEqual(result, set()) def test_composite_key6(self): result = set(select(c for c in Course if c not in (c2 for s in Student for c2 in s.courses))) - self.assertEqual(result, set([Course['Physics', 2]])) + self.assertEqual(result, {Course['Physics', 2]}) def test_composite_key7(self): result = set(select(c for s in Student for c in s.courses)) - self.assertEqual(result, set([Course['Math', 1], Course['Economics', 1]])) + self.assertEqual(result, {Course['Math', 1], Course['Economics', 1]}) def test_contains1(self): s1 = Student[1] result = set(select(g for g in Group if s1 in g.students)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_contains2(self): s1 = Student[1] result = set(select(g for g in Group if s1.name in g.students.name)) - self.assertEqual(result, set([Group[1]])) + self.assertEqual(result, {Group[1]}) def test_contains3(self): s1 = Student[1] result = set(select(g for g in Group if s1 not in g.students)) - self.assertEqual(result, set([Group[2]])) + self.assertEqual(result, {Group[2]}) def test_contains4(self): s1 = Student[1] result = set(select(g for g in Group if s1.name not in g.students.name)) - self.assertEqual(result, set([Group[2]])) + self.assertEqual(result, {Group[2]}) def test_buffer_monad1(self): try: select(s for s in Student if s.picture == buffer('abc')) except TypeError as e: self.assertTrue(not PY2 and str(e) == 'string argument without an encoding') @@ -324,32 +323,32 @@ def test_buffer_monad2(self): select(s for s in Student if s.picture == buffer('abc', 'ascii')) def test_database_monad(self): result = set(select(s for s in db.Student if db.Student[1] == s)) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_duplicate_name(self): result = set(select(x for x in Student if x.group in (x for x in Group))) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_hint_join1(self): result = set(select(s for s in Student if JOIN(max(s.courses.credits) == 3))) - self.assertEqual(result, set([Student[2]])) + self.assertEqual(result, {Student[2]}) def test_hint_join2(self): result = set(select(c for c in Course if JOIN(len(c.students) == 1))) - self.assertEqual(result, set([Course['Math', 1], Course['Economics', 1]])) + self.assertEqual(result, {Course['Math', 1], Course['Economics', 1]}) def test_tuple_param(self): x = Student[1], Student[2] result = set(select(s for s in Student if s not in x)) - self.assertEqual(result, set([Student[3]])) + self.assertEqual(result, {Student[3]}) @raises_exception(TypeError, "Expression `x` should not contain None values") def test_tuple_param_2(self): x = Student[1], None result = set(select(s for s in Student if s not in x)) - self.assertEqual(result, set([Student[3]])) + self.assertEqual(result, {Student[3]}) @raises_exception(TypeError, "Function 'f' cannot be used this way: f(s)") def test_unknown_func(self): def f(x): return x select(s for s in Student if f(s)) def test_method_monad(self): result = set(select(s for s in Student if s not in Student.select(lambda s: s.scholarship > 0))) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) if __name__ == "__main__": diff --git a/pony/orm/tests/test_declarative_sqltranslator2.py b/pony/orm/tests/test_declarative_sqltranslator2.py index 016df1fc5..607862c90 100644 --- a/pony/orm/tests/test_declarative_sqltranslator2.py +++ b/pony/orm/tests/test_declarative_sqltranslator2.py @@ -108,10 +108,10 @@ def test_distinct4(self): self.assertEqual(q[:], [Department[2]]) def test_distinct5(self): result = set(select(s for s in Student)) - self.assertEqual(result, set([Student[1], Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]}) def test_distinct6(self): result = set(select(s for s in Student).distinct()) - self.assertEqual(result, set([Student[1], Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]}) def test_not_null1(self): q = select(g for g in Group if '123-45-67' not in g.students.tel and g.dept == Department[1]) not_null = "IS_NOT_NULL COLUMN student-1 tel" in (" ".join(str(i) for i in flatten(q._translator.conditions))) @@ -124,11 +124,11 @@ def test_not_null2(self): self.assertEqual(q[:], [Group[101]]) def test_chain_of_attrs_inside_for1(self): result = set(select(s for d in Department if d.number == 2 for s in d.groups.students)) - self.assertEqual(result, set([Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[4], Student[5], Student[6], Student[7]}) def test_chain_of_attrs_inside_for2(self): pony.options.SIMPLE_ALIASES = False result = set(select(s for d in Department if d.number == 2 for s in d.groups.students)) - self.assertEqual(result, set([Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[4], Student[5], Student[6], Student[7]}) pony.options.SIMPLE_ALIASES = True def test_non_entity_result1(self): result = select((s.name, s.group.number) for s in Student if s.name.startswith("J"))[:] @@ -146,7 +146,7 @@ def test_non_entity_result3a(self): self.assertEqual(sorted(result), [1988, 1989, 1990, 1991]) def test_non_entity_result4(self): result = set(select(s.name for s in Student if s.name.startswith('M'))) - self.assertEqual(result, set([u'Matthew Reed', u'Maria Ionescu'])) + self.assertEqual(result, {u'Matthew Reed', u'Maria Ionescu'}) def test_non_entity_result5(self): result = select((s.group, s.dob) for s in Student if s.group == Group[101])[:] self.assertEqual(sorted(result), [(Group[101], date(1989, 2, 5)), (Group[101], date(1990, 11, 26)), (Group[101], date(1991, 3, 20))]) @@ -156,7 +156,7 @@ def test_non_entity_result6(self): Student[2]), (Course[u'Web Design',1], Student[1]), (Course[u'Web Design',1], Student[2])])) def test_non_entity7(self): result = set(select(s for s in Student if (s.name, s.dob) not in (((s2.name, s2.dob) for s2 in Student if s.group.number == 101)))) - self.assertEqual(result, set([Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[4], Student[5], Student[6], Student[7]}) @raises_exception(IncomparableTypesError, "Incomparable types 'int' and 'Set of Student' in expression: g.number == g.students") def test_incompartible_types(self): select(g for g in Group if g.number == g.students) @@ -167,7 +167,7 @@ def test_external_param1(self): def test_external_param2(self): x = Student[1] result = set(select(s for s in Student if s.name != x.name)) - self.assertEqual(result, set([Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]}) @raises_exception(TypeError, "Use select(...) function or Group.select(...) method for iteration") def test_exception1(self): for g in Group: @@ -182,13 +182,13 @@ def test_entity_not_found(self): select(s for s in db.Student for g in db.FooBar) def test_keyargs1(self): result = set(select(s for s in Student if s.dob < date(year=1990, month=10, day=20))) - self.assertEqual(result, set([Student[3], Student[4], Student[6], Student[7]])) + self.assertEqual(result, {Student[3], Student[4], Student[6], Student[7]}) def test_query_as_string1(self): result = set(select('s for s in Student if 3 <= s.gpa < 4')) - self.assertEqual(result, set([Student[1], Student[2], Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[1], Student[2], Student[4], Student[5], Student[6], Student[7]}) def test_query_as_string2(self): result = set(select('s for s in db.Student if 3 <= s.gpa < 4')) - self.assertEqual(result, set([Student[1], Student[2], Student[4], Student[5], Student[6], Student[7]])) + self.assertEqual(result, {Student[1], Student[2], Student[4], Student[5], Student[6], Student[7]}) def test_str_subclasses(self): result = select(d for d in Department for g in d.groups for c in d.courses if g.number == 106 and c.name.startswith('T'))[:] self.assertEqual(result, [Department[3]]) @@ -199,7 +199,7 @@ class Unicode2(unicode): select(s for s in Student if len(u2) == 1) def test_bool(self): result = set(select(s for s in Student if s.phd == True)) - self.assertEqual(result, set([Student[1], Student[2]])) + self.assertEqual(result, {Student[1], Student[2]}) def test_bool2(self): result = list(select(s for s in Student if s.phd + 1 == True)) self.assertEqual(result, []) @@ -212,7 +212,7 @@ def test_bool4(self): def test_bool5(self): x = True result = set(select(s for s in Student if s.phd == True and (False or (True and x)))) - self.assertEqual(result, set([Student[1], Student[2]])) + self.assertEqual(result, {Student[1], Student[2]}) def test_bool6(self): x = False result = list(select(s for s in Student if s.phd == (False or (True and x)) and s.phd is True)) diff --git a/pony/orm/tests/test_declarative_strings.py b/pony/orm/tests/test_declarative_strings.py index bdeceacf0..2706e82c9 100644 --- a/pony/orm/tests/test_declarative_strings.py +++ b/pony/orm/tests/test_declarative_strings.py @@ -32,146 +32,146 @@ def tearDown(self): def test_nonzero(self): result = set(select(s for s in Student if s.foo)) - self.assertEqual(result, set([Student[1], Student[2], Student[3]])) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) def test_add(self): name = 'Jonny' result = set(select(s for s in Student if s.name + "ny" == name)) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_slice_1(self): result = set(select(s for s in Student if s.name[0:3] == "Jon")) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[1], Student[4]}) def test_slice_2(self): result = set(select(s for s in Student if s.name[:3] == "Jon")) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[1], Student[4]}) def test_slice_3(self): x = 3 result = set(select(s for s in Student if s.name[:x] == "Jon")) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[1], Student[4]}) def test_slice_4(self): x = 3 result = set(select(s for s in Student if s.name[0:x] == "Jon")) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[1], Student[4]}) def test_slice_5(self): result = set(select(s for s in Student if s.name[0:10] == "Jon")) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_slice_6(self): result = set(select(s for s in Student if s.name[0:] == "Jon")) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_slice_7(self): result = set(select(s for s in Student if s.name[:] == "Jon")) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_slice_8(self): result = set(select(s for s in Student if s.name[1:] == "on")) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_slice_9(self): x = 1 result = set(select(s for s in Student if s.name[x:] == "on")) - self.assertEqual(result, set([Student[1]])) + self.assertEqual(result, {Student[1]}) def test_slice_10(self): x = 0 result = set(select(s for s in Student if s.name[x:3] == "Jon")) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[1], Student[4]}) def test_slice_11(self): x = 1 y = 3 result = set(select(s for s in Student if s.name[x:y] == "on")) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[1], Student[4]}) def test_slice_12(self): x = 10 y = 20 result = set(select(s for s in Student if s.name[x:y] == '')) - self.assertEqual(result, set([Student[1], Student[2], Student[3], Student[4], Student[5]])) + self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) def test_getitem_1(self): result = set(select(s for s in Student if s.name[1] == 'o')) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[1], Student[4]}) def test_getitem_2(self): x = 1 result = set(select(s for s in Student if s.name[x] == 'o')) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[1], Student[4]}) def test_getitem_3(self): result = set(select(s for s in Student if s.name[-1] == 'n')) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[1], Student[4]}) def test_getitem_4(self): x = -1 result = set(select(s for s in Student if s.name[x] == 'n')) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[1], Student[4]}) def test_contains_1(self): result = set(select(s for s in Student if 'o' in s.name)) - self.assertEqual(result, set([Student[1], Student[2], Student[4]])) + self.assertEqual(result, {Student[1], Student[2], Student[4]}) def test_contains_2(self): result = set(select(s for s in Student if 'on' in s.name)) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[1], Student[4]}) def test_contains_3(self): x = 'on' result = set(select(s for s in Student if x in s.name)) - self.assertEqual(result, set([Student[1], Student[4]])) + self.assertEqual(result, {Student[1], Student[4]}) def test_contains_4(self): x = 'on' result = set(select(s for s in Student if x not in s.name)) - self.assertEqual(result, set([Student[2], Student[3], Student[5]])) + self.assertEqual(result, {Student[2], Student[3], Student[5]}) def test_contains_5(self): result = set(select(s for s in Student if '%' in s.foo)) - self.assertEqual(result, set([Student[2]])) + self.assertEqual(result, {Student[2]}) def test_contains_6(self): x = '%' result = set(select(s for s in Student if x in s.foo)) - self.assertEqual(result, set([Student[2]])) + self.assertEqual(result, {Student[2]}) def test_contains_7(self): result = set(select(s for s in Student if '_' in s.foo)) - self.assertEqual(result, set([Student[3]])) + self.assertEqual(result, {Student[3]}) def test_contains_8(self): x = '_' result = set(select(s for s in Student if x in s.foo)) - self.assertEqual(result, set([Student[3]])) + self.assertEqual(result, {Student[3]}) def test_contains_9(self): result = set(select(s for s in Student if s.foo in 'Abcdef')) - self.assertEqual(result, set([Student[1], Student[4], Student[5]])) + self.assertEqual(result, {Student[1], Student[4], Student[5]}) def test_contains_10(self): result = set(select(s for s in Student if s.bar in s.foo)) - self.assertEqual(result, set([Student[2], Student[4], Student[5]])) + self.assertEqual(result, {Student[2], Student[4], Student[5]}) def test_startswith_1(self): students = set(select(s for s in Student if s.name.startswith('J'))) - self.assertEqual(students, set([Student[1], Student[4]])) + self.assertEqual(students, {Student[1], Student[4]}) def test_startswith_2(self): students = set(select(s for s in Student if not s.name.startswith('J'))) - self.assertEqual(students, set([Student[2], Student[3], Student[5]])) + self.assertEqual(students, {Student[2], Student[3], Student[5]}) def test_startswith_3(self): students = set(select(s for s in Student if not not s.name.startswith('J'))) - self.assertEqual(students, set([Student[1], Student[4]])) + self.assertEqual(students, {Student[1], Student[4]}) def test_startswith_4(self): students = set(select(s for s in Student if not not not s.name.startswith('J'))) - self.assertEqual(students, set([Student[2], Student[3], Student[5]])) + self.assertEqual(students, {Student[2], Student[3], Student[5]}) def test_startswith_5(self): x = "Pe" @@ -180,7 +180,7 @@ def test_startswith_5(self): def test_endswith_1(self): students = set(select(s for s in Student if s.name.endswith('n'))) - self.assertEqual(students, set([Student[1], Student[4]])) + self.assertEqual(students, {Student[1], Student[4]}) def test_endswith_2(self): x = "te" diff --git a/pony/orm/tests/test_diagram.py b/pony/orm/tests/test_diagram.py index 235de72af..73f537f0e 100644 --- a/pony/orm/tests/test_diagram.py +++ b/pony/orm/tests/test_diagram.py @@ -98,8 +98,8 @@ class Entity2(db.Entity): attr2 = Set(Entity1) db.generate_mapping(create_tables=True) m2m_table = db.schema.tables['Entity1_Entity2'] - col_names = set([ col.name for col in m2m_table.column_list ]) - self.assertEqual(col_names, set(['entity1', 'entity2'])) + col_names = {col.name for col in m2m_table.column_list} + self.assertEqual(col_names, {'entity1', 'entity2'}) self.assertEqual(Entity1.attr1.get_m2m_columns(), ['entity1']) def test_diagram9(self): @@ -114,8 +114,8 @@ class Entity2(db.Entity): attr2 = Set(Entity1) db.generate_mapping(create_tables=True) m2m_table = db.schema.tables['Entity1_Entity2'] - col_names = set([ col.name for col in m2m_table.column_list ]) - self.assertEqual(col_names, set(['entity1_a', 'entity1_b', 'entity2'])) + col_names = {col.name for col in m2m_table.column_list} + self.assertEqual(col_names, {'entity1_a', 'entity1_b', 'entity2'}) def test_diagram10(self): db = Database('sqlite', ':memory:') diff --git a/pony/orm/tests/test_diagram_attribute.py b/pony/orm/tests/test_diagram_attribute.py index b0d9aff89..6505bd255 100644 --- a/pony/orm/tests/test_diagram_attribute.py +++ b/pony/orm/tests/test_diagram_attribute.py @@ -263,11 +263,11 @@ class Entity1(db.Entity): id = PrimaryKey(int, columns=['a', 'b']) db.generate_mapping(create_tables=True) - @raises_exception(TypeError, "Parameter 'columns' must be a list. Got: %r'" % set(['a'])) + @raises_exception(TypeError, "Parameter 'columns' must be a list. Got: %r'" % {'a'}) def test_columns6(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): - id = PrimaryKey(int, columns=set(['a'])) + id = PrimaryKey(int, columns={'a'}) db.generate_mapping(create_tables=True) @raises_exception(TypeError, "Parameter 'column' must be a string. Got: 4") diff --git a/pony/orm/tests/test_filter.py b/pony/orm/tests/test_filter.py index f6b6724a5..84d2b4ca9 100644 --- a/pony/orm/tests/test_filter.py +++ b/pony/orm/tests/test_filter.py @@ -14,17 +14,17 @@ def tearDown(self): def test_filter_1(self): q = select(s for s in Student) result = set(q.filter(scholarship=0)) - self.assertEqual(result, set([Student[101], Student[103]])) + self.assertEqual(result, {Student[101], Student[103]}) def test_filter_2(self): q = select(s for s in Student) q2 = q.filter(scholarship=500) result = set(q2.filter(group=Group['3132'])) - self.assertEqual(result, set([Student[104]])) + self.assertEqual(result, {Student[104]}) def test_filter_3(self): q = select(s for s in Student) q2 = q.filter(lambda s: s.scholarship > 500) result = set(q2.filter(lambda s: count(s.marks) > 0)) - self.assertEqual(result, set([Student[102]])) + self.assertEqual(result, {Student[102]}) def test_filter_4(self): q = select(s for s in Student) q2 = q.filter(lambda s: s.scholarship != 500) @@ -47,11 +47,11 @@ def test_filter_7(self): q = select(s for s in Student) q2 = q.filter(scholarship=0) result = set(q2.filter(lambda s: count(s.marks) > 1)) - self.assertEqual(result, set([Student[103], Student[101]])) + self.assertEqual(result, {Student[103], Student[101]}) def test_filter_8(self): q = select(s for s in Student) q2 = q.filter(lambda s: s.scholarship != 500) q3 = q2.order_by(lambda s: s.name) q4 = q3.order_by(None) result = set(q4.filter(lambda s: count(s.marks) > 1)) - self.assertEqual(result, set([Student[103], Student[101]])) + self.assertEqual(result, {Student[103], Student[101]}) diff --git a/pony/orm/tests/test_frames.py b/pony/orm/tests/test_frames.py index 67a56dbef..68ac3a601 100644 --- a/pony/orm/tests/test_frames.py +++ b/pony/orm/tests/test_frames.py @@ -23,25 +23,25 @@ class TestFrames(unittest.TestCase): def test_select(self): x = 20 result = select(p.id for p in Person if p.age > x)[:] - self.assertEqual(set(result), set([1, 3])) + self.assertEqual(set(result), {1, 3}) @db_session def test_select_str(self): x = 20 result = select('p.id for p in Person if p.age > x')[:] - self.assertEqual(set(result), set([1, 3])) + self.assertEqual(set(result), {1, 3}) @db_session def test_left_join(self): x = 20 result = left_join(p.id for p in Person if p.age > x)[:] - self.assertEqual(set(result), set([1, 3])) + self.assertEqual(set(result), {1, 3}) @db_session def test_left_join_str(self): x = 20 result = left_join('p.id for p in Person if p.age > x')[:] - self.assertEqual(set(result), set([1, 3])) + self.assertEqual(set(result), {1, 3}) @db_session def test_get(self): @@ -107,13 +107,13 @@ def test_entity_exists_str(self): def test_entity_select(self): x = 20 result = Person.select(lambda p: p.age > x)[:] - self.assertEqual(set(result), set([Person[1], Person[3]])) + self.assertEqual(set(result), {Person[1], Person[3]}) @db_session def test_entity_select_str(self): x = 20 result = Person.select('lambda p: p.age > x')[:] - self.assertEqual(set(result), set([Person[1], Person[3]])) + self.assertEqual(set(result), {Person[1], Person[3]}) @db_session def test_order_by(self): diff --git a/pony/orm/tests/test_prefetching.py b/pony/orm/tests/test_prefetching.py index 685e1435a..ff3fda31f 100644 --- a/pony/orm/tests/test_prefetching.py +++ b/pony/orm/tests/test_prefetching.py @@ -94,7 +94,7 @@ def test_9(self): def test_10(self): with db_session: s1 = Student.select().prefetch(Student.courses).first() - self.assertEqual(set(s1.courses.name), set(['Math', 'Physics'])) + self.assertEqual(set(s1.courses.name), {'Math', 'Physics'}) @raises_exception(DatabaseSessionIsOver, 'Cannot load attribute Student[1].biography: the database session is over') def test_11(self): diff --git a/pony/orm/tests/test_raw_sql.py b/pony/orm/tests/test_raw_sql.py index 151be61f9..22f3a89df 100644 --- a/pony/orm/tests/test_raw_sql.py +++ b/pony/orm/tests/test_raw_sql.py @@ -26,20 +26,20 @@ class TestRawSQL(unittest.TestCase): def test_1(self): # raw_sql result can be treated as a logical expression persons = select(p for p in Person if raw_sql('abs("p"."age") > 25'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_2(self): # raw_sql result can be used for comparison persons = select(p for p in Person if raw_sql('abs("p"."age")') > 25)[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_3(self): # raw_sql can accept $parameters x = 25 persons = select(p for p in Person if raw_sql('abs("p"."age") > $x'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_4(self): @@ -47,7 +47,7 @@ def test_4(self): x = 1 s = 'p.id > $x' persons = select(p for p in Person if raw_sql(s))[:] - self.assertEqual(set(persons), set([Person[2], Person[3]])) + self.assertEqual(set(persons), {Person[2], Person[3]}) @db_session def test_5(self): @@ -55,14 +55,14 @@ def test_5(self): x = 1 cond = raw_sql('p.id > $x') persons = select(p for p in Person if cond)[:] - self.assertEqual(set(persons), set([Person[2], Person[3]])) + self.assertEqual(set(persons), {Person[2], Person[3]}) @db_session def test_6(self): # correct converter should be applied to raw_sql parameter type x = date(1990, 1, 1) persons = select(p for p in Person if raw_sql('p.dob < $x'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_7(self): @@ -70,19 +70,19 @@ def test_7(self): x = 10 y = 15 persons = select(p for p in Person if raw_sql('p.age > $(x + y)'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_8(self): # raw_sql argument may be complex expression (2) persons = select(p for p in Person if raw_sql('p.dob < $date.today()'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2], Person[3]])) + self.assertEqual(set(persons), {Person[1], Person[2], Person[3]}) @db_session def test_9(self): # using raw_sql in the expression part of the generator names = select(raw_sql('UPPER(p.name)') for p in Person)[:] - self.assertEqual(set(names), set(['JOHN', 'MIKE', 'MARY'])) + self.assertEqual(set(names), {'JOHN', 'MIKE', 'MARY'}) @db_session def test_10(self): @@ -101,21 +101,21 @@ def test_12(self): # raw_sql can be used in lambdas x = 25 persons = Person.select(lambda p: p.age > raw_sql('$x'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_13(self): # raw_sql in filter() x = 25 persons = select(p for p in Person).filter(lambda p: p.age > raw_sql('$x'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_14(self): # raw_sql in filter() without using lambda x = 25 persons = Person.select().filter(raw_sql('p.age > $x'))[:] - self.assertEqual(set(persons), set([Person[1], Person[2]])) + self.assertEqual(set(persons), {Person[1], Person[2]}) @db_session def test_15(self): @@ -123,7 +123,7 @@ def test_15(self): x = '123' y = 'John' persons = Person.select(lambda p: raw_sql("UPPER(p.name) || $x") == raw_sql("UPPER($y || '123')"))[:] - self.assertEqual(set(persons), set([Person[1]])) + self.assertEqual(set(persons), {Person[1]}) @db_session def test_16(self): @@ -135,7 +135,7 @@ def test_16(self): y = 'j' q = q.filter(lambda p: p.dob > x and p.name.startswith(raw_sql('UPPER($y)'))) persons = q[:] - self.assertEqual(set(persons), set([Person[1]])) + self.assertEqual(set(persons), {Person[1]}) @db_session def test_17(self): diff --git a/pony/orm/tests/test_relations_m2m.py b/pony/orm/tests/test_relations_m2m.py index 376bf113e..4c80f885b 100644 --- a/pony/orm/tests/test_relations_m2m.py +++ b/pony/orm/tests/test_relations_m2m.py @@ -91,7 +91,7 @@ def test_5(self): with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj3', 'Subj4']) - self.assertEqual(Group[101].subjects, set([Subject['Subj3'], Subject['Subj4']])) + self.assertEqual(Group[101].subjects, {Subject['Subj3'], Subject['Subj4']}) def test_6(self): db, Group, Subject = self.db, self.Group, self.Subject @@ -217,10 +217,10 @@ def test_13(self): g1.subjects.remove(s1) self.assertTrue(s1 not in group_setdata) self.assertEqual(group_setdata.added, None) - self.assertEqual(group_setdata.removed, set([ s1 ])) + self.assertEqual(group_setdata.removed, {s1}) self.assertTrue(g1 not in subj_setdata) self.assertEqual(subj_setdata.added, None) - self.assertEqual(subj_setdata.removed, set([ g1 ])) + self.assertEqual(subj_setdata.removed, {g1}) g1.subjects.add(s1) self.assertTrue(s1 in group_setdata) diff --git a/pony/orm/tests/test_relations_one2many.py b/pony/orm/tests/test_relations_one2many.py index 450e59b18..9caa55401 100644 --- a/pony/orm/tests/test_relations_one2many.py +++ b/pony/orm/tests/test_relations_one2many.py @@ -65,7 +65,7 @@ def test_4(self): s1, s2, s3, s4 = Student.select().order_by(Student.id) g1, g2 = Group[101], Group[102] g1.students = g2.students - self.assertEqual(set(g1.students), set([s3, s4])) + self.assertEqual(set(g1.students), {s3, s4}) self.assertEqual(s1._status_, 'marked_to_delete') self.assertEqual(s2._status_, 'marked_to_delete') @@ -223,7 +223,7 @@ def tearDown(self): def test_1(self): self.Student[1].group = None - self.assertEqual(set(self.Group[101].students), set([self.Student[2]])) + self.assertEqual(set(self.Group[101].students), {self.Student[2]}) def test_2(self): Student, Group = self.Student, self.Group @@ -246,7 +246,7 @@ def test_4(self): s1, s2, s3, s4 = Student.select().order_by(Student.id) g1, g2 = Group[101], Group[102] g1.students = g2.students - self.assertEqual(set(g1.students), set([s3, s4])) + self.assertEqual(set(g1.students), {s3, s4}) self.assertEqual(s1.group, None) self.assertEqual(s2.group, None) diff --git a/pony/orm/tests/test_relations_symmetric_m2m.py b/pony/orm/tests/test_relations_symmetric_m2m.py index c09d40742..b5762228d 100644 --- a/pony/orm/tests/test_relations_symmetric_m2m.py +++ b/pony/orm/tests/test_relations_symmetric_m2m.py @@ -32,12 +32,12 @@ def test1a(self): p1 = Person[1] p4 = Person[4] p1.friends.add(p4) - self.assertEqual(set(p4.friends), set([p1])) + self.assertEqual(set(p4.friends), {p1}) def test1b(self): p1 = Person[1] p4 = Person[4] p1.friends.add(p4) - self.assertEqual(set(p1.friends), set([Person[2], Person[3], p4])) + self.assertEqual(set(p1.friends), {Person[2], Person[3], p4}) def test1c(self): p1 = Person[1] p4 = Person[4] @@ -49,12 +49,12 @@ def test2a(self): p1 = Person[1] p2 = Person[2] p1.friends.remove(p2) - self.assertEqual(set(p1.friends), set([Person[3]])) + self.assertEqual(set(p1.friends), {Person[3]}) def test2b(self): p1 = Person[1] p2 = Person[2] p1.friends.remove(p2) - self.assertEqual(set(Person[3].friends), set([p1])) + self.assertEqual(set(Person[3].friends), {p1}) def test2c(self): p1 = Person[1] p2 = Person[2] @@ -84,7 +84,7 @@ def test3b(self): p1 = Person[1] p2 = Person[2] p1_friends = set(p1.friends) - self.assertEqual(p1_friends, set([p2])) + self.assertEqual(p1_friends, {p2}) try: p2_friends = set(p2.friends) except UnrepeatableReadError as e: self.assertEqual(e.args[0], "Phantom object Person[1] disappeared from collection Person[2].friends") diff --git a/pony/orm/tests/test_relations_symmetric_one2one.py b/pony/orm/tests/test_relations_symmetric_one2one.py index 232b3190a..e07b920bd 100644 --- a/pony/orm/tests/test_relations_symmetric_one2one.py +++ b/pony/orm/tests/test_relations_symmetric_one2one.py @@ -63,7 +63,7 @@ def test3(self): self.assertEqual([3, None, 1, None, None], data) def test4(self): persons = set(select(p for p in Person if p.spouse.name in ('B', 'D'))) - self.assertEqual(persons, set([Person[1], Person[3]])) + self.assertEqual(persons, {Person[1], Person[3]}) @raises_exception(UnrepeatableReadError, 'Value of Person.spouse for Person[1] was updated outside of current transaction') def test5(self): db.execute('update person set spouse = 3 where id = 2') From cc57a17c180fd88231daf3855f872a225969e091 Mon Sep 17 00:00:00 2001 From: rmakarov94 Date: Mon, 12 Sep 2016 13:14:34 +0400 Subject: [PATCH 094/547] =?UTF-8?q?Prevent=20=E2=80=9CTypeError:=20Unicode?= =?UTF-8?q?-objects=20must=20be=20encoded=20before=20hashing=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pony/orm/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 2d03c2285..6ba4d0670 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -984,7 +984,7 @@ def _get_schema_dict(database): return result def _get_schema_json(database): schema_json = json.dumps(database._get_schema_dict(), default=basic_converter, sort_keys=True) - schema_hash = md5(schema_json).hexdigest() + schema_hash = md5(schema_json.encode('utf-8')).hexdigest() return schema_json, schema_hash @cut_traceback def to_json(database, data, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): From b033e09609794a4b20238a5fdaaf1687c53f7bde Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 7 Sep 2016 17:18:23 +0300 Subject: [PATCH 095/547] Optimization of nested CASE expressions --- pony/orm/sqlbuilding.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 6ba8af430..cd82ab2ef 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -471,6 +471,9 @@ def SUBSTR(builder, expr, start, len=None): if len is None: return 'substr(', builder(expr), ', ', builder(start), ')' return 'substr(', builder(expr), ', ', builder(start), ', ', builder(len), ')' def CASE(builder, expr, cases, default=None): + if expr is None and default is not None and default[0] == 'CASE' and default[1] is None: + cases2, default2 = default[2:] + return builder.CASE(None, tuple(cases) + tuple(cases2), default2) result = [ 'case' ] if expr is not None: result.append(' ') From 8c5b7d6fc42324cdbec15af060d9ff033b346e28 Mon Sep 17 00:00:00 2001 From: Vitalii Date: Mon, 22 Aug 2016 11:54:33 +0300 Subject: [PATCH 096/547] Fix fixtures --- pony/orm/tests/fixtures.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/tests/fixtures.py b/pony/orm/tests/fixtures.py index f8fef0a32..b1af03438 100644 --- a/pony/orm/tests/fixtures.py +++ b/pony/orm/tests/fixtures.py @@ -209,7 +209,7 @@ def __enter__(self): def __exit__(self, *exc_info): self.Test.db.provider.json1_available = self.json1_available - return super(SqliteNoJson1, self).__exit__() + return super(SqliteNoJson1, self).__exit__(*exc_info) @provider() From 29a07c9cb43349a4753cc41c8495a4e2d865d4c5 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 14 Sep 2016 13:03:24 +0300 Subject: [PATCH 097/547] Join bug fixed --- pony/orm/sqltranslation.py | 14 +++- pony/orm/tests/test_inner_join_syntax.py | 90 ++++++++++++++++++++++++ 2 files changed, 102 insertions(+), 2 deletions(-) create mode 100644 pony/orm/tests/test_inner_join_syntax.py diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 22135e0e2..3f9cbdd69 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -869,6 +869,16 @@ def get_short_alias(subquery, name_path, entity_name): alias = '%s-%d' % (name, i) subquery.alias_counters[name] = i return alias + def join_table(subquery, parent_alias, alias, table_name, join_cond): + new_item = [alias, 'TABLE', table_name, join_cond] + from_ast = subquery.from_ast + for i in xrange(1, len(from_ast)): + if from_ast[i][0] == parent_alias: + for j in xrange(i+1, len(from_ast)): + if len(from_ast[j]) < 4: # item without join condition + from_ast.insert(j, new_item) + return + from_ast.append(new_item) class TableRef(object): def __init__(tableref, subquery, name, entity): @@ -945,7 +955,7 @@ def make_join(tableref, pk_only=False): m2m_alias = subquery.get_short_alias(None, 't') reverse_columns = attr.columns if attr.symmetric else attr.reverse.columns m2m_join_cond = join_tables(parent_alias, m2m_alias, left_pk_columns, reverse_columns) - subquery.from_ast.append([ m2m_alias, 'TABLE', m2m_table, m2m_join_cond ]) + subquery.join_table(parent_alias, m2m_alias, m2m_table, m2m_join_cond) if pk_only: tableref.alias = m2m_alias tableref.pk_columns = right_m2m_columns @@ -961,7 +971,7 @@ def make_join(tableref, pk_only=False): discr_criteria = entity._construct_discriminator_criteria_(alias) assert discr_criteria is not None join_cond.append(discr_criteria) - subquery.from_ast.append([ alias, 'TABLE', entity._table_, join_cond ]) + subquery.join_table(parent_alias, alias, entity._table_, join_cond) tableref.alias = alias tableref.pk_columns = pk_columns tableref.optimized = False diff --git a/pony/orm/tests/test_inner_join_syntax.py b/pony/orm/tests/test_inner_join_syntax.py new file mode 100644 index 000000000..20d4977db --- /dev/null +++ b/pony/orm/tests/test_inner_join_syntax.py @@ -0,0 +1,90 @@ +import unittest + +from pony.orm import * +from pony import orm + +import pony.orm.tests.fixtures + +class TestJoin(unittest.TestCase): + + exclude_fixtures = {'test': ['clear_tables']} + + @classmethod + def setUpClass(cls): + db = cls.db = Database('sqlite', ':memory:') + + class Genre(db.Entity): + name = orm.Optional(str) # TODO primary key + artists = orm.Set('Artist') + favorite = orm.Optional(bool) + index = orm.Optional(int) + + class Hobby(db.Entity): + name = orm.Required(str) + artists = orm.Set('Artist') + + class Artist(db.Entity): + name = orm.Required(str) + age = orm.Optional(int) + hobbies = orm.Set(Hobby) + genres = orm.Set(Genre) + + db.generate_mapping(create_tables=True) + + with orm.db_session: + pop = Genre(name='pop') + rock = Genre(name='rock') + Artist(name='Sia', age=40, genres=[pop, rock]) + Artist(name='Lady GaGa', age=30, genres=[pop]) + + pony.options.INNER_JOIN_SYNTAX = True + + @db_session + def test_join_1(self): + result = select(g.id for g in self.db.Genre for a in g.artists if a.name.startswith('S'))[:] + self.assertEqual(self.db.last_sql, """SELECT DISTINCT "g"."id" +FROM "Genre" "g" + INNER JOIN "Artist_Genre" "t-1" + ON "g"."id" = "t-1"."genre" + INNER JOIN "Artist" "a" + ON "t-1"."artist" = "a"."id" +WHERE "a"."name" LIKE 'S%'""") + + @db_session + def test_join_2(self): + result = select(g.id for g in self.db.Genre for a in self.db.Artist + if JOIN(a in g.artists) and a.name.startswith('S'))[:] + self.assertEqual(self.db.last_sql, """SELECT DISTINCT "g"."id" +FROM "Genre" "g" + INNER JOIN "Artist_Genre" "t-1" + ON "g"."id" = "t-1"."genre", "Artist" "a" +WHERE "t-1"."artist" = "a"."id" + AND "a"."name" LIKE 'S%'""") + + + @db_session + def test_join_3(self): + result = select(g.id for g in self.db.Genre for x in self.db.Artist for a in self.db.Artist + if JOIN(a in g.artists) and a.name.startswith('S') and g.id == x.id)[:] + self.assertEqual(self.db.last_sql, '''SELECT DISTINCT "g"."id" +FROM "Genre" "g" + INNER JOIN "Artist_Genre" "t-1" + ON "g"."id" = "t-1"."genre", "Artist" "x", "Artist" "a" +WHERE "t-1"."artist" = "a"."id" + AND "a"."name" LIKE 'S%' + AND "g"."id" = "x"."id"''') + + @db_session + def test_join_4(self): + result = select(g.id for g in self.db.Genre for a in self.db.Artist for x in self.db.Artist + if JOIN(a in g.artists) and a.name.startswith('S') and g.id == x.id)[:] + self.assertEqual(self.db.last_sql, '''SELECT DISTINCT "g"."id" +FROM "Genre" "g" + INNER JOIN "Artist_Genre" "t-1" + ON "g"."id" = "t-1"."genre", "Artist" "a", "Artist" "x" +WHERE "t-1"."artist" = "a"."id" + AND "a"."name" LIKE 'S%' + AND "g"."id" = "x"."id"''') + +if __name__ == '__main__': + unittest.main() From c37abe3abae99867c19d4bc4f100f57ac163288c Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Tue, 11 Oct 2016 20:14:24 +0300 Subject: [PATCH 098/547] Changing the license from AGPLv3 to Apache 2.0 --- LICENSE | 838 ++++++++++++------------------------------------------- setup.py | 10 +- 2 files changed, 181 insertions(+), 667 deletions(-) diff --git a/LICENSE b/LICENSE index 2def0e883..9a2a1c6bf 100644 --- a/LICENSE +++ b/LICENSE @@ -1,661 +1,179 @@ - GNU AFFERO GENERAL PUBLIC LICENSE - Version 3, 19 November 2007 - Copyright (C) 2007 Free Software Foundation, Inc. - Everyone is permitted to copy and distribute verbatim copies - of this license document, but changing it is not allowed. - - Preamble - - The GNU Affero General Public License is a free, copyleft license for -software and other kinds of works, specifically designed to ensure -cooperation with the community in the case of network server software. - - The licenses for most software and other practical works are designed -to take away your freedom to share and change the works. By contrast, -our General Public Licenses are intended to guarantee your freedom to -share and change all versions of a program--to make sure it remains free -software for all its users. - - When we speak of free software, we are referring to freedom, not -price. Our General Public Licenses are designed to make sure that you -have the freedom to distribute copies of free software (and charge for -them if you wish), that you receive source code or can get it if you -want it, that you can change the software or use pieces of it in new -free programs, and that you know you can do these things. - - Developers that use our General Public Licenses protect your rights -with two steps: (1) assert copyright on the software, and (2) offer -you this License which gives you legal permission to copy, distribute -and/or modify the software. - - A secondary benefit of defending all users' freedom is that -improvements made in alternate versions of the program, if they -receive widespread use, become available for other developers to -incorporate. Many developers of free software are heartened and -encouraged by the resulting cooperation. However, in the case of -software used on network servers, this result may fail to come about. -The GNU General Public License permits making a modified version and -letting the public access it on a server without ever releasing its -source code to the public. - - The GNU Affero General Public License is designed specifically to -ensure that, in such cases, the modified source code becomes available -to the community. It requires the operator of a network server to -provide the source code of the modified version running there to the -users of that server. Therefore, public use of a modified version, on -a publicly accessible server, gives the public access to the source -code of the modified version. - - An older license, called the Affero General Public License and -published by Affero, was designed to accomplish similar goals. This is -a different license, not a version of the Affero GPL, but Affero has -released a new version of the Affero GPL which permits relicensing under -this license. - - The precise terms and conditions for copying, distribution and -modification follow. - - TERMS AND CONDITIONS - - 0. Definitions. - - "This License" refers to version 3 of the GNU Affero General Public License. - - "Copyright" also means copyright-like laws that apply to other kinds of -works, such as semiconductor masks. - - "The Program" refers to any copyrightable work licensed under this -License. Each licensee is addressed as "you". "Licensees" and -"recipients" may be individuals or organizations. - - To "modify" a work means to copy from or adapt all or part of the work -in a fashion requiring copyright permission, other than the making of an -exact copy. The resulting work is called a "modified version" of the -earlier work or a work "based on" the earlier work. - - A "covered work" means either the unmodified Program or a work based -on the Program. - - To "propagate" a work means to do anything with it that, without -permission, would make you directly or secondarily liable for -infringement under applicable copyright law, except executing it on a -computer or modifying a private copy. Propagation includes copying, -distribution (with or without modification), making available to the -public, and in some countries other activities as well. - - To "convey" a work means any kind of propagation that enables other -parties to make or receive copies. Mere interaction with a user through -a computer network, with no transfer of a copy, is not conveying. - - An interactive user interface displays "Appropriate Legal Notices" -to the extent that it includes a convenient and prominently visible -feature that (1) displays an appropriate copyright notice, and (2) -tells the user that there is no warranty for the work (except to the -extent that warranties are provided), that licensees may convey the -work under this License, and how to view a copy of this License. If -the interface presents a list of user commands or options, such as a -menu, a prominent item in the list meets this criterion. - - 1. Source Code. - - The "source code" for a work means the preferred form of the work -for making modifications to it. "Object code" means any non-source -form of a work. - - A "Standard Interface" means an interface that either is an official -standard defined by a recognized standards body, or, in the case of -interfaces specified for a particular programming language, one that -is widely used among developers working in that language. - - The "System Libraries" of an executable work include anything, other -than the work as a whole, that (a) is included in the normal form of -packaging a Major Component, but which is not part of that Major -Component, and (b) serves only to enable use of the work with that -Major Component, or to implement a Standard Interface for which an -implementation is available to the public in source code form. A -"Major Component", in this context, means a major essential component -(kernel, window system, and so on) of the specific operating system -(if any) on which the executable work runs, or a compiler used to -produce the work, or an object code interpreter used to run it. - - The "Corresponding Source" for a work in object code form means all -the source code needed to generate, install, and (for an executable -work) run the object code and to modify the work, including scripts to -control those activities. However, it does not include the work's -System Libraries, or general-purpose tools or generally available free -programs which are used unmodified in performing those activities but -which are not part of the work. For example, Corresponding Source -includes interface definition files associated with source files for -the work, and the source code for shared libraries and dynamically -linked subprograms that the work is specifically designed to require, -such as by intimate data communication or control flow between those -subprograms and other parts of the work. - - The Corresponding Source need not include anything that users -can regenerate automatically from other parts of the Corresponding -Source. - - The Corresponding Source for a work in source code form is that -same work. - - 2. Basic Permissions. - - All rights granted under this License are granted for the term of -copyright on the Program, and are irrevocable provided the stated -conditions are met. This License explicitly affirms your unlimited -permission to run the unmodified Program. The output from running a -covered work is covered by this License only if the output, given its -content, constitutes a covered work. This License acknowledges your -rights of fair use or other equivalent, as provided by copyright law. - - You may make, run and propagate covered works that you do not -convey, without conditions so long as your license otherwise remains -in force. You may convey covered works to others for the sole purpose -of having them make modifications exclusively for you, or provide you -with facilities for running those works, provided that you comply with -the terms of this License in conveying all material for which you do -not control copyright. Those thus making or running the covered works -for you must do so exclusively on your behalf, under your direction -and control, on terms that prohibit them from making any copies of -your copyrighted material outside their relationship with you. - - Conveying under any other circumstances is permitted solely under -the conditions stated below. Sublicensing is not allowed; section 10 -makes it unnecessary. - - 3. Protecting Users' Legal Rights From Anti-Circumvention Law. - - No covered work shall be deemed part of an effective technological -measure under any applicable law fulfilling obligations under article -11 of the WIPO copyright treaty adopted on 20 December 1996, or -similar laws prohibiting or restricting circumvention of such -measures. - - When you convey a covered work, you waive any legal power to forbid -circumvention of technological measures to the extent such circumvention -is effected by exercising rights under this License with respect to -the covered work, and you disclaim any intention to limit operation or -modification of the work as a means of enforcing, against the work's -users, your or third parties' legal rights to forbid circumvention of -technological measures. - - 4. Conveying Verbatim Copies. - - You may convey verbatim copies of the Program's source code as you -receive it, in any medium, provided that you conspicuously and -appropriately publish on each copy an appropriate copyright notice; -keep intact all notices stating that this License and any -non-permissive terms added in accord with section 7 apply to the code; -keep intact all notices of the absence of any warranty; and give all -recipients a copy of this License along with the Program. - - You may charge any price or no price for each copy that you convey, -and you may offer support or warranty protection for a fee. - - 5. Conveying Modified Source Versions. - - You may convey a work based on the Program, or the modifications to -produce it from the Program, in the form of source code under the -terms of section 4, provided that you also meet all of these conditions: - - a) The work must carry prominent notices stating that you modified - it, and giving a relevant date. - - b) The work must carry prominent notices stating that it is - released under this License and any conditions added under section - 7. This requirement modifies the requirement in section 4 to - "keep intact all notices". - - c) You must license the entire work, as a whole, under this - License to anyone who comes into possession of a copy. This - License will therefore apply, along with any applicable section 7 - additional terms, to the whole of the work, and all its parts, - regardless of how they are packaged. This License gives no - permission to license the work in any other way, but it does not - invalidate such permission if you have separately received it. - - d) If the work has interactive user interfaces, each must display - Appropriate Legal Notices; however, if the Program has interactive - interfaces that do not display Appropriate Legal Notices, your - work need not make them do so. - - A compilation of a covered work with other separate and independent -works, which are not by their nature extensions of the covered work, -and which are not combined with it such as to form a larger program, -in or on a volume of a storage or distribution medium, is called an -"aggregate" if the compilation and its resulting copyright are not -used to limit the access or legal rights of the compilation's users -beyond what the individual works permit. Inclusion of a covered work -in an aggregate does not cause this License to apply to the other -parts of the aggregate. - - 6. Conveying Non-Source Forms. - - You may convey a covered work in object code form under the terms -of sections 4 and 5, provided that you also convey the -machine-readable Corresponding Source under the terms of this License, -in one of these ways: - - a) Convey the object code in, or embodied in, a physical product - (including a physical distribution medium), accompanied by the - Corresponding Source fixed on a durable physical medium - customarily used for software interchange. - - b) Convey the object code in, or embodied in, a physical product - (including a physical distribution medium), accompanied by a - written offer, valid for at least three years and valid for as - long as you offer spare parts or customer support for that product - model, to give anyone who possesses the object code either (1) a - copy of the Corresponding Source for all the software in the - product that is covered by this License, on a durable physical - medium customarily used for software interchange, for a price no - more than your reasonable cost of physically performing this - conveying of source, or (2) access to copy the - Corresponding Source from a network server at no charge. - - c) Convey individual copies of the object code with a copy of the - written offer to provide the Corresponding Source. This - alternative is allowed only occasionally and noncommercially, and - only if you received the object code with such an offer, in accord - with subsection 6b. - - d) Convey the object code by offering access from a designated - place (gratis or for a charge), and offer equivalent access to the - Corresponding Source in the same way through the same place at no - further charge. You need not require recipients to copy the - Corresponding Source along with the object code. If the place to - copy the object code is a network server, the Corresponding Source - may be on a different server (operated by you or a third party) - that supports equivalent copying facilities, provided you maintain - clear directions next to the object code saying where to find the - Corresponding Source. Regardless of what server hosts the - Corresponding Source, you remain obligated to ensure that it is - available for as long as needed to satisfy these requirements. - - e) Convey the object code using peer-to-peer transmission, provided - you inform other peers where the object code and Corresponding - Source of the work are being offered to the general public at no - charge under subsection 6d. - - A separable portion of the object code, whose source code is excluded -from the Corresponding Source as a System Library, need not be -included in conveying the object code work. - - A "User Product" is either (1) a "consumer product", which means any -tangible personal property which is normally used for personal, family, -or household purposes, or (2) anything designed or sold for incorporation -into a dwelling. In determining whether a product is a consumer product, -doubtful cases shall be resolved in favor of coverage. For a particular -product received by a particular user, "normally used" refers to a -typical or common use of that class of product, regardless of the status -of the particular user or of the way in which the particular user -actually uses, or expects or is expected to use, the product. A product -is a consumer product regardless of whether the product has substantial -commercial, industrial or non-consumer uses, unless such uses represent -the only significant mode of use of the product. - - "Installation Information" for a User Product means any methods, -procedures, authorization keys, or other information required to install -and execute modified versions of a covered work in that User Product from -a modified version of its Corresponding Source. The information must -suffice to ensure that the continued functioning of the modified object -code is in no case prevented or interfered with solely because -modification has been made. - - If you convey an object code work under this section in, or with, or -specifically for use in, a User Product, and the conveying occurs as -part of a transaction in which the right of possession and use of the -User Product is transferred to the recipient in perpetuity or for a -fixed term (regardless of how the transaction is characterized), the -Corresponding Source conveyed under this section must be accompanied -by the Installation Information. But this requirement does not apply -if neither you nor any third party retains the ability to install -modified object code on the User Product (for example, the work has -been installed in ROM). - - The requirement to provide Installation Information does not include a -requirement to continue to provide support service, warranty, or updates -for a work that has been modified or installed by the recipient, or for -the User Product in which it has been modified or installed. Access to a -network may be denied when the modification itself materially and -adversely affects the operation of the network or violates the rules and -protocols for communication across the network. - - Corresponding Source conveyed, and Installation Information provided, -in accord with this section must be in a format that is publicly -documented (and with an implementation available to the public in -source code form), and must require no special password or key for -unpacking, reading or copying. - - 7. Additional Terms. - - "Additional permissions" are terms that supplement the terms of this -License by making exceptions from one or more of its conditions. -Additional permissions that are applicable to the entire Program shall -be treated as though they were included in this License, to the extent -that they are valid under applicable law. If additional permissions -apply only to part of the Program, that part may be used separately -under those permissions, but the entire Program remains governed by -this License without regard to the additional permissions. - - When you convey a copy of a covered work, you may at your option -remove any additional permissions from that copy, or from any part of -it. (Additional permissions may be written to require their own -removal in certain cases when you modify the work.) You may place -additional permissions on material, added by you to a covered work, -for which you have or can give appropriate copyright permission. - - Notwithstanding any other provision of this License, for material you -add to a covered work, you may (if authorized by the copyright holders of -that material) supplement the terms of this License with terms: - - a) Disclaiming warranty or limiting liability differently from the - terms of sections 15 and 16 of this License; or - - b) Requiring preservation of specified reasonable legal notices or - author attributions in that material or in the Appropriate Legal - Notices displayed by works containing it; or - - c) Prohibiting misrepresentation of the origin of that material, or - requiring that modified versions of such material be marked in - reasonable ways as different from the original version; or - - d) Limiting the use for publicity purposes of names of licensors or - authors of the material; or - - e) Declining to grant rights under trademark law for use of some - trade names, trademarks, or service marks; or - - f) Requiring indemnification of licensors and authors of that - material by anyone who conveys the material (or modified versions of - it) with contractual assumptions of liability to the recipient, for - any liability that these contractual assumptions directly impose on - those licensors and authors. - - All other non-permissive additional terms are considered "further -restrictions" within the meaning of section 10. If the Program as you -received it, or any part of it, contains a notice stating that it is -governed by this License along with a term that is a further -restriction, you may remove that term. If a license document contains -a further restriction but permits relicensing or conveying under this -License, you may add to a covered work material governed by the terms -of that license document, provided that the further restriction does -not survive such relicensing or conveying. - - If you add terms to a covered work in accord with this section, you -must place, in the relevant source files, a statement of the -additional terms that apply to those files, or a notice indicating -where to find the applicable terms. - - Additional terms, permissive or non-permissive, may be stated in the -form of a separately written license, or stated as exceptions; -the above requirements apply either way. - - 8. Termination. - - You may not propagate or modify a covered work except as expressly -provided under this License. Any attempt otherwise to propagate or -modify it is void, and will automatically terminate your rights under -this License (including any patent licenses granted under the third -paragraph of section 11). - - However, if you cease all violation of this License, then your -license from a particular copyright holder is reinstated (a) -provisionally, unless and until the copyright holder explicitly and -finally terminates your license, and (b) permanently, if the copyright -holder fails to notify you of the violation by some reasonable means -prior to 60 days after the cessation. - - Moreover, your license from a particular copyright holder is -reinstated permanently if the copyright holder notifies you of the -violation by some reasonable means, this is the first time you have -received notice of violation of this License (for any work) from that -copyright holder, and you cure the violation prior to 30 days after -your receipt of the notice. - - Termination of your rights under this section does not terminate the -licenses of parties who have received copies or rights from you under -this License. If your rights have been terminated and not permanently -reinstated, you do not qualify to receive new licenses for the same -material under section 10. - - 9. Acceptance Not Required for Having Copies. - - You are not required to accept this License in order to receive or -run a copy of the Program. Ancillary propagation of a covered work -occurring solely as a consequence of using peer-to-peer transmission -to receive a copy likewise does not require acceptance. However, -nothing other than this License grants you permission to propagate or -modify any covered work. These actions infringe copyright if you do -not accept this License. Therefore, by modifying or propagating a -covered work, you indicate your acceptance of this License to do so. - - 10. Automatic Licensing of Downstream Recipients. - - Each time you convey a covered work, the recipient automatically -receives a license from the original licensors, to run, modify and -propagate that work, subject to this License. You are not responsible -for enforcing compliance by third parties with this License. - - An "entity transaction" is a transaction transferring control of an -organization, or substantially all assets of one, or subdividing an -organization, or merging organizations. If propagation of a covered -work results from an entity transaction, each party to that -transaction who receives a copy of the work also receives whatever -licenses to the work the party's predecessor in interest had or could -give under the previous paragraph, plus a right to possession of the -Corresponding Source of the work from the predecessor in interest, if -the predecessor has it or can get it with reasonable efforts. - - You may not impose any further restrictions on the exercise of the -rights granted or affirmed under this License. For example, you may -not impose a license fee, royalty, or other charge for exercise of -rights granted under this License, and you may not initiate litigation -(including a cross-claim or counterclaim in a lawsuit) alleging that -any patent claim is infringed by making, using, selling, offering for -sale, or importing the Program or any portion of it. - - 11. Patents. - - A "contributor" is a copyright holder who authorizes use under this -License of the Program or a work on which the Program is based. The -work thus licensed is called the contributor's "contributor version". - - A contributor's "essential patent claims" are all patent claims -owned or controlled by the contributor, whether already acquired or -hereafter acquired, that would be infringed by some manner, permitted -by this License, of making, using, or selling its contributor version, -but do not include claims that would be infringed only as a -consequence of further modification of the contributor version. For -purposes of this definition, "control" includes the right to grant -patent sublicenses in a manner consistent with the requirements of -this License. - - Each contributor grants you a non-exclusive, worldwide, royalty-free -patent license under the contributor's essential patent claims, to -make, use, sell, offer for sale, import and otherwise run, modify and -propagate the contents of its contributor version. - - In the following three paragraphs, a "patent license" is any express -agreement or commitment, however denominated, not to enforce a patent -(such as an express permission to practice a patent or covenant not to -sue for patent infringement). To "grant" such a patent license to a -party means to make such an agreement or commitment not to enforce a -patent against the party. - - If you convey a covered work, knowingly relying on a patent license, -and the Corresponding Source of the work is not available for anyone -to copy, free of charge and under the terms of this License, through a -publicly available network server or other readily accessible means, -then you must either (1) cause the Corresponding Source to be so -available, or (2) arrange to deprive yourself of the benefit of the -patent license for this particular work, or (3) arrange, in a manner -consistent with the requirements of this License, to extend the patent -license to downstream recipients. "Knowingly relying" means you have -actual knowledge that, but for the patent license, your conveying the -covered work in a country, or your recipient's use of the covered work -in a country, would infringe one or more identifiable patents in that -country that you have reason to believe are valid. - - If, pursuant to or in connection with a single transaction or -arrangement, you convey, or propagate by procuring conveyance of, a -covered work, and grant a patent license to some of the parties -receiving the covered work authorizing them to use, propagate, modify -or convey a specific copy of the covered work, then the patent license -you grant is automatically extended to all recipients of the covered -work and works based on it. - - A patent license is "discriminatory" if it does not include within -the scope of its coverage, prohibits the exercise of, or is -conditioned on the non-exercise of one or more of the rights that are -specifically granted under this License. You may not convey a covered -work if you are a party to an arrangement with a third party that is -in the business of distributing software, under which you make payment -to the third party based on the extent of your activity of conveying -the work, and under which the third party grants, to any of the -parties who would receive the covered work from you, a discriminatory -patent license (a) in connection with copies of the covered work -conveyed by you (or copies made from those copies), or (b) primarily -for and in connection with specific products or compilations that -contain the covered work, unless you entered into that arrangement, -or that patent license was granted, prior to 28 March 2007. - - Nothing in this License shall be construed as excluding or limiting -any implied license or other defenses to infringement that may -otherwise be available to you under applicable patent law. - - 12. No Surrender of Others' Freedom. - - If conditions are imposed on you (whether by court order, agreement or -otherwise) that contradict the conditions of this License, they do not -excuse you from the conditions of this License. If you cannot convey a -covered work so as to satisfy simultaneously your obligations under this -License and any other pertinent obligations, then as a consequence you may -not convey it at all. For example, if you agree to terms that obligate you -to collect a royalty for further conveying from those to whom you convey -the Program, the only way you could satisfy both those terms and this -License would be to refrain entirely from conveying the Program. - - 13. Remote Network Interaction; Use with the GNU General Public License. - - Notwithstanding any other provision of this License, if you modify the -Program, your modified version must prominently offer all users -interacting with it remotely through a computer network (if your version -supports such interaction) an opportunity to receive the Corresponding -Source of your version by providing access to the Corresponding Source -from a network server at no charge, through some standard or customary -means of facilitating copying of software. This Corresponding Source -shall include the Corresponding Source for any work covered by version 3 -of the GNU General Public License that is incorporated pursuant to the -following paragraph. - - Notwithstanding any other provision of this License, you have -permission to link or combine any covered work with a work licensed -under version 3 of the GNU General Public License into a single -combined work, and to convey the resulting work. The terms of this -License will continue to apply to the part which is the covered work, -but the work with which it is combined will remain governed by version -3 of the GNU General Public License. - - 14. Revised Versions of this License. - - The Free Software Foundation may publish revised and/or new versions of -the GNU Affero General Public License from time to time. Such new versions -will be similar in spirit to the present version, but may differ in detail to -address new problems or concerns. - - Each version is given a distinguishing version number. If the -Program specifies that a certain numbered version of the GNU Affero General -Public License "or any later version" applies to it, you have the -option of following the terms and conditions either of that numbered -version or of any later version published by the Free Software -Foundation. If the Program does not specify a version number of the -GNU Affero General Public License, you may choose any version ever published -by the Free Software Foundation. - - If the Program specifies that a proxy can decide which future -versions of the GNU Affero General Public License can be used, that proxy's -public statement of acceptance of a version permanently authorizes you -to choose that version for the Program. - - Later license versions may give you additional or different -permissions. However, no additional obligations are imposed on any -author or copyright holder as a result of your choosing to follow a -later version. - - 15. Disclaimer of Warranty. - - THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY -APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT -HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY -OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, -THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM -IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF -ALL NECESSARY SERVICING, REPAIR OR CORRECTION. - - 16. Limitation of Liability. - - IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING -WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS -THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY -GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE -USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF -DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD -PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), -EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF -SUCH DAMAGES. - - 17. Interpretation of Sections 15 and 16. - - If the disclaimer of warranty and limitation of liability provided -above cannot be given local legal effect according to their terms, -reviewing courts shall apply local law that most closely approximates -an absolute waiver of all civil liability in connection with the -Program, unless a warranty or assumption of liability accompanies a -copy of the Program in return for a fee. - - END OF TERMS AND CONDITIONS - - How to Apply These Terms to Your New Programs - - If you develop a new program, and you want it to be of the greatest -possible use to the public, the best way to achieve this is to make it -free software which everyone can redistribute and change under these terms. - - To do so, attach the following notices to the program. It is safest -to attach them to the start of each source file to most effectively -state the exclusion of warranty; and each file should have at least -the "copyright" line and a pointer to where the full notice is found. - - - Copyright (C) - - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. - - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see . - -Also add information on how to contact you by electronic and paper mail. - - If your software can interact with users remotely through a computer -network, you should also make sure that it provides a way for users to -get its source. For example, if your program is a web application, its -interface could display a "Source" link that leads users to an archive -of the code. There are many ways you could offer source, and different -solutions will be better for different programs; see section 13 for the -specific requirements. - - You should also get your employer (if you work as a programmer) or school, -if any, to sign a "copyright disclaimer" for the program, if necessary. -For more information on this, and how to apply and follow the GNU AGPL, see -. \ No newline at end of file + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Copyright 2016 Alexander Kozlovsky, Alexey Malashkevich diff --git a/setup.py b/setup.py index 3ab1a3104..e3cc03145 100644 --- a/setup.py +++ b/setup.py @@ -54,11 +54,7 @@ classifiers = [ 'Development Status :: 4 - Beta', 'Intended Audience :: Developers', - 'License :: Free for non-commercial use', - 'License :: OSI Approved :: GNU Affero General Public License v3', - 'License :: Other/Proprietary License', - 'License :: Free For Educational Use', - 'License :: Free for non-commercial use', + 'License :: OSI Approved :: Apache Software License', 'Operating System :: OS Independent', 'Programming Language :: Python', 'Programming Language :: Python :: 2', @@ -74,7 +70,7 @@ author = "Alexander Kozlovsky, Alexey Malashkevich" author_email = "team@ponyorm.com" url = "https://ponyorm.com" -lic = "AGPL, Commercial, Free for educational and non-commercial use" +licence = "Apache License Version 2.0" packages = [ "pony", @@ -107,7 +103,7 @@ author=author, author_email=author_email, url=url, - license=lic, + license=licence, packages=packages, download_url=download_url ) From bec5ff74df688daa4bbb2af26c44ce389903e32b Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Tue, 11 Oct 2016 21:47:48 +0300 Subject: [PATCH 099/547] Release 0.7 change log --- CHANGELOG.md | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b1b434124..d32da2a7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,22 @@ +# Pony ORM Release 0.7 (2016-10-11) + +Starting with this release Pony ORM is release under the Apache License, Version 2.0. + +## New features + +* Added getattr() support in queries: https://docs.ponyorm.com/api_reference.html#getattr + +## Backward incompatible changes + +* #159: exceptions happened during flush() should not be wrapped with CommitException + +Before this release an exception that happened in a hook(https://docs.ponyorm.com/api_reference.html#entity-hooks), could be raised in two ways - either wrapped into the CommitException or without wrapping. It depended if the exception happened during the execution of flush() or commit() function on the db_session exit. Now the exception happened inside the hook never will be wrapped into the CommitException. + +## Bugfixes + +* #190: Timedelta is not supported when using pymysql + + # Pony ORM Release 0.6.6 (2016-08-22) ## New features From 485dcb8363f4bd69d06b91a03af4c2abcfa6da8c Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 11 Oct 2016 20:44:25 +0300 Subject: [PATCH 100/547] Update Pony version: 0.6.7-dev -> 0.7 --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index 384388bdb..3135937fb 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.6.7-dev' +__version__ = '0.7' uid = str(random.randint(1, 1000000)) From 23d4c849c8f9594d9dbfc4b73fb815ef42dadb0d Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Tue, 25 Oct 2016 13:30:05 +0300 Subject: [PATCH 101/547] Update README.md Fixes #206 --- README.md | 42 +++++++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index e7fd3ee59..eff8c8d06 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,50 @@ Pony Object-Relational Mapper -================================== +============================= -Pony is an object-relational mapper. The most interesting feature of Pony is its ability to write queries to the database using generator expressions. Pony works with entities which are mapped to a SQL database. Using generator syntax for writing queries allows the user to formulate very eloquent queries. It increases the level of abstraction and allows a programmer to concentrate on the business logic of the application. For this purpose Pony analyzes the abstract syntax tree of a generator and translates it to its SQL equivalent. +Pony is an advanced object-relational mapper. The most interesting feature of Pony is its ability to write queries to the database using Python generator expressions. Pony analyzes the abstract syntax tree of the generator expression and translates it to into SQL query. -Following is an example of a query in Pony: +Here is the example of a query in Pony: select(p for p in Product if p.name.startswith('A') and p.cost <= 1000) Pony translates queries to SQL using a specific database dialect. Currently Pony works with SQLite, MySQL, PostgreSQL and Oracle databases. -Pony ORM also include the ER Diagram Editor which is a great tool for prototyping. You can create your ER diagram online at [https://editor.ponyorm.com](https://editor.ponyorm.com), generate the database schema based on the diagram and start working with the database using declarative queries in seconds. +By providing a Pythonic API, Pony facilitates fast app development. Pony is an easy-to-learn and easy-to-use library. It makes your work more productive and helps saving resources. Pony achieves the easiness of use through the following: + +* Compact entity definitions +* Concise query language +* Ability to work with Pony interactively in Python interpreter +* Comprehensive error messages, showing the exact part where error happened in the query +* Displaying the generated SQL in readable format with indentation + +All this helps the developer to focus on implementing the business logic of an application, instead of struggling with a mapper trying to understand how to get the data from the database. + +See the example [here](https://github.com/ponyorm/pony/blob/orm/pony/orm/examples/estore.py) + + +Online tool for database design +------------------------------- + +Pony ORM also has the Entity-Relationship Diagram Editor which is a great tool for prototyping. You can create your database diagram online at [https://editor.ponyorm.com](https://editor.ponyorm.com), generate the database schema based on the diagram and start working with the database using declarative queries in seconds. -The package pony.orm.examples contains several examples. Documenation is available at [https://docs.ponyorm.com](https://docs.ponyorm.com) The documentation source is avaliable at [https://github.com/ponyorm/pony-doc](https://github.com/ponyorm/pony-doc), it is released under Apache 2.0 license. Please create new documentation related issues [https://github.com/ponyorm/pony-doc/issues](https://github.com/ponyorm/pony-doc/issues) or make a pull request with your improvements. -We are looking forward to your comments and suggestions at our mailing list [http://ponyorm-list.ponyorm.com](http://ponyorm-list.ponyorm.com) License ------------ +------- + +Pony ORM is released under the Apache 2.0 license. + -Pony ORM is released under multiple licenses, check [ponyorm.com](https://ponyorm.com/license-and-pricing.html) for more information. +PonyORM community +----------------- -Copyright (c) 2016 Pony ORM, LLC. All rights reserved. -team (at) ponyorm.com +Please post your questions on [Stack Overflow](http://stackoverflow.com/questions/tagged/ponyorm). +Meet the PonyORM team, chat with the community members, and get your questions answered on our community [Telegram group](https://telegram.me/ponyorm). +Join our newsletter at [ponyorm.com](https://ponyorm.com). +Reach us on [Twitter](https://twitter.com/ponyorm). -Please send your questions, comments and suggestions to our mailing list [http://ponyorm-list.ponyorm.com](http://ponyorm-list.ponyorm.com) +Copyright (c) 2016 Pony ORM, LLC. All rights reserved. team (at) ponyorm.com From eff28b900942404c3e8a2cbae511ce95e8bbc6ba Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Tue, 25 Oct 2016 13:30:05 +0300 Subject: [PATCH 102/547] Update README.md Fixes #206 --- README.md | 42 +++++++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index e7fd3ee59..eff8c8d06 100644 --- a/README.md +++ b/README.md @@ -1,30 +1,50 @@ Pony Object-Relational Mapper -================================== +============================= -Pony is an object-relational mapper. The most interesting feature of Pony is its ability to write queries to the database using generator expressions. Pony works with entities which are mapped to a SQL database. Using generator syntax for writing queries allows the user to formulate very eloquent queries. It increases the level of abstraction and allows a programmer to concentrate on the business logic of the application. For this purpose Pony analyzes the abstract syntax tree of a generator and translates it to its SQL equivalent. +Pony is an advanced object-relational mapper. The most interesting feature of Pony is its ability to write queries to the database using Python generator expressions. Pony analyzes the abstract syntax tree of the generator expression and translates it to into SQL query. -Following is an example of a query in Pony: +Here is the example of a query in Pony: select(p for p in Product if p.name.startswith('A') and p.cost <= 1000) Pony translates queries to SQL using a specific database dialect. Currently Pony works with SQLite, MySQL, PostgreSQL and Oracle databases. -Pony ORM also include the ER Diagram Editor which is a great tool for prototyping. You can create your ER diagram online at [https://editor.ponyorm.com](https://editor.ponyorm.com), generate the database schema based on the diagram and start working with the database using declarative queries in seconds. +By providing a Pythonic API, Pony facilitates fast app development. Pony is an easy-to-learn and easy-to-use library. It makes your work more productive and helps saving resources. Pony achieves the easiness of use through the following: + +* Compact entity definitions +* Concise query language +* Ability to work with Pony interactively in Python interpreter +* Comprehensive error messages, showing the exact part where error happened in the query +* Displaying the generated SQL in readable format with indentation + +All this helps the developer to focus on implementing the business logic of an application, instead of struggling with a mapper trying to understand how to get the data from the database. + +See the example [here](https://github.com/ponyorm/pony/blob/orm/pony/orm/examples/estore.py) + + +Online tool for database design +------------------------------- + +Pony ORM also has the Entity-Relationship Diagram Editor which is a great tool for prototyping. You can create your database diagram online at [https://editor.ponyorm.com](https://editor.ponyorm.com), generate the database schema based on the diagram and start working with the database using declarative queries in seconds. -The package pony.orm.examples contains several examples. Documenation is available at [https://docs.ponyorm.com](https://docs.ponyorm.com) The documentation source is avaliable at [https://github.com/ponyorm/pony-doc](https://github.com/ponyorm/pony-doc), it is released under Apache 2.0 license. Please create new documentation related issues [https://github.com/ponyorm/pony-doc/issues](https://github.com/ponyorm/pony-doc/issues) or make a pull request with your improvements. -We are looking forward to your comments and suggestions at our mailing list [http://ponyorm-list.ponyorm.com](http://ponyorm-list.ponyorm.com) License ------------ +------- + +Pony ORM is released under the Apache 2.0 license. + -Pony ORM is released under multiple licenses, check [ponyorm.com](https://ponyorm.com/license-and-pricing.html) for more information. +PonyORM community +----------------- -Copyright (c) 2016 Pony ORM, LLC. All rights reserved. -team (at) ponyorm.com +Please post your questions on [Stack Overflow](http://stackoverflow.com/questions/tagged/ponyorm). +Meet the PonyORM team, chat with the community members, and get your questions answered on our community [Telegram group](https://telegram.me/ponyorm). +Join our newsletter at [ponyorm.com](https://ponyorm.com). +Reach us on [Twitter](https://twitter.com/ponyorm). -Please send your questions, comments and suggestions to our mailing list [http://ponyorm-list.ponyorm.com](http://ponyorm-list.ponyorm.com) +Copyright (c) 2016 Pony ORM, LLC. All rights reserved. team (at) ponyorm.com From 41c498a844339de9bd3dc3de071a917122672187 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 25 Oct 2016 19:15:33 +0300 Subject: [PATCH 103/547] Change Pony version: 0.7 -> 0.7.1-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index 3135937fb..77376cb2e 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7' +__version__ = '0.7.1-dev' uid = str(random.randint(1, 1000000)) From 54275da79a7165a9b867fec165b9457ff85ecac8 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 25 Oct 2016 17:04:01 +0300 Subject: [PATCH 104/547] Fixes #203: subtranslator should use argnames from parent translator --- pony/orm/sqltranslation.py | 11 +++++++---- pony/orm/tests/test_declarative_sqltranslator.py | 12 ++++++++++++ 2 files changed, 19 insertions(+), 4 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 3f9cbdd69..be33f9880 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -674,10 +674,13 @@ def postTuple(translator, node): return translator.ListMonad(translator, [ item.monad for item in node.nodes ]) def postName(translator, node): name = node.name - argnames = translator.argnames - if translator.argnames and name in translator.argnames: - i = translator.argnames.index(name) - return translator.expr_monads[i] + t = translator + while t is not None: + argnames = t.argnames + if argnames is not None and name in argnames: + i = argnames.index(name) + return t.expr_monads[i] + t = t.parent tableref = translator.subquery.get_tableref(name) if tableref is not None: return translator.ObjectIterMonad(translator, tableref, tableref.entity) diff --git a/pony/orm/tests/test_declarative_sqltranslator.py b/pony/orm/tests/test_declarative_sqltranslator.py index c28832374..33aff08de 100644 --- a/pony/orm/tests/test_declarative_sqltranslator.py +++ b/pony/orm/tests/test_declarative_sqltranslator.py @@ -349,6 +349,18 @@ def f(x): return x def test_method_monad(self): result = set(select(s for s in Student if s not in Student.select(lambda s: s.scholarship > 0))) self.assertEqual(result, {Student[1]}) + def test_lambda_1(self): + q = select(s for s in Student) + q = q.filter(lambda s: s.name == 'S1') + self.assertEqual(list(q), [Student[1]]) + def test_lambda_2(self): + q = select(s for s in Student) + q = q.filter(lambda stud: stud.name == 'S1') + self.assertEqual(list(q), [Student[1]]) + def test_lambda_3(self): + q = select(s for s in Student) + q = q.filter(lambda stud: exists(x for x in Student if stud.name < x.name)) + self.assertEqual(set(q), {Student[1], Student[2]}) if __name__ == "__main__": From ca3113c97a0bcd225a0001937366243e817d551f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 25 Oct 2016 18:09:37 +0300 Subject: [PATCH 105/547] Change a way aliases in SQL query are generated in order to fix a problem when a subquery alias mask a base query alias (see test_lambda_4 in test_declarative_sqltranslator.py) --- pony/orm/sqltranslation.py | 34 +++++++++---------- .../test_declarative_join_optimization.py | 2 +- .../tests/test_declarative_sqltranslator.py | 4 +++ .../tests/test_declarative_sqltranslator2.py | 4 +-- pony/orm/tests/test_relations_one2one3.py | 18 +++++----- 5 files changed, 32 insertions(+), 30 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index be33f9880..37ceb3ec8 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -843,7 +843,7 @@ def __init__(subquery, parent_subquery=None, left_join=False): subquery.alias_counters = {} subquery.expr_counter = itertools.count(1) else: - subquery.alias_counters = parent_subquery.alias_counters.copy() + subquery.alias_counters = parent_subquery.alias_counters subquery.expr_counter = parent_subquery.expr_counter subquery.used_from_subquery = False def get_tableref(subquery, name_path, from_subquery=False): @@ -862,14 +862,10 @@ def add_tableref(subquery, name_path, parent_tableref, attr): tableref = JoinedTableRef(subquery, name_path, parent_tableref, attr) tablerefs[name_path] = tableref return tableref - def get_short_alias(subquery, name_path, entity_name): - if name_path: - if is_ident(name_path): return name_path - if not options.SIMPLE_ALIASES and len(name_path) <= max_alias_length: - return name_path - name = entity_name[:max_alias_length-3].lower() + def make_alias(subquery, name): + name = name[:max_alias_length-3].lower() i = subquery.alias_counters.setdefault(name, 0) + 1 - alias = '%s-%d' % (name, i) + alias = name if i == 1 and name != 't' else '%s-%d' % (name, i) subquery.alias_counters[name] = i return alias def join_table(subquery, parent_alias, alias, table_name, join_cond): @@ -886,7 +882,8 @@ def join_table(subquery, parent_alias, alias, table_name, join_cond): class TableRef(object): def __init__(tableref, subquery, name, entity): tableref.subquery = subquery - tableref.alias = tableref.name_path = name + tableref.name_path = name + tableref.alias = subquery.make_alias(name) tableref.entity = entity tableref.joined = False tableref.can_affect_distinct = True @@ -907,6 +904,7 @@ class JoinedTableRef(object): def __init__(tableref, subquery, name_path, parent_tableref, attr): tableref.subquery = subquery tableref.name_path = name_path + tableref.var_name = name_path if is_ident(name_path) else None tableref.alias = None tableref.optimized = None tableref.parent_tableref = parent_tableref @@ -933,7 +931,7 @@ def make_join(tableref, pk_only=False): assert reverse.columns and not reverse.is_collection rentity = reverse.entity pk_columns = rentity._pk_columns_ - alias = subquery.get_short_alias(tableref.name_path, rentity.__name__) + alias = subquery.make_alias(tableref.var_name or rentity.__name__) join_cond = join_tables(parent_alias, alias, left_pk_columns, reverse.columns) else: if attr.pk_offset is not None: @@ -946,16 +944,16 @@ def make_join(tableref, pk_only=False): tableref.optimized = True tableref.joined = True return parent_alias, left_columns - alias = subquery.get_short_alias(tableref.name_path, entity.__name__) + alias = subquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(parent_alias, alias, left_columns, pk_columns) elif not attr.reverse.is_collection: - alias = subquery.get_short_alias(tableref.name_path, entity.__name__) + alias = subquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(parent_alias, alias, left_pk_columns, attr.reverse.columns) else: right_m2m_columns = attr.reverse_columns if attr.symmetric else attr.columns if not tableref.joined: m2m_table = attr.table - m2m_alias = subquery.get_short_alias(None, 't') + m2m_alias = subquery.make_alias('t') reverse_columns = attr.columns if attr.symmetric else attr.reverse.columns m2m_join_cond = join_tables(parent_alias, m2m_alias, left_pk_columns, reverse_columns) subquery.join_table(parent_alias, m2m_alias, m2m_table, m2m_join_cond) @@ -968,7 +966,7 @@ def make_join(tableref, pk_only=False): elif tableref.optimized: assert not pk_only m2m_alias = tableref.alias - alias = subquery.get_short_alias(tableref.name_path, entity.__name__) + alias = subquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(m2m_alias, alias, right_m2m_columns, pk_columns) if not pk_only and entity._discriminator_attr_: discr_criteria = entity._construct_discriminator_criteria_(alias) @@ -2422,7 +2420,7 @@ def _joined_subselect(monad, make_aggr, extra_grouping=False, coalesce_to_zero=F assert len({alias for _, alias, column in groupby_columns}) == 1 if extra_grouping: - inner_alias = translator.subquery.get_short_alias(None, 't') + inner_alias = translator.subquery.make_alias('t') inner_columns = [ 'DISTINCT' ] col_mapping = {} col_names = set() @@ -2462,7 +2460,7 @@ def _joined_subselect(monad, make_aggr, extra_grouping=False, coalesce_to_zero=F subquery_ast.append([ 'WHERE' ] + inner_conditions) subquery_ast.append([ 'GROUP_BY' ] + groupby_columns) - alias = translator.subquery.get_short_alias(None, 't') + alias = translator.subquery.make_alias('t') for cond in outer_conditions: cond[2][1] = alias translator.subquery.from_ast.append([ alias, 'SELECT', subquery_ast, sqland(outer_conditions) ]) expr_ast = [ 'COLUMN', alias, expr_name ] @@ -2603,7 +2601,7 @@ def contains(monad, item, not_in=False): new_names.append(new_name) select_ast[i] = [ 'AS', column_ast, new_name ] - alias = subquery.get_short_alias(None, 't') + alias = subquery.make_alias('t') outer_conditions = [ [ 'EQ', item_column, [ 'COLUMN', alias, new_name ] ] for item_column, new_name in izip(item_columns, new_names) ] subquery.from_ast.append([ alias, 'SELECT', subquery_ast[1:], sqland(outer_conditions) ]) @@ -2666,7 +2664,7 @@ def count(monad): [ 'AGGREGATES', [ 'COUNT', 'DISTINCT', [ 'COLUMN', alias, 'ROWID' ] ] ], from_ast, where_ast ] else: - alias = translator.subquery.get_short_alias(None, 't') + alias = translator.subquery.make_alias('t') sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', 'ALL' ] ], [ 'FROM', [ alias, 'SELECT', [ [ 'DISTINCT' ] + sub.expr_columns, from_ast, where_ast ] ] ] ] diff --git a/pony/orm/tests/test_declarative_join_optimization.py b/pony/orm/tests/test_declarative_join_optimization.py index 89c14de73..3427e060c 100644 --- a/pony/orm/tests/test_declarative_join_optimization.py +++ b/pony/orm/tests/test_declarative_join_optimization.py @@ -69,7 +69,7 @@ def test7(self): q = select(s for s in Student if sum(c.credits for c in Course if s.group.dept == c.dept) > 10) objects = q[:] self.assertEqual(str(q._translator.subquery.from_ast), - "['FROM', ['s', 'TABLE', 'Student'], ['group-1', 'TABLE', 'Group', ['EQ', ['COLUMN', 's', 'group'], ['COLUMN', 'group-1', 'number']]]]") + "['FROM', ['s', 'TABLE', 'Student'], ['group', 'TABLE', 'Group', ['EQ', ['COLUMN', 's', 'group'], ['COLUMN', 'group', 'number']]]]") if __name__ == '__main__': diff --git a/pony/orm/tests/test_declarative_sqltranslator.py b/pony/orm/tests/test_declarative_sqltranslator.py index 33aff08de..bc7134908 100644 --- a/pony/orm/tests/test_declarative_sqltranslator.py +++ b/pony/orm/tests/test_declarative_sqltranslator.py @@ -361,6 +361,10 @@ def test_lambda_3(self): q = select(s for s in Student) q = q.filter(lambda stud: exists(x for x in Student if stud.name < x.name)) self.assertEqual(set(q), {Student[1], Student[2]}) + def test_lambda_4(self): + q = select(s for s in Student) + q = q.filter(lambda stud: exists(s for s in Student if stud.name < s.name)) + self.assertEqual(set(q), {Student[1], Student[2]}) if __name__ == "__main__": diff --git a/pony/orm/tests/test_declarative_sqltranslator2.py b/pony/orm/tests/test_declarative_sqltranslator2.py index 607862c90..87ea5000a 100644 --- a/pony/orm/tests/test_declarative_sqltranslator2.py +++ b/pony/orm/tests/test_declarative_sqltranslator2.py @@ -114,12 +114,12 @@ def test_distinct6(self): self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]}) def test_not_null1(self): q = select(g for g in Group if '123-45-67' not in g.students.tel and g.dept == Department[1]) - not_null = "IS_NOT_NULL COLUMN student-1 tel" in (" ".join(str(i) for i in flatten(q._translator.conditions))) + not_null = "IS_NOT_NULL COLUMN student tel" in (" ".join(str(i) for i in flatten(q._translator.conditions))) self.assertEqual(not_null, True) self.assertEqual(q[:], [Group[101]]) def test_not_null2(self): q = select(g for g in Group if 'John' not in g.students.name and g.dept == Department[1]) - not_null = "IS_NOT_NULL COLUMN student-1 name" in (" ".join(str(i) for i in flatten(q._translator.conditions))) + not_null = "IS_NOT_NULL COLUMN student name" in (" ".join(str(i) for i in flatten(q._translator.conditions))) self.assertEqual(not_null, False) self.assertEqual(q[:], [Group[101]]) def test_chain_of_attrs_inside_for1(self): diff --git a/pony/orm/tests/test_relations_one2one3.py b/pony/orm/tests/test_relations_one2one3.py index 2419267f7..dffddb5ce 100644 --- a/pony/orm/tests/test_relations_one2one3.py +++ b/pony/orm/tests/test_relations_one2one3.py @@ -38,9 +38,9 @@ def test_2(self): sql = self.db.last_sql self.assertEqual(sql, '''SELECT "p"."id", "p"."name" FROM "Person" "p" - LEFT JOIN "Passport" "passport-1" - ON "p"."id" = "passport-1"."person" -WHERE "passport-1"."id" IS NULL''') + LEFT JOIN "Passport" "passport" + ON "p"."id" = "passport"."person" +WHERE "passport"."id" IS NULL''') @db_session def test_3(self): @@ -48,9 +48,9 @@ def test_3(self): sql = self.db.last_sql self.assertEqual(sql, '''SELECT "p"."id", "p"."name" FROM "Person" "p" - LEFT JOIN "Passport" "passport-1" - ON "p"."id" = "passport-1"."person" -WHERE "passport-1"."id" IS NULL''') + LEFT JOIN "Passport" "passport" + ON "p"."id" = "passport"."person" +WHERE "passport"."id" IS NULL''') @db_session def test_4(self): @@ -58,9 +58,9 @@ def test_4(self): sql = self.db.last_sql self.assertEqual(sql, '''SELECT "p"."id", "p"."name" FROM "Person" "p" - LEFT JOIN "Passport" "passport-1" - ON "p"."id" = "passport-1"."person" -WHERE "passport-1"."id" IS NOT NULL''') + LEFT JOIN "Passport" "passport" + ON "p"."id" = "passport"."person" +WHERE "passport"."id" IS NOT NULL''') @db_session def test_5(self): From bda71ef9de1306bc246ab4e8026f7d09e3939bb6 Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Thu, 27 Oct 2016 19:58:37 +0300 Subject: [PATCH 106/547] Update README.md --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index eff8c8d06..ff58d1abb 100644 --- a/README.md +++ b/README.md @@ -28,9 +28,12 @@ Online tool for database design Pony ORM also has the Entity-Relationship Diagram Editor which is a great tool for prototyping. You can create your database diagram online at [https://editor.ponyorm.com](https://editor.ponyorm.com), generate the database schema based on the diagram and start working with the database using declarative queries in seconds. +Documentation +------------- + Documenation is available at [https://docs.ponyorm.com](https://docs.ponyorm.com) -The documentation source is avaliable at [https://github.com/ponyorm/pony-doc](https://github.com/ponyorm/pony-doc), it is released under Apache 2.0 license. -Please create new documentation related issues [https://github.com/ponyorm/pony-doc/issues](https://github.com/ponyorm/pony-doc/issues) or make a pull request with your improvements. +The documentation source is avaliable at [https://github.com/ponyorm/pony-doc](https://github.com/ponyorm/pony-doc). +Please create new documentation related issues [here](https://github.com/ponyorm/pony-doc/issues) or make a pull request with your improvements. License From 6948a8d5d6ad4660ee5f6b1ab053b073d14beef9 Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Thu, 27 Oct 2016 19:58:37 +0300 Subject: [PATCH 107/547] Update README.md --- README.md | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index eff8c8d06..ff58d1abb 100644 --- a/README.md +++ b/README.md @@ -28,9 +28,12 @@ Online tool for database design Pony ORM also has the Entity-Relationship Diagram Editor which is a great tool for prototyping. You can create your database diagram online at [https://editor.ponyorm.com](https://editor.ponyorm.com), generate the database schema based on the diagram and start working with the database using declarative queries in seconds. +Documentation +------------- + Documenation is available at [https://docs.ponyorm.com](https://docs.ponyorm.com) -The documentation source is avaliable at [https://github.com/ponyorm/pony-doc](https://github.com/ponyorm/pony-doc), it is released under Apache 2.0 license. -Please create new documentation related issues [https://github.com/ponyorm/pony-doc/issues](https://github.com/ponyorm/pony-doc/issues) or make a pull request with your improvements. +The documentation source is avaliable at [https://github.com/ponyorm/pony-doc](https://github.com/ponyorm/pony-doc). +Please create new documentation related issues [here](https://github.com/ponyorm/pony-doc/issues) or make a pull request with your improvements. License From 0cefc93323eaf9b2aedcadc23179f06ecf0e70fe Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 1 Dec 2016 06:32:33 +0300 Subject: [PATCH 108/547] Volatile attribute bug fixed --- pony/orm/core.py | 6 ++--- pony/orm/tests/test_volatile.py | 39 +++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 3 deletions(-) create mode 100644 pony/orm/tests/test_volatile.py diff --git a/pony/orm/core.py b/pony/orm/core.py index 6ba4d0670..9207f13da 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2069,7 +2069,7 @@ def db_set(attr, obj, new_dbval, is_reverse_call=False): not attr.reverse and attr.converters[0].dbvals_equal(old_dbval, new_dbval)): return - bit = obj._bits_[attr] + bit = obj._bits_except_volatile_[attr] if obj._rbits_ & bit: assert old_dbval is not NOT_LOADED if new_dbval is NOT_LOADED: diff = '' @@ -3859,7 +3859,7 @@ def _set_rbits(entity, objects, attrs): if wbits is None: continue rbits = get_rbits(obj.__class__) if rbits is None: - rbits = sum(obj._bits_.get(attr, 0) for attr in attrs) + rbits = sum(obj._bits_except_volatile_.get(attr, 0) for attr in attrs) rbits_dict[obj.__class__] = rbits obj._rbits_ |= rbits & ~wbits def _parse_row_(entity, row, attr_offsets): @@ -4372,7 +4372,7 @@ def _db_set_(obj, avdict, unpickling=False): del avdict[attr] continue - bit = obj._bits_[attr] + bit = obj._bits_except_volatile_[attr] if rbits & bit: throw(UnrepeatableReadError, 'Value of %s.%s for %s was updated outside of current transaction (was: %r, now: %r)' % (obj.__class__.__name__, attr.name, obj, old_dbval, new_dbval)) diff --git a/pony/orm/tests/test_volatile.py b/pony/orm/tests/test_volatile.py new file mode 100644 index 000000000..680939ed4 --- /dev/null +++ b/pony/orm/tests/test_volatile.py @@ -0,0 +1,39 @@ +import sys, unittest + +from pony.orm import * +from pony.orm.tests.testutils import * + +class TestVolatile(unittest.TestCase): + def setUp(self): + db = self.db = Database('sqlite', ':memory:') + + class Item(self.db.Entity): + name = Required(str) + index = Required(int, volatile=True) + + db.generate_mapping(create_tables=True) + + with db_session: + Item(name='A', index=1) + Item(name='B', index=2) + Item(name='C', index=3) + + @db_session + def test_1(self): + db = self.db + Item = db.Item + + db.execute('update "Item" set "index" = "index" + 1') + items = Item.select(lambda item: item.index > 0).order_by(Item.id)[:] + a, b, c = items + self.assertEqual(a.index, 2) + self.assertEqual(b.index, 3) + self.assertEqual(c.index, 4) + c.index = 1 + items = Item.select()[:] # force re-read from the database + self.assertEqual(c.index, 1) + self.assertEqual(a.index, 2) + self.assertEqual(b.index, 3) + +if __name__ == '__main__': + unittest.main() From a158c84924c92412140376e80662e36d5ef84408 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 22 Nov 2016 20:59:46 +0300 Subject: [PATCH 109/547] Fix creation of self-referenced foreign keys --- pony/orm/dbschema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index a03a8bef7..e6c3f8b7e 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -137,6 +137,7 @@ def get_create_command(table): return '\n'.join(cmd) def get_objects_to_create(table, created_tables=None): if created_tables is None: created_tables = set() + created_tables.add(table) result = [ table ] indexes = [ index for index in itervalues(table.indexes) if not index.is_pk and not index.is_unique ] for index in indexes: assert index.name is not None @@ -152,7 +153,6 @@ def get_objects_to_create(table, created_tables=None): for foreign_key in sorted(itervalues(child_table.foreign_keys), key=lambda fk: fk.name): if foreign_key.parent_table is not table: continue result.append(foreign_key) - created_tables.add(table) return result def add_column(table, column_name, sql_type, converter, is_not_null=None, sql_default=None): return table.schema.column_class(column_name, table, sql_type, converter, is_not_null, sql_default) From a0eeab882c98f21d1317c6ddc713070025c6a189 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 3 Dec 2016 03:57:39 +0300 Subject: [PATCH 110/547] Fix clearing of volatile attributes after update: should clear dbvals as well as vals --- pony/orm/core.py | 12 ++++++++---- pony/orm/tests/test_volatile.py | 10 +++++++++- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 9207f13da..26613cce8 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4641,14 +4641,18 @@ def _update_dbvals_(obj, after_create): for key, i in attr.composite_keys: keyval = tuple(get_val(attr) for attr in key) cache_indexes[key].pop(keyval, None) - del vals[attr] elif after_create and val is None: obj._rbits_ &= ~bits[attr] - del vals[attr] else: - # TODO this conversion should be unnecessary + # For normal attribute, set `dbval` to the same value as `val` after update/create + # dbvals[attr] = val converter = attr.converters[0] - dbvals[attr] = converter.val2dbval(val, obj) + dbvals[attr] = converter.val2dbval(val, obj) # TODO this conversion should be unnecessary + continue + # Clear value of volatile attribute or null values after create, because the value may be changed in the DB + del vals[attr] + dbvals.pop(attr, None) + def _save_created_(obj): auto_pk = (obj._pkval_ is None) attrs = [] diff --git a/pony/orm/tests/test_volatile.py b/pony/orm/tests/test_volatile.py index 680939ed4..663091efd 100644 --- a/pony/orm/tests/test_volatile.py +++ b/pony/orm/tests/test_volatile.py @@ -22,7 +22,6 @@ class Item(self.db.Entity): def test_1(self): db = self.db Item = db.Item - db.execute('update "Item" set "index" = "index" + 1') items = Item.select(lambda item: item.index > 0).order_by(Item.id)[:] a, b, c = items @@ -35,5 +34,14 @@ def test_1(self): self.assertEqual(a.index, 2) self.assertEqual(b.index, 3) + + @db_session + def test_2(self): + Item = self.db.Item + item = Item[1] + item.name = 'X' + item.flush() + self.assertEqual(item.index, 1) + if __name__ == '__main__': unittest.main() From 60ddecb060b8ea3f8637d87ec1bcbac51d8c4ddf Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 12 Dec 2016 23:40:04 +0300 Subject: [PATCH 111/547] Bug fixed: when required attribute is empty when loading from the database it should not lead to validation error --- pony/orm/core.py | 5 ++-- pony/orm/tests/test_validate.py | 44 +++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 pony/orm/tests/test_validate.py diff --git a/pony/orm/core.py b/pony/orm/core.py index 26613cce8..cac6415a4 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2209,8 +2209,9 @@ class Required(Attribute): __slots__ = [] def validate(attr, val, obj=None, entity=None, from_db=False): val = Attribute.validate(attr, val, obj, entity, from_db) - if val == '' or (val is None and not (attr.auto or attr.is_volatile or attr.sql_default)): - throw(ValueError, 'Attribute %s is required' % (attr if obj is None else '%r.%s' % (obj, attr.name))) + if not from_db: + if val == '' or (val is None and not (attr.auto or attr.is_volatile or attr.sql_default)): + throw(ValueError, 'Attribute %s is required' % (attr if obj is None else '%r.%s' % (obj, attr.name))) return val class Discriminator(Required): diff --git a/pony/orm/tests/test_validate.py b/pony/orm/tests/test_validate.py new file mode 100644 index 000000000..76f9846f0 --- /dev/null +++ b/pony/orm/tests/test_validate.py @@ -0,0 +1,44 @@ +import unittest + +from pony.orm import * +from pony.orm.tests.testutils import raises_exception + +db = Database('sqlite', ':memory:') + +class Person(db.Entity): + id = PrimaryKey(int) + name = Required(str) + tel = Optional(str) + +db.generate_mapping(check_tables=False) + +with db_session: + db.execute(""" + create table Person( + id int primary key, + name text, + tel text + ) + """) + + +class TestValidate(unittest.TestCase): + + @db_session + def setUp(self): + db.execute('delete from Person') + + @db_session + def test_1(self): + db.insert('Person', id=1, name='', tel='111') + p = Person.get(id=1) + self.assertEqual(p.name, '') + + @db_session + def test_2(self): + db.insert('Person', id=1, name=None, tel='111') + p = Person.get(id=1) + self.assertEqual(p.name, None) + +if __name__ == '__main__': + unittest.main() From 6516a8dbfc2f054fe1ed5e1f3ff684de885f8309 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 20 Dec 2016 17:44:09 +0300 Subject: [PATCH 112/547] Fix throwing InvalidQuery: "Use generator expression (... for ... in ...) instead of list comprehension [... for ... in ...] inside query" --- pony/orm/decompiling.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index bed528309..0a6a04b97 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -274,8 +274,9 @@ def JUMP_FORWARD(decompiler, endpos): if decompiler.targets.get(endpos) is then: decompiler.targets[endpos] = if_exp return if_exp - def LIST_APPEND(decompiler): - throw(NotImplementedError) + def LIST_APPEND(decompiler, offset=None): + throw(InvalidQuery('Use generator expression (... for ... in ...) ' + 'instead of list comprehension [... for ... in ...] inside query')) def LOAD_ATTR(decompiler, attr_name): return ast.Getattr(decompiler.stack.pop(), attr_name) @@ -379,7 +380,8 @@ def STORE_DEREF(decompiler, freevar): def STORE_FAST(decompiler, varname): if varname.startswith('_['): - throw(InvalidQuery('Use generator expression (... for ... in ...) instead of list comprehension [... for ... in ...] inside query')) + throw(InvalidQuery('Use generator expression (... for ... in ...) ' + 'instead of list comprehension [... for ... in ...] inside query')) decompiler.assnames.add(varname) decompiler.store(ast.AssName(varname, 'OP_ASSIGN')) From e19b4463bff110047ada4dd9685b3b42c9c2644b Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 21 Dec 2016 11:49:01 +0300 Subject: [PATCH 113/547] New warning DatabaseContainsIncorrectEmptyValue added --- pony/orm/core.py | 17 ++++++++++-- pony/orm/tests/test_validate.py | 49 +++++++++++++++++++++++++++------ 2 files changed, 54 insertions(+), 12 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index cac6415a4..c5e0af540 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2,7 +2,7 @@ from pony.py23compat import PY2, izip, imap, iteritems, itervalues, items_list, values_list, xrange, cmp, \ basestring, unicode, buffer, int_types, builtins, pickle, with_metaclass -import io, json, re, sys, types, datetime, logging, itertools +import io, json, re, sys, types, datetime, logging, itertools, warnings from operator import attrgetter, itemgetter from itertools import chain, starmap, repeat from time import time @@ -42,6 +42,7 @@ ObjectNotFound MultipleObjectsFoundError TooManyObjectsFoundError OperationWithDeletedObjectError TransactionError ConnectionClosedError TransactionIntegrityError IsolationError CommitException RollbackException UnrepeatableReadError OptimisticCheckError UnresolvableCyclicDependency UnexpectedError DatabaseSessionIsOver + DatabaseContainsIncorrectValue DatabaseContainsIncorrectEmptyValue TranslationError ExprEvalError @@ -187,6 +188,12 @@ def __init__(exc, src, cause): class OptimizationFailed(Exception): pass # Internal exception, cannot be encountered in user code +class DatabaseContainsIncorrectValue(RuntimeWarning): + pass + +class DatabaseContainsIncorrectEmptyValue(DatabaseContainsIncorrectValue): + pass + def adapt_sql(sql, paramstyle): result = adapted_sql_cache.get((sql, paramstyle)) if result is not None: return result @@ -2209,9 +2216,13 @@ class Required(Attribute): __slots__ = [] def validate(attr, val, obj=None, entity=None, from_db=False): val = Attribute.validate(attr, val, obj, entity, from_db) - if not from_db: - if val == '' or (val is None and not (attr.auto or attr.is_volatile or attr.sql_default)): + if val == '' or (val is None and not (attr.auto or attr.is_volatile or attr.sql_default)): + if not from_db: throw(ValueError, 'Attribute %s is required' % (attr if obj is None else '%r.%s' % (obj, attr.name))) + else: + warnings.warn('Database contains %s for required attribute %s' + % ('NULL' if val is None else 'empty string', attr), + DatabaseContainsIncorrectEmptyValue) return val class Discriminator(Required): diff --git a/pony/orm/tests/test_validate.py b/pony/orm/tests/test_validate.py index 76f9846f0..380068bff 100644 --- a/pony/orm/tests/test_validate.py +++ b/pony/orm/tests/test_validate.py @@ -1,6 +1,7 @@ -import unittest +import unittest, warnings from pony.orm import * +from pony.orm import core from pony.orm.tests.testutils import raises_exception db = Database('sqlite', ':memory:') @@ -21,24 +22,54 @@ class Person(db.Entity): ) """) +warnings.simplefilter('error', ) + class TestValidate(unittest.TestCase): @db_session def setUp(self): db.execute('delete from Person') + registry = getattr(core, '__warningregistry__', {}) + for key in list(registry): + text, category, lineno = key + if category is DatabaseContainsIncorrectEmptyValue: + del registry[key] + + @db_session + def test_1a(self): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DatabaseContainsIncorrectEmptyValue) + db.insert('Person', id=1, name='', tel='111') + p = Person.get(id=1) + self.assertEqual(p.name, '') + @raises_exception(DatabaseContainsIncorrectEmptyValue, + 'Database contains empty string for required attribute Person.name') @db_session - def test_1(self): - db.insert('Person', id=1, name='', tel='111') - p = Person.get(id=1) - self.assertEqual(p.name, '') + def test_1b(self): + with warnings.catch_warnings(): + warnings.simplefilter('error', DatabaseContainsIncorrectEmptyValue) + db.insert('Person', id=1, name='', tel='111') + p = Person.get(id=1) @db_session - def test_2(self): - db.insert('Person', id=1, name=None, tel='111') - p = Person.get(id=1) - self.assertEqual(p.name, None) + def test_2a(self): + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DatabaseContainsIncorrectEmptyValue) + db.insert('Person', id=1, name=None, tel='111') + p = Person.get(id=1) + self.assertEqual(p.name, None) + + @raises_exception(DatabaseContainsIncorrectEmptyValue, + 'Database contains NULL for required attribute Person.name') + @db_session + def test_2b(self): + with warnings.catch_warnings(): + warnings.simplefilter('error', DatabaseContainsIncorrectEmptyValue) + db.insert('Person', id=1, name=None, tel='111') + p = Person.get(id=1) + if __name__ == '__main__': unittest.main() From 5d29397e0462a4c93dc1a81a272ad6482460fb19 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 27 Dec 2016 14:14:13 +0300 Subject: [PATCH 114/547] Fix bug introduced in ca3113c9: query optimization lead to invalid query --- pony/orm/sqltranslation.py | 2 +- pony/orm/tests/test_declarative_sqltranslator.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 37ceb3ec8..3ebe4f908 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -843,7 +843,7 @@ def __init__(subquery, parent_subquery=None, left_join=False): subquery.alias_counters = {} subquery.expr_counter = itertools.count(1) else: - subquery.alias_counters = parent_subquery.alias_counters + subquery.alias_counters = parent_subquery.alias_counters.copy() subquery.expr_counter = parent_subquery.expr_counter subquery.used_from_subquery = False def get_tableref(subquery, name_path, from_subquery=False): diff --git a/pony/orm/tests/test_declarative_sqltranslator.py b/pony/orm/tests/test_declarative_sqltranslator.py index bc7134908..396732a63 100644 --- a/pony/orm/tests/test_declarative_sqltranslator.py +++ b/pony/orm/tests/test_declarative_sqltranslator.py @@ -77,6 +77,7 @@ class Room(db.Entity): g1.rooms = [ r1, r2 ] g2.rooms = [ r2, r3 ] c1.students.add(s1) + c1.students.add(s2) c2.students.add(s2) db2 = Database('sqlite', ':memory:') @@ -283,13 +284,13 @@ def test_composite_key1(self): self.assertEqual(result, {Teacher.get(name='T1')}) def test_composite_key2(self): result = set(select(s for s in Student if Course['Math', 1] in s.courses)) - self.assertEqual(result, {Student[1]}) + self.assertEqual(result, {Student[1], Student[2]}) def test_composite_key3(self): result = set(select(s for s in Student if Course['Math', 1] not in s.courses)) - self.assertEqual(result, {Student[2], Student[3]}) + self.assertEqual(result, {Student[3]}) def test_composite_key4(self): result = set(select(s for s in Student if len(c for c in Course if c not in s.courses) == 2)) - self.assertEqual(result, {Student[1], Student[2]}) + self.assertEqual(result, {Student[1]}) def test_composite_key5(self): result = set(select(s for s in Student if not (c for c in Course if c not in s.courses))) self.assertEqual(result, set()) @@ -332,7 +333,7 @@ def test_hint_join1(self): self.assertEqual(result, {Student[2]}) def test_hint_join2(self): result = set(select(c for c in Course if JOIN(len(c.students) == 1))) - self.assertEqual(result, {Course['Math', 1], Course['Economics', 1]}) + self.assertEqual(result, {Course['Economics', 1]}) def test_tuple_param(self): x = Student[1], Student[2] result = set(select(s for s in Student if s not in x)) @@ -365,6 +366,12 @@ def test_lambda_4(self): q = select(s for s in Student) q = q.filter(lambda stud: exists(s for s in Student if stud.name < s.name)) self.assertEqual(set(q), {Student[1], Student[2]}) + def test_optimized_1(self): + q = select((g, count(g.students)) for g in Group if count(g.students) > 1) + self.assertEqual(set(q), {(Group[1], 2)}) + def test_optimized_2(self): + q = select((s, count(s.courses)) for s in Student if count(s.courses) > 1) + self.assertEqual(set(q), {(Student[2], 2)}) if __name__ == "__main__": From 5ab01cb0bc7df8d848a268a9935f635331d3cb78 Mon Sep 17 00:00:00 2001 From: Anton Grudko Date: Mon, 26 Dec 2016 13:47:20 +0300 Subject: [PATCH 115/547] Fix #216: decompiler fixes for Python 3.6 --- pony/orm/decompiling.py | 69 +++++++++++++++---- pony/orm/tests/test_declarative_func_monad.py | 5 +- pony/orm/tests/test_validate.py | 1 + setup.py | 5 +- 4 files changed, 64 insertions(+), 16 deletions(-) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index 0a6a04b97..8fa9ae700 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -77,23 +77,32 @@ def __init__(decompiler, code, start=0, end=None): decompiler.external_names = decompiler.names - decompiler.assnames assert not decompiler.stack, decompiler.stack def decompile(decompiler): + PY36 = sys.version_info >= (3, 6) code = decompiler.code co_code = code.co_code free = code.co_cellvars + code.co_freevars try: + extended_arg = 0 while decompiler.pos < decompiler.end: i = decompiler.pos if i in decompiler.targets: decompiler.process_target(i) op = ord(code.co_code[i]) - i += 1 - if op >= HAVE_ARGUMENT: - oparg = ord(co_code[i]) + ord(co_code[i+1])*256 + if PY36: + if op >= HAVE_ARGUMENT: + oparg = ord(co_code[i + 1]) | extended_arg + extended_arg = (arg << 8) if op == EXTENDED_ARG else 0 i += 2 - if op == EXTENDED_ARG: - op = ord(code.co_code[i]) - i += 1 - oparg = ord(co_code[i]) + ord(co_code[i+1])*256 + oparg*65536 + else: + i += 1 + if op >= HAVE_ARGUMENT: + oparg = ord(co_code[i]) + ord(co_code[i + 1]) * 256 i += 2 + if op == EXTENDED_ARG: + op = ord(code.co_code[i]) + i += 1 + oparg = ord(co_code[i]) + ord(co_code[i + 1]) * 256 + oparg * 65536 + i += 2 + if op >= HAVE_ARGUMENT: if op in hasconst: arg = [code.co_consts[oparg]] elif op in hasname: arg = [code.co_names[oparg]] elif op in hasjrel: arg = [i + oparg] @@ -147,6 +156,14 @@ def BINARY_SUBSCR(decompiler): if isinstance(oper2, ast.Tuple): return ast.Subscript(oper1, 'OP_APPLY', list(oper2.nodes)) else: return ast.Subscript(oper1, 'OP_APPLY', [ oper2 ]) + def BUILD_CONST_KEY_MAP(decompiler, length): + keys = decompiler.stack.pop() + assert isinstance(keys, ast.Const) + keys = [ ast.Const(key) for key in keys.value ] + values = decompiler.pop_items(length) + pairs = list(izip(keys, values)) + return ast.Dict(pairs) + def BUILD_LIST(decompiler, size): return ast.List(decompiler.pop_items(size)) @@ -177,7 +194,10 @@ def CALL_FUNCTION(decompiler, argc, star=None, star2=None): args.append(ast.Keyword(key, arg)) for i in xrange(posarg): args.append(pop()) args.reverse() - tos = pop() + return decompiler._call_function(args, star, star2) + + def _call_function(decompiler, args, star=None, star2=None): + tos = decompiler.stack.pop() if isinstance(tos, ast.GenExpr): assert len(args) == 1 and star is None and star2 is None genexpr = tos @@ -192,13 +212,31 @@ def CALL_FUNCTION_VAR(decompiler, argc): return decompiler.CALL_FUNCTION(argc, decompiler.stack.pop()) def CALL_FUNCTION_KW(decompiler, argc): - return decompiler.CALL_FUNCTION(argc, None, decompiler.stack.pop()) + if sys.version_info < (3, 6): + return decompiler.CALL_FUNCTION(argc, star2=decompiler.stack.pop()) + keys = decompiler.stack.pop() + assert isinstance(keys, ast.Const) + keys = keys.value + values = decompiler.pop_items(argc) + assert len(keys) <= len(values) + args = values[:-len(keys)] + for key, value in izip(keys, values[-len(keys):]): + args.append(ast.Keyword(key, value)) + return decompiler._call_function(args) def CALL_FUNCTION_VAR_KW(decompiler, argc): star2 = decompiler.stack.pop() star = decompiler.stack.pop() return decompiler.CALL_FUNCTION(argc, star, star2) + def CALL_FUNCTION_EX(decompiler, argc): + star2 = None + if argc: + if argc != 1: throw(NotImplementedError) + star2 = decompiler.stack.pop() + star = decompiler.stack.pop() + return decompiler._call_function([], star, star2) + def COMPARE_OP(decompiler, op): oper2 = decompiler.stack.pop() oper1 = decompiler.stack.pop() @@ -310,9 +348,16 @@ def MAKE_CLOSURE(decompiler, argc): return decompiler.MAKE_FUNCTION(argc) def MAKE_FUNCTION(decompiler, argc): - if argc: throw(NotImplementedError) - tos = decompiler.stack.pop() - if not PY2: tos = decompiler.stack.pop() + if sys.version_info >= (3, 6): + if argc: + if argc != 0x08: throw(NotImplementedError, argc) + qualname = decompiler.stack.pop() + tos = decompiler.stack.pop() + if (argc & 0x08): func_closure = decompiler.stack.pop() + else: + if argc: throw(NotImplementedError) + tos = decompiler.stack.pop() + if not PY2: tos = decompiler.stack.pop() codeobject = tos.value func_decompiler = Decompiler(codeobject) # decompiler.names.update(decompiler.names) ??? diff --git a/pony/orm/tests/test_declarative_func_monad.py b/pony/orm/tests/test_declarative_func_monad.py index a0c17d51e..834e08cc9 100644 --- a/pony/orm/tests/test_declarative_func_monad.py +++ b/pony/orm/tests/test_declarative_func_monad.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, print_function, division from pony.py23compat import PY2 -import unittest +import sys, unittest from datetime import date, datetime from decimal import Decimal @@ -115,7 +115,8 @@ def test_datetime_now1(self): self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) @raises_exception(ExprEvalError, "1 < datetime.now() raises TypeError: " + ("can't compare datetime.datetime to int" if PY2 else - "unorderable types: int() < datetime.datetime()")) + "unorderable types: int() < datetime.datetime()" if sys.version_info < (3, 6) else + "'<' not supported between instances of 'int' and 'datetime.datetime'")) def test_datetime_now2(self): select(s for s in Student if 1 < datetime.now()) def test_datetime_now3(self): diff --git a/pony/orm/tests/test_validate.py b/pony/orm/tests/test_validate.py index 380068bff..3ff4425d5 100644 --- a/pony/orm/tests/test_validate.py +++ b/pony/orm/tests/test_validate.py @@ -32,6 +32,7 @@ def setUp(self): db.execute('delete from Person') registry = getattr(core, '__warningregistry__', {}) for key in list(registry): + if type(key) is not tuple: continue text, category, lineno = key if category is DatabaseContainsIncorrectEmptyValue: del registry[key] diff --git a/setup.py b/setup.py index e3cc03145..612f77fe6 100644 --- a/setup.py +++ b/setup.py @@ -63,6 +63,7 @@ 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', 'Topic :: Software Development :: Libraries', 'Topic :: Database' ] @@ -88,8 +89,8 @@ if __name__ == "__main__": pv = sys.version_info[:2] - if pv not in ((2, 7), (3, 3), (3, 4), (3, 5)): - s = "Sorry, but %s %s requires Python of one of the following versions: 2.7, 3.3, 3.4 and 3.5." \ + if pv not in ((2, 7), (3, 3), (3, 4), (3, 5), (3, 6)): + s = "Sorry, but %s %s requires Python of one of the following versions: 2.7, 3.3-3.6." \ " You have version %s" print(s % (name, version, sys.version.split(' ', 1)[0])) sys.exit(1) From f418602c5f1efbce9ee2f13125b57c933dedc4e9 Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Tue, 10 Jan 2017 16:41:16 +0300 Subject: [PATCH 116/547] Update CHANGELOG and the version: 0.7.1-dev -> 0.7.1 --- CHANGELOG.md | 17 +++++++++++++++++ pony/__init__.py | 2 +- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d32da2a7b..f2619b064 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,20 @@ +# Pony ORM Release 0.7.1 (2017-01-10) + +## New features + +* New warning DatabaseContainsIncorrectEmptyValue added, it is raised when the required attribute is empty during loading an entity from the database + +## Bugfixes + +* Fixes #216: Added Python 3.6 support +* Fixes #203: subtranslator should use argnames from parent translator +* Change a way aliases in SQL query are generated in order to fix a problem when a subquery alias masks a base query alias +* Volatile attribute bug fixed +* Fix creation of self-referenced foreign keys - before this Pony didn't create the foreign key for self-referenced attributes +* Bug fixed: when required attribute is empty the loading from the database shouldn't raise the validation error. Now Pony raises the warning DatabaseContainsIncorrectEmptyValue +* Throw an error with more clear explanation when a list comprehension is used inside a query instead of a generator expression: "Use generator expression (... for ... in ...) instead of list comprehension [... for ... in ...] inside query" + + # Pony ORM Release 0.7 (2016-10-11) Starting with this release Pony ORM is release under the Apache License, Version 2.0. diff --git a/pony/__init__.py b/pony/__init__.py index 77376cb2e..3264e8d1b 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.1-dev' +__version__ = '0.7.1' uid = str(random.randint(1, 1000000)) From d03e6bee84af86f1e86e78130e5c28ad6d2ed4c6 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 13 Jan 2017 06:47:48 +0300 Subject: [PATCH 117/547] Change Pony version: 0.7.1 -> 0.7.2-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index 3264e8d1b..c6307af36 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.1' +__version__ = '0.7.2-dev' uid = str(random.randint(1, 1000000)) From 1146ff1f4c3525393d4636406ae41213a628f4b1 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 13 Jan 2017 06:40:19 +0300 Subject: [PATCH 118/547] Fixes #219: when a connection was closed due to some exception, implicit rollback during exiting from db_session should not mask that exception with another 'connection already closed' exception --- pony/orm/core.py | 21 +++++++++++++-------- pony/orm/tests/test_db_session.py | 7 +++++++ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index c5e0af540..9e08db55d 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -283,6 +283,12 @@ def transact_reraise(exc_class, exceptions): reraise(exc_class, new_exc, tb) finally: del exceptions, exc, tb, new_exc +def rollback_and_reraise(exc_info): + try: + rollback() + finally: + reraise(*exc_info) + @cut_traceback def commit(): caches = _get_caches() @@ -292,8 +298,7 @@ def commit(): for cache in caches: cache.flush() except: - rollback() - raise + rollback_and_reraise(sys.exc_info()) primary_cache = caches[0] other_caches = caches[1:] @@ -387,14 +392,15 @@ def __exit__(db_session, exc_type=None, exc=None, tb=None): else: assert exc is not None # exc can be None in Python 2.6 even if exc_type is not None try: can_commit = db_session.allowed_exceptions(exc) - except: - rollback() - raise + except: rollback_and_reraise(sys.exc_info()) if can_commit: commit() for cache in _get_caches(): cache.release() assert not local.db2cache - else: rollback() + else: + try: rollback() + except: + if exc_type is None: raise # if exc_type is not None it will be reraised outside of __exit__ finally: del exc, tb local.db_session = None @@ -465,8 +471,7 @@ def wrapped_interact(iterator, input=None, exc_info=None): if cache.modified or cache.in_transaction: throw(TransactionError, 'You need to manually commit() changes before yielding from the generator') except: - rollback() - raise + rollback_and_reraise(sys.exc_info()) else: return output finally: diff --git a/pony/orm/tests/test_db_session.py b/pony/orm/tests/test_db_session.py index 3d56e060a..7bf2c1160 100644 --- a/pony/orm/tests/test_db_session.py +++ b/pony/orm/tests/test_db_session.py @@ -325,6 +325,13 @@ def before_insert(self): db.commit() # Should raise ZeroDivisionError and not CommitException + @raises_exception(ZeroDivisionError) + def test_db_session_exceptions_4(self): + with db_session: + connection = self.db.get_connection() + connection.close() + 1/0 + db = Database('sqlite', ':memory:') class Group(db.Entity): From a771845cce4794bdd42238eb3fb7143ee9c3cd63 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 2 Mar 2017 14:33:40 +0300 Subject: [PATCH 119/547] Improved assertions --- pony/orm/core.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 9e08db55d..b87d4051e 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2096,7 +2096,7 @@ def db_set(attr, obj, new_dbval, is_reverse_call=False): wbit = bool(obj._wbits_ & bit) if not wbit: old_val = obj._vals_.get(attr, NOT_LOADED) - assert old_val == old_dbval + assert old_val == old_dbval, (old_val, old_dbval) if attr.is_part_of_unique_index: cache = obj._session_cache_ if attr.is_unique: cache.db_update_simple_index(obj, attr, old_val, new_dbval) @@ -4390,9 +4390,13 @@ def _db_set_(obj, avdict, unpickling=False): continue bit = obj._bits_except_volatile_[attr] - if rbits & bit: throw(UnrepeatableReadError, - 'Value of %s.%s for %s was updated outside of current transaction (was: %r, now: %r)' - % (obj.__class__.__name__, attr.name, obj, old_dbval, new_dbval)) + if rbits & bit: + errormsg = 'Please contact PonyORM developers so they can ' \ + 'reproduce your error and fix a bug: support@ponyorm.com' + assert old_dbval is not NOT_LOADED, errormsg + throw(UnrepeatableReadError, + 'Value of %s.%s for %s was updated outside of current transaction (was: %r, now: %r)' + % (obj.__class__.__name__, attr.name, obj, old_dbval, new_dbval)) if attr.reverse: attr.db_update_reverse(obj, old_dbval, new_dbval) obj._dbvals_[attr] = new_dbval From 9f909f5a4f00b3387c33ea78a8835fb7af44388f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 16 Feb 2017 14:00:54 +0300 Subject: [PATCH 120/547] Fix @cut_traceback decorator --- pony/utils/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pony/utils/utils.py b/pony/utils/utils.py index d8d4e3b5f..03ef1e040 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -84,6 +84,7 @@ def cut_traceback(func, *args, **kwargs): except AssertionError: raise except Exception: exc_type, exc, tb = sys.exc_info() + full_tb = tb last_pony_tb = None try: while tb.tb_next: @@ -93,11 +94,11 @@ def cut_traceback(func, *args, **kwargs): last_pony_tb = tb tb = tb.tb_next if last_pony_tb is None: raise - if tb.tb_frame.f_globals.get('__name__') == 'pony.utils' and tb.tb_frame.f_code.co_name == 'throw': + if tb.tb_frame.f_globals.get('__name__').startswith('pony.utils') and tb.tb_frame.f_code.co_name == 'throw': reraise(exc_type, exc, last_pony_tb) - raise exc # Set "pony.options.CUT_TRACEBACK = False" to see full traceback + reraise(exc_type, exc, full_tb) finally: - del exc, tb, last_pony_tb + del exc, full_tb, tb, last_pony_tb if PY2: exec('''def reraise(exc_type, exc, tb): @@ -492,4 +493,4 @@ def concat(*args): return ''.join(tostring(arg) for arg in args) def is_utf8(encoding): - return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') \ No newline at end of file + return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') From a4a581ded93feeae470fd85b01685eda14d74ff8 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 3 Feb 2017 17:36:27 +0300 Subject: [PATCH 121/547] Fix for JSON_NE --- pony/orm/dbproviders/mysql.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index bb8346bf0..a3aa63187 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -103,9 +103,9 @@ def JSON_NONZERO(builder, expr): return 'COALESCE(CAST(', builder(expr), ''' as CHAR), 'null') NOT IN ('null', 'false', '0', '""', '[]', '{}')''' def JSON_ARRAY_LENGTH(builder, value): return 'json_length(', builder(value), ')' - def EQ_JSON(builder, left, right): + def JSON_EQ(builder, left, right): return '(', builder(left), ' = CAST(', builder(right), ' AS JSON))' - def NE_JSON(builder, left, right): + def JSON_NE(builder, left, right): return '(', builder(left), ' != CAST(', builder(right), ' AS JSON))' def JSON_CONTAINS(builder, expr, path, key): key_sql = builder(key) @@ -163,7 +163,8 @@ def sql_type(converter): return 'BINARY(16)' class MySQLJsonConverter(dbapiprovider.JsonConverter): - EQ = 'EQ_JSON' + EQ = 'JSON_EQ' + NE = 'JSON_NE' def init(self, kwargs): if self.provider.server_version < (5, 7, 8): version = '.'.join(imap(str, self.provider.server_version)) From a00e4072962c90d4e257863a8dd1baa300851953 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 3 Feb 2017 19:05:34 +0300 Subject: [PATCH 122/547] Optional optimistic checks for float attributes --- pony/orm/core.py | 5 +++-- pony/orm/dbapiprovider.py | 3 +++ pony/orm/dbproviders/sqlite.py | 7 ++++++- pony/orm/sqlbuilding.py | 6 ++++++ 4 files changed, 18 insertions(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index b87d4051e..63e50c138 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1796,7 +1796,7 @@ def __init__(attr, py_type, *args, **kwargs): attr.lazy = kwargs.pop('lazy', getattr(py_type, 'lazy', False)) attr.lazy_sql_cache = None attr.is_volatile = kwargs.pop('volatile', False) - attr.optimistic = kwargs.pop('optimistic', True) + attr.optimistic = kwargs.pop('optimistic', None) attr.sql_default = kwargs.pop('sql_default', None) attr.py_check = kwargs.pop('py_check', None) attr.hidden = kwargs.pop('hidden', False) @@ -4620,7 +4620,8 @@ def _construct_optimistic_criteria_(obj): for attr in obj._attrs_with_bit_(obj._attrs_with_columns_, obj._rbits_): converters = attr.converters assert converters - if not (attr.optimistic and converters[0].optimistic): continue + optimistic = attr.optimistic if attr.optimistic is not None else converters[0].optimistic + if not optimistic: continue dbval = obj._dbvals_[attr] optimistic_columns.extend(attr.columns) optimistic_converters.extend(attr.converters) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 69378b227..d3e0a7128 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -506,10 +506,13 @@ def sql_type(converter): return converter.unsigned_types.get(converter.size) class RealConverter(Converter): + EQ = 'FLOAT_EQ' + NE = 'FLOAT_NE' # The tolerance is necessary for Oracle, because it has different representation of float numbers. # For other databases the default tolerance is set because the precision can be lost during # Python -> JavaScript -> Python conversion default_tolerance = 1e-14 + optimistic = False def init(converter, kwargs): Converter.init(converter, kwargs) min_val = kwargs.pop('min', None) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index ef96a9400..adcf09813 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -132,7 +132,12 @@ def RANDOM(builder): return 'rand()' # return '(random() / 9223372036854775807.0 + 1.0) / 2.0' PY_UPPER = make_unary_func('py_upper') PY_LOWER = make_unary_func('py_lower') - + def FLOAT_EQ(builder, a, b): + a, b = builder(a), builder(b) + return 'abs(', a, ' - ', b, ') / coalesce(nullif(max(abs(', a, '), abs(', b, ')), 0), 1) <= 1e-14' + def FLOAT_NE(builder, a, b): + a, b = builder(a), builder(b) + return 'abs(', a, ' - ', b, ') / coalesce(nullif(max(abs(', a, '), abs(', b, ')), 0), 1) > 1e-14' def JSON_QUERY(builder, expr, path): fname = 'json_extract' if builder.json1_available else 'py_json_extract' path_sql, has_params, has_wildcards = builder.build_json_path(path) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index cd82ab2ef..bd1913bde 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -400,6 +400,12 @@ def POW(builder, expr1, expr2): DIV = make_binary_op(' / ', True) FLOORDIV = make_binary_op(' / ', True) + def FLOAT_EQ(builder, a, b): + a, b = builder(a), builder(b) + return 'abs(', a, ' - ', b, ') / coalesce(nullif(greatest(abs(', a, '), abs(', b, ')), 0), 1) <= 1e-14' + def FLOAT_NE(builder, a, b): + a, b = builder(a), builder(b) + return 'abs(', a, ' - ', b, ') / coalesce(nullif(greatest(abs(', a, '), abs(', b, ')), 0), 1) > 1e-14' def CONCAT(builder, *args): return '(', join(' || ', imap(builder, args)), ')' def NEG(builder, expr): From a07e3daec9b747a1d0dd6b8e0bc9a5febe3ffe91 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 20 Feb 2017 12:10:44 +0300 Subject: [PATCH 123/547] Fix Oracle connection string in example file --- pony/orm/examples/university1.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/examples/university1.py b/pony/orm/examples/university1.py index cb7447982..699288ff8 100644 --- a/pony/orm/examples/university1.py +++ b/pony/orm/examples/university1.py @@ -45,7 +45,7 @@ class Student(db.Entity): db.bind('sqlite', 'university1.sqlite', create_db=True) #db.bind('mysql', host="localhost", user="pony", passwd="pony", db="university1") #db.bind('postgres', user='pony', password='pony', host='localhost', database='university1') -#db.bind('oracle', 'university1/pony@localhost') +#db.bind('oracle', 'c##pony/pony@localhost/orcl') db.generate_mapping(create_tables=True) From f86bd3aea3c482c16c9b8a5072b310079e206f65 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 16 Jan 2017 12:45:27 +0300 Subject: [PATCH 124/547] All arguments of Database(...) or db.bind(...) can be specified as keyword arguments --- pony/orm/core.py | 6 +++--- pony/orm/examples/university1.py | 11 +++++++---- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 63e50c138..46ec8208e 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -547,9 +547,9 @@ def _bind(self, *args, **kwargs): # argument 'self' cannot be named 'database', because 'database' can be in kwargs if self.provider is not None: throw(TypeError, 'Database object was already bound to %s provider' % self.provider.dialect) - if not args: - throw(TypeError, 'Database provider should be specified as a first positional argument') - provider, args = args[0], args[1:] + if args: provider, args = args[0], args[1:] + elif 'provider' not in kwargs: throw(TypeError, 'Database provider is not specified') + else: provider = kwargs.pop('provider') if isinstance(provider, type) and issubclass(provider, DBAPIProvider): provider_cls = provider else: diff --git a/pony/orm/examples/university1.py b/pony/orm/examples/university1.py index 699288ff8..81a1b91f9 100644 --- a/pony/orm/examples/university1.py +++ b/pony/orm/examples/university1.py @@ -42,10 +42,13 @@ class Student(db.Entity): sql_debug(True) # Output all SQL queries to stdout -db.bind('sqlite', 'university1.sqlite', create_db=True) -#db.bind('mysql', host="localhost", user="pony", passwd="pony", db="university1") -#db.bind('postgres', user='pony', password='pony', host='localhost', database='university1') -#db.bind('oracle', 'c##pony/pony@localhost/orcl') +params = dict( + sqlite=dict(provider='sqlite', filename='university1.sqlite', create_db=True), + mysql=dict(provider='mysql', host="localhost", user="pony", passwd="pony", db="pony"), + postgres=dict(provider='postgres', user='pony', password='pony', host='localhost', database='pony'), + oracle=dict(provider='oracle', user='c##pony', password='pony', dsn='localhost/orcl') +) +db.bind(**params['sqlite']) db.generate_mapping(create_tables=True) From 4d56aaa485f1eb570a967f31e37d74d2684fcc21 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 20 Feb 2017 01:49:31 +0300 Subject: [PATCH 125/547] Fix Oracle provider handling of keyword arguments --- pony/orm/dbproviders/oracle.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index d7b4636a0..600b8880d 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -475,12 +475,16 @@ def get_pool(provider, *args, **kwargs): elif len(args) == 2: user, password = args elif len(args) == 3: user, password, dsn = args elif args: throw(ValueError, 'Invalid number of positional arguments') - if user != kwargs.setdefault('user', user): - throw(ValueError, 'Ambiguous value for user') - if password != kwargs.setdefault('password', password): - throw(ValueError, 'Ambiguous value for password') - if dsn != kwargs.setdefault('dsn', dsn): - throw(ValueError, 'Ambiguous value for dsn') + + def setdefault(kwargs, key, value): + kwargs_value = kwargs.setdefault(key, value) + if value is not None and value != kwargs_value: + throw(ValueError, 'Ambiguous value for ' + key) + + setdefault(kwargs, 'user', user) + setdefault(kwargs, 'password', password) + setdefault(kwargs, 'dsn', dsn) + kwargs.setdefault('threaded', True) kwargs.setdefault('min', 1) kwargs.setdefault('max', 10) From 87e3dcfb4c8c0dcdf4bc0cd0f95270936563ee25 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 2 Mar 2017 19:46:43 +0300 Subject: [PATCH 126/547] Fixes #232: negate for numeric expression should check if value is NULL --- pony/orm/sqltranslation.py | 6 +++++- pony/orm/tests/test_crud.py | 19 +++++++++++++++++-- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 3ebe4f908..f155008e2 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1275,7 +1275,11 @@ def nonzero(monad): return translator.CmpMonad('!=', monad, translator.ConstMonad.new(translator, 0)) def negate(monad): translator = monad.translator - return translator.CmpMonad('==', monad, translator.ConstMonad.new(translator, 0)) + result = translator.CmpMonad('==', monad, translator.ConstMonad.new(translator, 0)) + if isinstance(monad, translator.AttrMonad) and not monad.attr.nullable: + return result + sql = [ 'OR', result.getsql()[0], [ 'IS_NULL', monad.getsql()[0] ] ] + return translator.BoolExprMonad(translator, sql) def numeric_attr_factory(name): def attr_func(monad): diff --git a/pony/orm/tests/test_crud.py b/pony/orm/tests/test_crud.py index c84ea50db..ea4fcc0f6 100644 --- a/pony/orm/tests/test_crud.py +++ b/pony/orm/tests/test_crud.py @@ -16,6 +16,7 @@ class Group(db.Entity): class Student(db.Entity): name = Required(unicode) + age = Optional(int) scholarship = Required(Decimal, default=0) picture = Optional(buffer, lazy=True) email = Required(unicode, unique=True) @@ -34,8 +35,8 @@ class Course(db.Entity): with db_session: g1 = Group(id=1, major='Math') g2 = Group(id=2, major='Physics') - s1 = Student(id=1, name='S1', email='s1@example.com', group=g1) - s2 = Student(id=2, name='S2', email='s2@example.com', group=g1) + s1 = Student(id=1, name='S1', age=19, email='s1@example.com', group=g1) + s2 = Student(id=2, name='S2', age=21, email='s2@example.com', group=g1) s3 = Student(id=3, name='S3', email='s3@example.com', group=g2) c1 = Course(name='Math', semester=1) c2 = Course(name='Math', semester=2) @@ -72,6 +73,20 @@ def test_exists_3(self): x = Student.exists(group=g1) self.assertEqual(x, True) + def test_numeric_nonzero(self): + result = select(s.id for s in Student if s.age)[:] + self.assertEqual(set(result), {1, 2}) + + def test_numeric_negate_1(self): + result = select(s.id for s in Student if not s.age)[:] + self.assertEqual(set(result), {3}) + self.assertTrue('is null' in db.last_sql.lower()) + + def test_numeric_negate_2(self): + result = select(c.id for c in Course if not c.semester)[:] + self.assertEqual(result, []) + self.assertTrue('is null' not in db.last_sql.lower()) + def test_set1(self): s1 = Student[1] s1.set(name='New name', scholarship=100) From 40e3b0705d7556d42be0767117d9b05781c39118 Mon Sep 17 00:00:00 2001 From: mikhail lazko Date: Mon, 27 Feb 2017 16:26:36 +0700 Subject: [PATCH 127/547] fixed __ne__ for MethodType --- pony/orm/ormtypes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index c9278d208..ce591746b 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -60,7 +60,7 @@ def __init__(self, method): def __eq__(self, other): return type(other) is MethodType and self.obj == other.obj and self.func == other.func def __ne__(self, other): - return type(other) is not SetType or self.obj != other.obj or self.func != other.func + return type(other) is not MethodType or self.obj != other.obj or self.func != other.func def __hash__(self): return hash(self.obj) ^ hash(self.func) From d7cce6c5f6064dce46d13334255548033583be35 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 21 Mar 2017 15:35:06 +0300 Subject: [PATCH 128/547] Fix bug when discriminator column is used as a part of a primary key: http://stackoverflow.com/questions/42860579/return-none-as-classtype-in-entity-inheritance-on-ponyorm --- pony/orm/core.py | 11 ++++++++--- pony/orm/tests/test_inheritance.py | 24 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 46ec8208e..85e47d295 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -3881,19 +3881,24 @@ def _set_rbits(entity, objects, attrs): obj._rbits_ |= rbits & ~wbits def _parse_row_(entity, row, attr_offsets): discr_attr = entity._discriminator_attr_ - if not discr_attr: real_entity_subclass = entity + if not discr_attr: + discr_value = None + real_entity_subclass = entity else: discr_offset = attr_offsets[discr_attr][0] discr_value = discr_attr.validate(row[discr_offset], None, entity, from_db=True) real_entity_subclass = discr_attr.code2cls[discr_value] + discr_value = real_entity_subclass._discriminator_ # To convert unicode to str in Python 2.x avdict = {} for attr in real_entity_subclass._attrs_: offsets = attr_offsets.get(attr) if offsets is None or attr.is_discriminator: continue avdict[attr] = attr.parse_value(row, offsets) - if not entity._pk_is_composite_: pkval = avdict.pop(entity._pk_attrs_[0], None) - else: pkval = tuple(avdict.pop(attr, None) for attr in entity._pk_attrs_) + + pkval = tuple(avdict.pop(attr, discr_value) for attr in entity._pk_attrs_) + assert None not in pkval + if not entity._pk_is_composite_: pkval = pkval[0] return real_entity_subclass, pkval, avdict def _load_many_(entity, objects): database = entity._database_ diff --git a/pony/orm/tests/test_inheritance.py b/pony/orm/tests/test_inheritance.py index 4b4173528..6538d4abf 100644 --- a/pony/orm/tests/test_inheritance.py +++ b/pony/orm/tests/test_inheritance.py @@ -257,5 +257,29 @@ class Entity3(Entity1): result = select(e for e in Entity1 if e.b == 30 or e.c == 50) self.assertEqual([ e.id for e in result ], [ 2, 3 ]) + def test_discriminator_1(self): + db = Database('sqlite', ':memory:') + class Entity1(db.Entity): + a = Discriminator(str) + b = Required(int) + PrimaryKey(a, b) + class Entity2(db.Entity1): + c = Required(int) + db.generate_mapping(create_tables=True) + with db_session: + x = Entity1(b=10) + y = Entity2(b=20, c=30) + with db_session: + obj = Entity1.get(b=20) + self.assertEqual(obj.a, 'Entity2') + self.assertEqual(obj.b, 20) + self.assertEqual(obj._pkval_, ('Entity2', 20)) + with db_session: + obj = Entity1['Entity2', 20] + self.assertEqual(obj.a, 'Entity2') + self.assertEqual(obj.b, 20) + self.assertEqual(obj._pkval_, ('Entity2', 20)) + + if __name__ == '__main__': unittest.main() From d4f10371cce5999c5817da01a58ea6019c245f32 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 21 Mar 2017 18:13:59 +0300 Subject: [PATCH 129/547] Remove uninformative comments --- pony/orm/tests/test_json.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pony/orm/tests/test_json.py b/pony/orm/tests/test_json.py index 52bc90f7c..01467ae20 100644 --- a/pony/orm/tests/test_json.py +++ b/pony/orm/tests/test_json.py @@ -278,8 +278,6 @@ def test_dict_set_item(self): p = get(p for p in self.Product) self.assertDictEqual(p.info['os'], {'type': 'iOS', 'version': '9'}) - # JSON length - @db_session def test_len(self): with raises_if(self, self.db.provider.dialect == 'Oracle', @@ -289,8 +287,6 @@ def test_len(self): val = select(len(p.info['colors']) for p in self.Product).first() self.assertEqual(val, 3) - # # Json equality - @db_session def test_equal_str(self): p = get(p for p in self.Product if p.info['name'] == 'Apple iPad Air 2') From 10e45c5e91c7b03b01b6f15dcef16cde056e5dd0 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 21 Mar 2017 18:15:13 +0300 Subject: [PATCH 130/547] Fixes #221: issue with unicode json path keys --- pony/orm/sqlbuilding.py | 2 +- pony/orm/tests/test_json.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index bd1913bde..20205329b 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -555,7 +555,7 @@ def eval_json_path(cls, values): empty_slice = slice(None, None, None) for value in values: if isinstance(value, int): append('[%d]' % value) - elif isinstance(value, str): + elif isinstance(value, basestring): append('.' + value if is_ident(value) else '."%s"' % value.replace('"', '\\"')) elif value is Ellipsis: append('.*') elif value == empty_slice: append('[*]') diff --git a/pony/orm/tests/test_json.py b/pony/orm/tests/test_json.py index 01467ae20..2858a1056 100644 --- a/pony/orm/tests/test_json.py +++ b/pony/orm/tests/test_json.py @@ -292,6 +292,11 @@ def test_equal_str(self): p = get(p for p in self.Product if p.info['name'] == 'Apple iPad Air 2') self.assertTrue(p) + @db_session + def test_unicode_key(self): + p = get(p for p in self.Product if p.info[u'name'] == 'Apple iPad Air 2') + self.assertTrue(p) + @db_session def test_equal_string_attr(self): p = get(p for p in self.Product if p.info['name'] == p.name) From 886437751d6eb7c077eaab9db67938f4988e465e Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 23 Mar 2017 18:09:45 +0300 Subject: [PATCH 131/547] Improve BlobConverter.sql2py() --- pony/orm/dbapiprovider.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index d3e0a7128..8d7ed5122 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -619,7 +619,9 @@ def validate(converter, val): if isinstance(val, str): return buffer(val) throw(TypeError, "Attribute %r: expected type is 'buffer'. Got: %r" % (converter.attr, type(val))) def sql2py(converter, val): - if not isinstance(val, buffer): val = buffer(val) + if not isinstance(val, buffer): + try: val = buffer(val) + except: pass return val def sql_type(converter): return 'BLOB' From a941dd1fb845d79ae00e38b158da90d022ce4461 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 25 Mar 2017 01:07:22 +0300 Subject: [PATCH 132/547] Improved @raises_exception() decorator can test fragments of exception message text --- pony/orm/tests/testutils.py | 44 ++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/pony/orm/tests/testutils.py b/pony/orm/tests/testutils.py index 62cefbdc9..890cddf8d 100644 --- a/pony/orm/tests/testutils.py +++ b/pony/orm/tests/testutils.py @@ -7,38 +7,42 @@ from pony.orm.core import Database from pony.utils import import_module -def raises_exception(exc_class, msg=None): +def test_exception_msg(test_case, exc_msg, test_msg=None): + if test_msg is None: return + error_template = "incorrect exception message. expected '%s', got '%s'" + assert test_msg not in ('...', '....', '.....', '......') + if test_msg.startswith('...'): + if test_msg.endswith('...'): + test_case.assertIn(test_msg[3:-3], exc_msg, error_template % (test_msg, exc_msg)) + else: + test_case.assertTrue(exc_msg.endswith(test_msg[3:]), error_template % (test_msg, exc_msg)) + elif test_msg.endswith('...'): + test_case.assertTrue(exc_msg.startswith(test_msg[:-3]), error_template % (test_msg, exc_msg)) + else: + test_case.assertEqual(exc_msg, test_msg, error_template % (test_msg, exc_msg)) + +def raises_exception(exc_class, test_msg=None): def decorator(func): - def wrapper(self, *args, **kwargs): + def wrapper(test_case, *args, **kwargs): try: - func(self, *args, **kwargs) - self.fail("expected exception %s wasn't raised" % exc_class.__name__) + func(test_case, *args, **kwargs) + test_case.fail("Expected exception %s wasn't raised" % exc_class.__name__) except exc_class as e: - if not e.args: self.assertEqual(msg, None) - elif msg is not None: - self.assertEqual(e.args[0], msg, "incorrect exception message. expected '%s', got '%s'" % (msg, e.args[0])) + if not e.args: test_case.assertEqual(test_msg, None) + else: test_exception_msg(test_case, str(e), test_msg) wrapper.__name__ = func.__name__ return wrapper return decorator @contextmanager -def raises_if(test, cond, exc_class, exc_msg=None): +def raises_if(test_case, cond, exc_class, test_msg=None): try: yield except exc_class as e: - test.assertTrue(cond) - if exc_msg is None: pass - elif exc_msg.startswith('...') and exc_msg != '...': - if exc_msg.endswith('...'): - test.assertIn(exc_msg[3:-3], str(e)) - else: - test.assertTrue(str(e).endswith(exc_msg[3:])) - elif exc_msg.endswith('...'): - test.assertTrue(str(e).startswith(exc_msg[:-3])) - else: - test.assertEqual(str(e), exc_msg) + test_case.assertTrue(cond) + test_exception_msg(test_case, str(e), test_msg) else: - test.assertFalse(cond) + test_case.assertFalse(cond, "Expected exception %s wasn't raised" % exc_class.__name__) def flatten(x): result = [] From 9e2f7747ff7bbdfbe6b852df6aa1a80b9fd739cb Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 25 Mar 2017 01:07:45 +0300 Subject: [PATCH 133/547] Fixes #238, fixes #133: raise TransactionIntegrityError exception instead of AssertionError if obj.collection.create(**kwargs) creates duplicate object --- pony/orm/core.py | 1 - pony/orm/tests/test_collections.py | 45 ++++++++++++++++++++++++++++++ 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 85e47d295..7f226fad0 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -3062,7 +3062,6 @@ def create(wrapper, **kwargs): kwargs[reverse.name] = wrapper._obj_ item_type = attr.py_type item = item_type(**kwargs) - wrapper.add(item) return item @cut_traceback def add(wrapper, new_items): diff --git a/pony/orm/tests/test_collections.py b/pony/orm/tests/test_collections.py index 57673edec..15242f148 100644 --- a/pony/orm/tests/test_collections.py +++ b/pony/orm/tests/test_collections.py @@ -1,4 +1,5 @@ from __future__ import absolute_import, print_function, division +from pony.py23compat import PY2 import unittest @@ -23,6 +24,50 @@ def test_setwrapper_nonzero(self): def test_get_by_collection_error(self): Group.get(students=[]) + @db_session + def test_collection_create_one2many_1(self): + g = Group['3132'] + g.students.create(record=106, name='Mike', scholarship=200) + flush() + self.assertEqual(len(g.students), 3) + rollback() + + @raises_exception(TypeError, "When using Group.students.create(), " + "'group' attribute should not be passed explicitly") + @db_session + def test_collection_create_one2many_2(self): + g = Group['3132'] + g.students.create(record=106, name='Mike', scholarship=200, group=g) + + @raises_exception(TransactionIntegrityError, "Object Student[105] cannot be stored in the database...") + @db_session + def test_collection_create_one2many_3(self): + g = Group['3132'] + g.students.create(record=105, name='Mike', scholarship=200) + + @db_session + def test_collection_create_many2many_1(self): + g = Group['3132'] + g.subjects.create(name='Biology') + flush() + self.assertEqual(len(g.subjects), 3) + rollback() + + @raises_exception(TypeError, "When using Group.subjects.create(), " + "'groups' attribute should not be passed explicitly") + @db_session + def test_collection_create_many2many_2(self): + g = Group['3132'] + g.subjects.create(name='Biology', groups=[g]) + + @raises_exception(TransactionIntegrityError, + "Object Subject[u'Math'] cannot be stored in the database..." if PY2 else + "Object Subject['Math'] cannot be stored in the database...") + @db_session + def test_collection_create_many2many_3(self): + g = Group['3132'] + g.subjects.create(name='Math') + # replace collection items when the old ones are not fully loaded ##>>> from pony.examples.orm.students01.model import * ##>>> s1 = Student[101] From 9e65691ad2ec62e868287ec5a4de39e00aaf5c7a Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 31 May 2017 13:56:30 +0300 Subject: [PATCH 134/547] Fix __all__ declaration for better compatibility with PyCharm --- pony/orm/core.py | 54 +++++++++++++++++++++++------------------------- 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 7f226fad0..3ae27afab 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -29,45 +29,43 @@ from pony.utils import localbase, decorator, cut_traceback, throw, reraise, truncate_repr, get_lambda_args, \ deprecated, import_module, parse_expr, is_ident, tostring, strjoin, concat -__all__ = ''' - pony +__all__ = [ + 'pony', - DBException RowNotFound MultipleRowsFound TooManyRowsFound + 'DBException', 'RowNotFound', 'MultipleRowsFound', 'TooManyRowsFound', - Warning Error InterfaceError DatabaseError DataError OperationalError - IntegrityError InternalError ProgrammingError NotSupportedError + 'Warning', 'Error', 'InterfaceError', 'DatabaseError', 'DataError', 'OperationalError', + 'IntegrityError', 'InternalError', 'ProgrammingError', 'NotSupportedError', - OrmError ERDiagramError DBSchemaError MappingError - TableDoesNotExist TableIsNotEmpty ConstraintError CacheIndexError PermissionError - ObjectNotFound MultipleObjectsFoundError TooManyObjectsFoundError OperationWithDeletedObjectError - TransactionError ConnectionClosedError TransactionIntegrityError IsolationError CommitException RollbackException - UnrepeatableReadError OptimisticCheckError UnresolvableCyclicDependency UnexpectedError DatabaseSessionIsOver - DatabaseContainsIncorrectValue DatabaseContainsIncorrectEmptyValue + 'OrmError', 'ERDiagramError', 'DBSchemaError', 'MappingError', + 'TableDoesNotExist', 'TableIsNotEmpty', 'ConstraintError', 'CacheIndexError', + 'ObjectNotFound', 'MultipleObjectsFoundError', 'TooManyObjectsFoundError', 'OperationWithDeletedObjectError', + 'TransactionError', 'ConnectionClosedError', 'TransactionIntegrityError', 'IsolationError', + 'CommitException', 'RollbackException', 'UnrepeatableReadError', 'OptimisticCheckError', + 'UnresolvableCyclicDependency', 'UnexpectedError', 'DatabaseSessionIsOver', + 'DatabaseContainsIncorrectValue', 'DatabaseContainsIncorrectEmptyValue', + 'TranslationError', 'ExprEvalError', 'PermissionError', - TranslationError ExprEvalError + 'Database', 'sql_debug', 'show', - RowNotFound MultipleRowsFound TooManyRowsFound + 'PrimaryKey', 'Required', 'Optional', 'Set', 'Discriminator', + 'composite_key', 'composite_index', + 'flush', 'commit', 'rollback', 'db_session', 'with_transaction', - Database sql_debug show + 'LongStr', 'LongUnicode', 'Json', - PrimaryKey Required Optional Set Discriminator - composite_key composite_index - flush commit rollback db_session with_transaction + 'select', 'left_join', 'get', 'exists', 'delete', - LongStr LongUnicode Json + 'count', 'sum', 'min', 'max', 'avg', 'distinct', - select left_join get exists delete + 'JOIN', 'desc', 'concat', 'raw_sql', - count sum min max avg distinct + 'buffer', 'unicode', - JOIN desc concat raw_sql - - buffer unicode - - get_current_user set_current_user perm has_perm - get_user_groups get_user_roles get_object_labels - user_groups_getter user_roles_getter obj_labels_getter - '''.split() + 'get_current_user', 'set_current_user', 'perm', 'has_perm', + 'get_user_groups', 'get_user_roles', 'get_object_labels', + 'user_groups_getter', 'user_roles_getter', 'obj_labels_getter' +] debug = False suppress_debug_change = False From 4c65cbb839ece6291c4262d1ae7bb8641389a1a3 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 11 May 2017 18:24:22 +0300 Subject: [PATCH 135/547] Fix SQLite double lock release problem --- pony/orm/dbproviders/sqlite.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index adcf09813..0b44fbbc9 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -295,21 +295,30 @@ def set_transaction_mode(provider, connection, cache): def commit(provider, connection, cache=None): in_transaction = cache is not None and cache.in_transaction - DBAPIProvider.commit(provider, connection, cache) - if in_transaction: - provider.transaction_lock.release() + try: + DBAPIProvider.commit(provider, connection, cache) + finally: + if in_transaction: + cache.in_transaction = False + provider.transaction_lock.release() def rollback(provider, connection, cache=None): in_transaction = cache is not None and cache.in_transaction - DBAPIProvider.rollback(provider, connection, cache) - if in_transaction: - provider.transaction_lock.release() + try: + DBAPIProvider.rollback(provider, connection, cache) + finally: + if in_transaction: + cache.in_transaction = False + provider.transaction_lock.release() def drop(provider, connection, cache=None): in_transaction = cache is not None and cache.in_transaction - DBAPIProvider.drop(provider, connection, cache) - if in_transaction: - provider.transaction_lock.release() + try: + DBAPIProvider.drop(provider, connection, cache) + finally: + if in_transaction: + cache.in_transaction = False + provider.transaction_lock.release() @wrap_dbapi_exceptions def release(provider, connection, cache=None): From a78666b6be2b654b8f3f9e702851922c20f143a9 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 10 Mar 2017 18:30:32 +0300 Subject: [PATCH 136/547] Bug in obj._db_set_() method fixed --- pony/orm/core.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 3ae27afab..d1653a265 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4420,8 +4420,10 @@ def _db_set_(obj, avdict, unpickling=False): cache.db_update_composite_index(obj, attrs, prev_vals, new_vals) for attr, new_val in iteritems(avdict): - converter = attr.converters[0] - new_val = converter.dbval2val(new_val, obj) + if not attr.reverse: + assert len(attr.converters) == 1, attr + converter = attr.converters[0] + new_val = converter.dbval2val(new_val, obj) obj._vals_[attr] = new_val def _delete_(obj, undo_funcs=None): status = obj._status_ From 7cb96e6aad8ee2f703fc8a7cc9cf7d7b0503fea1 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 28 Oct 2016 13:18:06 +0300 Subject: [PATCH 137/547] Remove unused local variable --- pony/orm/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index d1653a265..981184b92 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -556,7 +556,7 @@ def _bind(self, *args, **kwargs): 'Pony no longer supports PyGreSQL module. Please use psycopg2 instead.') provider_module = import_module('pony.orm.dbproviders.' + provider) provider_cls = provider_module.provider_cls - self.provider = provider = provider_cls(*args, **kwargs) + self.provider = provider_cls(*args, **kwargs) @property def last_sql(database): return database._dblocal.last_sql From f8550534491001c03edab015eed8a156256f21df Mon Sep 17 00:00:00 2001 From: pwtail Date: Fri, 30 Jun 2017 18:46:46 +0300 Subject: [PATCH 138/547] Add a possibility to specify max_len as a keyword argument --- pony/orm/dbapiprovider.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 8d7ed5122..674630b2a 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -399,9 +399,12 @@ def __init__(converter, provider, py_type, attr=None): Converter.__init__(converter, provider, py_type, attr) def init(converter, kwargs): attr = converter.attr - if not attr.args: max_len = None - elif len(attr.args) > 1: unexpected_args(attr, attr.args[1:]) - else: max_len = attr.args[0] + max_len = attr.kwargs.pop('max_len', None) + if len(attr.args) > 1: unexpected_args(attr, attr.args[1:]) + elif attr.args: + if max_len is not None: throw(TypeError, + 'Max length option specified twice: as a positional argument and as a `max_len` named argument') + max_len = attr.args[0] if issubclass(attr.py_type, (LongStr, LongUnicode)): if max_len is not None: throw(TypeError, 'Max length is not supported for CLOBs') elif max_len is None: max_len = converter.provider.varchar_default_max_len From 2d3afb24bc4ff1ffdd09d841fc07637c882d4661 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 17 Jul 2017 13:10:32 +0300 Subject: [PATCH 139/547] More aggressive cache clearing --- pony/orm/core.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 981184b92..54e17675b 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1554,6 +1554,8 @@ def commit(cache): assert cache.connection is not None cache.database.provider.commit(cache.connection, cache) cache.for_update.clear() + cache.query_results.clear() + cache.max_id_cache.clear() cache.immediate = True except: cache.rollback() @@ -1572,21 +1574,24 @@ def close(cache, rollback=True): connection = cache.connection if connection is None: return cache.connection = None - if rollback: - try: provider.rollback(connection, cache) - except: - provider.drop(connection, cache) - raise - provider.release(connection, cache) - db_session = cache.db_session or local.db_session - if db_session and db_session.strict: - cache.clear() - def clear(cache): - for obj in cache.objects: - obj._vals_ = obj._dbvals_ = obj._session_cache_ = None - cache.objects = cache.indexes = cache.seeds = cache.for_update = cache.modified_collections \ - = cache.objects_to_save = cache.saved_objects = cache.query_results \ - = cache.perm_cache = cache.user_roles_cache = cache.obj_labels_cache = None + + try: + if rollback: + try: provider.rollback(connection, cache) + except: + provider.drop(connection, cache) + raise + provider.release(connection, cache) + finally: + cache.objects = cache.objects_to_save = cache.saved_objects = cache.query_results \ + = cache.indexes = cache.seeds = cache.for_update = cache.max_id_cache \ + = cache.modified_collections = cache.collection_statistics = None + + db_session = cache.db_session or local.db_session + if db_session and db_session.strict: + for obj in cache.objects: + obj._vals_ = obj._dbvals_ = obj._session_cache_ = None + cache.perm_cache = cache.user_roles_cache = cache.obj_labels_cache = None @contextmanager def flush_disabled(cache): cache.noflush_counter += 1 From 6495daad813805e08ee73db065b693755713d5ae Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 17 Jul 2017 13:11:52 +0300 Subject: [PATCH 140/547] Fixes #276: Memory leak in get_lambda_args() --- pony/orm/decompiling.py | 7 ++----- pony/utils/utils.py | 23 ++++++++++++++++++++--- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index 8fa9ae700..8a580a15a 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -7,15 +7,13 @@ from pony.thirdparty.compiler import ast, parse -from pony.utils import throw +from pony.utils import throw, get_codeobject_id ##ast.And.__repr__ = lambda self: "And(%s: %s)" % (getattr(self, 'endpos', '?'), repr(self.nodes),) ##ast.Or.__repr__ = lambda self: "Or(%s: %s)" % (getattr(self, 'endpos', '?'), repr(self.nodes),) ast_cache = {} -codeobjects = {} - def decompile(x): cells = {} t = type(x) @@ -28,10 +26,9 @@ def decompile(x): else: if x.__closure__: cells = dict(izip(codeobject.co_freevars, x.__closure__)) else: throw(TypeError) - key = id(codeobject) + key = get_codeobject_id(codeobject) result = ast_cache.get(key) if result is None: - codeobjects[key] = codeobject decompiler = Decompiler(codeobject) result = decompiler.ast, decompiler.external_names ast_cache[key] = result diff --git a/pony/utils/utils.py b/pony/utils/utils.py index 03ef1e040..7182c6854 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -126,11 +126,27 @@ def truncate_repr(s, max_len=100): s = repr(s) return s if len(s) <= max_len else s[:max_len-3] + '...' +codeobjects = {} + +def get_codeobject_id(codeobject): + codeobject_id = id(codeobject) + if codeobject_id not in codeobjects: + codeobjects[codeobject_id] = codeobject + return codeobject_id + lambda_args_cache = {} def get_lambda_args(func): - names = lambda_args_cache.get(func) + if type(func) is types.FunctionType: + codeobject = func.func_code if PY2 else func.__code__ + cache_key = get_codeobject_id(codeobject) + elif isinstance(func, ast.Lambda): + cache_key = func + else: assert False # pragma: no cover + + names = lambda_args_cache.get(cache_key) if names is not None: return names + if type(func) is types.FunctionType: if hasattr(inspect, 'signature'): names, argsname, kwname, defaults = [], None, None, None @@ -162,7 +178,8 @@ def get_lambda_args(func): if argsname: throw(TypeError, '*%s is not supported' % argsname) if kwname: throw(TypeError, '**%s is not supported' % kwname) if defaults: throw(TypeError, 'Defaults are not supported') - lambda_args_cache[func] = names + + lambda_args_cache[cache_key] = names return names _cache = {} @@ -493,4 +510,4 @@ def concat(*args): return ''.join(tostring(arg) for arg in args) def is_utf8(encoding): - return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') + return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') From df8f8dd97c15bc24fd0978c045dcc4bcd5994a43 Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Mon, 17 Jul 2017 17:45:08 -0400 Subject: [PATCH 141/547] Update changelog and version for 0.7.2 release --- CHANGELOG.md | 19 +++++++++++++++++++ pony/__init__.py | 2 +- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f2619b064..bc3375323 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,22 @@ +# Pony ORM Release 0.7.2 (2017-07-17) + +## New features + +* All arguments of db.bind() can be specified as keyword arguments. Previously Pony required the first positional argument which specified the database provider. Now you can pass all the database parameters using the dict: db.bind(**db_params). See https://docs.ponyorm.com/api_reference.html#Database.bind +* The `optimistic` attribute option is added https://docs.ponyorm.com/api_reference.html#cmdoption-arg-optimistic + +## Bugfixes + +* Fixes #219: when a database driver raises an error, sometimes this error was masked by the 'RollbackException: InterfaceError: connection already closed' exception. This happened because on error, Pony tried to rollback transaction, but the connection to the database was already closed and it masked the initial error. Now Pony displays the original error which helps to understand the cause of the problem. +* Fixes #276: Memory leak +* Fixes the __all__ declaration. Previously IDEs, such as PyCharm, could not understand what is going to be imported by 'from pony.orm import *'. Now it works fine. +* Fixes #232: negate check for numeric expressions now checks if value is zero or NULL +* Fixes #238, fixes #133: raise TransactionIntegrityError exception instead of AssertionError if obj.collection.create(**kwargs) creates a duplicate object +* Fixes #221: issue with unicode json path keys +* Fixes bug when discriminator column is used as a part of a primary key +* Handle situation when SQLite blob column contains non-binary value + + # Pony ORM Release 0.7.1 (2017-01-10) ## New features diff --git a/pony/__init__.py b/pony/__init__.py index c6307af36..340bee89d 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.2-dev' +__version__ = '0.7.2' uid = str(random.randint(1, 1000000)) From d34147642795e4a03d1b3992b3267522a5b391b8 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 18 Jul 2017 00:53:01 +0300 Subject: [PATCH 142/547] Change Pony version: 0.7.2 -> 0.7.3-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index 340bee89d..99ca019c1 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.2' +__version__ = '0.7.3-dev' uid = str(random.randint(1, 1000000)) From 8c9ee893e4a69ba84ded90576182629b601a7d40 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 18 Aug 2017 14:38:08 +0300 Subject: [PATCH 143/547] Fix db_session(strict=True) which was broken in 2d3afb24 --- pony/orm/core.py | 7 +++---- pony/orm/tests/test_db_session.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 54e17675b..2535d203b 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1583,15 +1583,14 @@ def close(cache, rollback=True): raise provider.release(connection, cache) finally: - cache.objects = cache.objects_to_save = cache.saved_objects = cache.query_results \ - = cache.indexes = cache.seeds = cache.for_update = cache.max_id_cache \ - = cache.modified_collections = cache.collection_statistics = None - db_session = cache.db_session or local.db_session if db_session and db_session.strict: for obj in cache.objects: obj._vals_ = obj._dbvals_ = obj._session_cache_ = None cache.perm_cache = cache.user_roles_cache = cache.obj_labels_cache = None + cache.objects = cache.objects_to_save = cache.saved_objects = cache.query_results \ + = cache.indexes = cache.seeds = cache.for_update = cache.max_id_cache \ + = cache.modified_collections = cache.collection_statistics = None @contextmanager def flush_disabled(cache): cache.noflush_counter += 1 diff --git a/pony/orm/tests/test_db_session.py b/pony/orm/tests/test_db_session.py index 7bf2c1160..8145d7c3f 100644 --- a/pony/orm/tests/test_db_session.py +++ b/pony/orm/tests/test_db_session.py @@ -435,5 +435,15 @@ def test12(self): s1 = Student[1] s1.set(name='New name') + def test_db_session_strict_1(self): + with db_session(strict=True): + s1 = Student[1] + + @raises_exception(DatabaseSessionIsOver, 'Cannot read value of Student[1].name: the database session is over') + def test_db_session_strict_2(self): + with db_session(strict=True): + s1 = Student[1] + name = s1.name + if __name__ == '__main__': unittest.main() From 69f82bb2dfa65b738fbb17a4a7bcc0fc8226bfbc Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 18 Aug 2017 15:29:48 +0300 Subject: [PATCH 144/547] Remove unnecessary code --- pony/orm/core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 2535d203b..c77aaedd1 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2100,7 +2100,6 @@ def db_set(attr, obj, new_dbval, is_reverse_call=False): old_val = obj._vals_.get(attr, NOT_LOADED) assert old_val == old_dbval, (old_val, old_dbval) if attr.is_part_of_unique_index: - cache = obj._session_cache_ if attr.is_unique: cache.db_update_simple_index(obj, attr, old_val, new_dbval) get_val = obj._vals_.get for attrs, i in attr.composite_keys: @@ -4815,8 +4814,6 @@ def _save_deleted_(obj): obj._status_ = 'deleted' cache.indexes[obj._pk_attrs_].pop(obj._pkval_) def _save_(obj, dependent_objects=None): - cache = obj._session_cache_ - assert cache.is_alive status = obj._status_ if status in ('created', 'modified'): @@ -4829,6 +4826,7 @@ def _save_(obj, dependent_objects=None): assert obj._status_ in saved_statuses cache = obj._session_cache_ + assert cache.is_alive cache.saved_objects.append((obj, obj._status_)) objects_to_save = cache.objects_to_save save_pos = obj._save_pos_ From 295840e8eb28a31d62fb15d0e60a30f92c311e29 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 18 Aug 2017 15:47:18 +0300 Subject: [PATCH 145/547] Collection.count() method should check if session is alive --- pony/orm/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pony/orm/core.py b/pony/orm/core.py index c77aaedd1..e06c255ef 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2980,6 +2980,7 @@ def count(wrapper): setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() elif setdata.count is not None: return setdata.count + if not cache.is_alive: throw_db_session_is_over(obj, attr) entity = attr.entity reverse = attr.reverse database = entity._database_ From 60d85aab65d2076ea4bf94f35d508035c36c1ca1 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 18 Aug 2017 16:16:04 +0300 Subject: [PATCH 146/547] Set obj._session_cache_ to None after exiting from db_session for better garbage collection --- pony/orm/core.py | 71 ++++++++++++++++------------ pony/orm/tests/test_core_multiset.py | 2 +- 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index e06c255ef..dba384cf8 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1017,7 +1017,8 @@ def user_has_no_rights_to_see(obj, attr=None): caches = set() def obj_converter(obj): if not isinstance(obj, Entity): return converter(obj) - caches.add(obj._session_cache_) + cache = obj._session_cache_ + if cache is not None: caches.add(cache) if len(caches) > 1: throw(TransactionError, 'An attempt to serialize objects belonging to different transactions') if not can_view(user, obj): @@ -1584,10 +1585,15 @@ def close(cache, rollback=True): provider.release(connection, cache) finally: db_session = cache.db_session or local.db_session - if db_session and db_session.strict: - for obj in cache.objects: - obj._vals_ = obj._dbvals_ = obj._session_cache_ = None - cache.perm_cache = cache.user_roles_cache = cache.obj_labels_cache = None + if db_session: + if db_session.strict: + for obj in cache.objects: + obj._vals_ = obj._dbvals_ = obj._session_cache_ = None + cache.perm_cache = cache.user_roles_cache = cache.obj_labels_cache = None + else: + for obj in cache.objects: + obj._dbvals_ = obj._session_cache_ = None + cache.objects = cache.objects_to_save = cache.saved_objects = cache.query_results \ = cache.indexes = cache.seeds = cache.for_update = cache.max_id_cache \ = cache.modified_collections = cache.collection_statistics = None @@ -1935,7 +1941,8 @@ def parse_value(attr, row, offsets): else: val = attr.py_type._get_by_raw_pkval_(vals) return val def load(attr, obj): - if not obj._session_cache_.is_alive: throw(DatabaseSessionIsOver, + cache = obj._session_cache_ + if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, 'Cannot load attribute %s.%s: the database session is over' % (safe_repr(obj), attr.name)) if not attr.columns: reverse = attr.reverse @@ -1989,7 +1996,7 @@ def get(attr, obj): @cut_traceback def __set__(attr, obj, new_val, undo_funcs=None): cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, + if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, 'Cannot assign new value to attribute %s.%s: the database session is over' % (safe_repr(obj), attr.name)) if obj._status_ in del_statuses: throw_object_was_deleted(obj) reverse = attr.reverse @@ -2073,7 +2080,7 @@ def undo_func(): raise def db_set(attr, obj, new_dbval, is_reverse_call=False): cache = obj._session_cache_ - assert cache.is_alive + assert cache is not None and cache.is_alive assert obj._status_ not in created_or_deleted_statuses assert attr.pk_offset is None if new_dbval is NOT_LOADED: assert is_reverse_call @@ -2503,7 +2510,7 @@ def validate(attr, val, obj=None, entity=None, from_db=False): return items def load(attr, obj, items=None): cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, + if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, 'Cannot load collection %s.%s: the database session is over' % (safe_repr(obj), attr.name)) assert obj._status_ not in del_statuses setdata = obj._vals_.get(attr) @@ -2665,7 +2672,7 @@ def __set__(attr, obj, new_items, undo_funcs=None): if isinstance(new_items, SetInstance) and new_items._obj_ is obj and new_items._attr_ is attr: return # after += or -= cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, + if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr)) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): @@ -2901,7 +2908,8 @@ def __repr__(wrapper): return '<%s %r.%s>' % (wrapper.__class__.__name__, wrapper._obj_, wrapper._attr_.name) @cut_traceback def __str__(wrapper): - if not wrapper._obj_._session_cache_.is_alive: content = '...' + cache = wrapper._obj_._session_cache_ + if cache is None or not cache.is_alive: content = '...' else: content = ', '.join(imap(str, wrapper)) return '%s([%s])' % (wrapper.__class__.__name__, content) @cut_traceback @@ -2980,7 +2988,7 @@ def count(wrapper): setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() elif setdata.count is not None: return setdata.count - if not cache.is_alive: throw_db_session_is_over(obj, attr) + if cache is None or not cache.is_alive: throw_db_session_is_over(obj, attr) entity = attr.entity reverse = attr.reverse database = entity._database_ @@ -3070,7 +3078,7 @@ def add(wrapper, new_items): obj = wrapper._obj_ attr = wrapper._attr_ cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, + if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr)) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): @@ -3110,7 +3118,7 @@ def remove(wrapper, items): obj = wrapper._obj_ attr = wrapper._attr_ cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, + if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr)) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): @@ -3153,7 +3161,8 @@ def __isub__(wrapper, items): def clear(wrapper): obj = wrapper._obj_ attr = wrapper._attr_ - if not obj._session_cache_.is_alive: throw(DatabaseSessionIsOver, + cache = obj._session_cache_ + if cache is None or not obj._session_cache_.is_alive: throw(DatabaseSessionIsOver, 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr)) if obj._status_ in del_statuses: throw_object_was_deleted(obj) attr.__set__(obj, ()) @@ -3210,7 +3219,8 @@ def distinct(multiset): return multiset._items_.copy() @cut_traceback def __repr__(multiset): - if multiset._obj_._session_cache_.is_alive: + cache = multiset._obj_._session_cache_ + if cache is not None and cache.is_alive: size = builtins.sum(itervalues(multiset._items_)) if size == 1: size_str = ' (1 item)' else: size_str = ' (%d items)' % size @@ -4272,7 +4282,7 @@ def __repr__(obj): return '%s[%s]' % (obj.__class__.__name__, pkval) def _load_(obj): cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, + if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, 'Cannot load object %s: the database session is over' % safe_repr(obj)) entity = obj.__class__ database = entity._database_ @@ -4294,7 +4304,7 @@ def _load_(obj): @cut_traceback def load(obj, *attrs): cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, + if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, 'Cannot load object %s: the database session is over' % safe_repr(obj)) entity = obj.__class__ database = entity._database_ @@ -4355,12 +4365,9 @@ def load(obj, *attrs): 'Phantom object %s disappeared' % safe_repr(obj)) def _attr_changed_(obj, attr): cache = obj._session_cache_ - if not cache.is_alive: throw( - DatabaseSessionIsOver, - 'Cannot assign new value to attribute %s.%s: the database session' - ' is over' % (safe_repr(obj), attr.name)) - if obj._status_ in del_statuses: - throw_object_was_deleted(obj) + if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, + 'Cannot assign new value to attribute %s.%s: the database session is over' % (safe_repr(obj), attr.name)) + if obj._status_ in del_statuses: throw_object_was_deleted(obj) status = obj._status_ wbits = obj._wbits_ bit = obj._bits_[attr] @@ -4377,7 +4384,7 @@ def _attr_changed_(obj, attr): def _db_set_(obj, avdict, unpickling=False): assert obj._status_ not in created_or_deleted_statuses cache = obj._session_cache_ - assert cache.is_alive + assert cache is not None and cache.is_alive cache.seeds[obj._pk_attrs_].discard(obj) if not avdict: return @@ -4435,6 +4442,7 @@ def _delete_(obj, undo_funcs=None): is_recursive_call = undo_funcs is not None if not is_recursive_call: undo_funcs = [] cache = obj._session_cache_ + assert cache is not None and cache.is_alive with cache.flush_disabled(): get_val = obj._vals_.get undo_list = [] @@ -4528,13 +4536,14 @@ def undo_func(): raise @cut_traceback def delete(obj): - if not obj._session_cache_.is_alive: throw(DatabaseSessionIsOver, + cache = obj._session_cache_ + if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, 'Cannot delete object %s: the database session is over' % safe_repr(obj)) obj._delete_() @cut_traceback def set(obj, **kwargs): cache = obj._session_cache_ - if not cache.is_alive: throw(DatabaseSessionIsOver, + if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, 'Cannot change object %s: the database session is over' % safe_repr(obj)) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): @@ -4816,7 +4825,6 @@ def _save_deleted_(obj): cache.indexes[obj._pk_attrs_].pop(obj._pkval_) def _save_(obj, dependent_objects=None): status = obj._status_ - if status in ('created', 'modified'): obj._save_principal_objects_(dependent_objects) @@ -4827,7 +4835,7 @@ def _save_(obj, dependent_objects=None): assert obj._status_ in saved_statuses cache = obj._session_cache_ - assert cache.is_alive + assert cache is not None and cache.is_alive cache.saved_objects.append((obj, obj._status_)) objects_to_save = cache.objects_to_save save_pos = obj._save_pos_ @@ -4842,7 +4850,7 @@ def flush(obj): assert obj._save_pos_ is not None, 'save_pos is None for %s object' % obj._status_ cache = obj._session_cache_ - assert not cache.saved_objects + assert cache is not None and cache.is_alive and not cache.saved_objects with cache.flush_disabled(): obj._before_save_() # should be inside flush_disabled to prevent infinite recursion # TODO: add to documentation that flush is disabled inside before_xxx hooks @@ -4871,7 +4879,8 @@ def after_delete(obj): pass @cut_traceback def to_dict(obj, only=None, exclude=None, with_collections=False, with_lazy=False, related_objects=False): - if obj._session_cache_.modified: obj._session_cache_.flush() + cache = obj._session_cache_ + if cache is not None and cache.is_alive and cache.modified: cache.flush() attrs = obj.__class__._get_attrs_(only, exclude, with_collections, with_lazy) result = {} for attr in attrs: diff --git a/pony/orm/tests/test_core_multiset.py b/pony/orm/tests/test_core_multiset.py index d1374cfa4..278e33df9 100644 --- a/pony/orm/tests/test_core_multiset.py +++ b/pony/orm/tests/test_core_multiset.py @@ -77,7 +77,7 @@ def test_multiset_repr_4(self): with db_session: g = Group[101] multiset = g.students.courses - self.assertEqual(multiset._obj_._session_cache_.is_alive, False) + self.assertIsNone(multiset._obj_._session_cache_) self.assertEqual(repr(multiset), "") @db_session From 681fd24a94c8f0cf1a2a325939dff695524cf21b Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 18 Aug 2017 16:18:50 +0300 Subject: [PATCH 147/547] Unload collections which are not fully loaded after exiting from db session for better garbage collection --- pony/orm/core.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pony/orm/core.py b/pony/orm/core.py index dba384cf8..b1d9f8bc0 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1593,6 +1593,9 @@ def close(cache, rollback=True): else: for obj in cache.objects: obj._dbvals_ = obj._session_cache_ = None + for attr, setdata in iteritems(obj._vals_): + if attr.is_collection: + if not setdata.is_fully_loaded: obj._vals_[attr] = None cache.objects = cache.objects_to_save = cache.saved_objects = cache.query_results \ = cache.indexes = cache.seeds = cache.for_update = cache.max_id_cache \ From 1cd89c8ce0d451f45db5c34c5aa43c50c140d043 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 18 Aug 2017 17:22:19 +0300 Subject: [PATCH 148/547] Use throw_db_session_is_over() everywhere --- pony/orm/core.py | 58 ++++++++++++------------------- pony/orm/tests/test_db_session.py | 10 +++--- 2 files changed, 28 insertions(+), 40 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index b1d9f8bc0..0a418b2da 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -495,9 +495,9 @@ def wrapped_interact(iterator, input=None, exc_info=None): db_session = DBSessionContextManager() -def throw_db_session_is_over(obj, attr): - throw(DatabaseSessionIsOver, 'Cannot read value of %s.%s: the database session is over' - % (safe_repr(obj), attr.name)) +def throw_db_session_is_over(action, obj, attr=None): + msg = 'Cannot %s %s%s: the database session is over' + throw(DatabaseSessionIsOver, msg % (action, safe_repr(obj), '.%s' % attr.name if attr else '')) def with_transaction(*args, **kwargs): deprecated(3, "@with_transaction decorator is deprecated, use @db_session decorator instead") @@ -1945,8 +1945,7 @@ def parse_value(attr, row, offsets): return val def load(attr, obj): cache = obj._session_cache_ - if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot load attribute %s.%s: the database session is over' % (safe_repr(obj), attr.name)) + if cache is None or not cache.is_alive: throw_db_session_is_over('load attribute', obj, attr) if not attr.columns: reverse = attr.reverse assert reverse is not None and reverse.columns @@ -1990,7 +1989,7 @@ def get(attr, obj): if attr.pk_offset is None and obj._status_ in ('deleted', 'cancelled'): throw_object_was_deleted(obj) vals = obj._vals_ - if vals is None: throw_db_session_is_over(obj, attr) + if vals is None: throw_db_session_is_over('read value of', obj, attr) val = vals[attr] if attr in vals else attr.load(obj) if val is not None and attr.reverse and val._subclasses_ and val._status_ not in ('deleted', 'cancelled'): seeds = obj._session_cache_.seeds[val._pk_attrs_] @@ -1999,8 +1998,7 @@ def get(attr, obj): @cut_traceback def __set__(attr, obj, new_val, undo_funcs=None): cache = obj._session_cache_ - if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot assign new value to attribute %s.%s: the database session is over' % (safe_repr(obj), attr.name)) + if cache is None or not cache.is_alive: throw_db_session_is_over('assign new value to', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) reverse = attr.reverse new_val = attr.validate(new_val, obj, from_db=False) @@ -2513,8 +2511,7 @@ def validate(attr, val, obj=None, entity=None, from_db=False): return items def load(attr, obj, items=None): cache = obj._session_cache_ - if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot load collection %s.%s: the database session is over' % (safe_repr(obj), attr.name)) + if cache is None or not cache.is_alive: throw_db_session_is_over('load collection', obj, attr) assert obj._status_ not in del_statuses setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() @@ -2651,7 +2648,7 @@ def construct_sql_m2m(attr, batch_size=1, items_count=0): return sql, adapter def copy(attr, obj): if obj._status_ in del_statuses: throw_object_was_deleted(obj) - if obj._vals_ is None: throw_db_session_is_over(obj, attr) + if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None or not setdata.is_fully_loaded: setdata = attr.load(obj) reverse = attr.reverse @@ -2675,8 +2672,7 @@ def __set__(attr, obj, new_items, undo_funcs=None): if isinstance(new_items, SetInstance) and new_items._obj_ is obj and new_items._attr_ is attr: return # after += or -= cache = obj._session_cache_ - if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr)) + if cache is None or not cache.is_alive: throw_db_session_is_over('change collection', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): new_items = attr.validate(new_items, obj) @@ -2920,7 +2916,7 @@ def __nonzero__(wrapper): attr = wrapper._attr_ obj = wrapper._obj_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) - if obj._vals_ is None: throw_db_session_is_over(obj, attr) + if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None: setdata = attr.load(obj) if setdata: return True @@ -2931,7 +2927,7 @@ def is_empty(wrapper): attr = wrapper._attr_ obj = wrapper._obj_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) - if obj._vals_ is None: throw_db_session_is_over(obj, attr) + if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() elif setdata.is_fully_loaded: return not setdata @@ -2977,7 +2973,7 @@ def __len__(wrapper): attr = wrapper._attr_ obj = wrapper._obj_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) - if obj._vals_ is None: throw_db_session_is_over(obj, attr) + if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None or not setdata.is_fully_loaded: setdata = attr.load(obj) return len(setdata) @@ -2987,11 +2983,11 @@ def count(wrapper): obj = wrapper._obj_ cache = obj._session_cache_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) - if obj._vals_ is None: throw_db_session_is_over(obj, attr) + if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) setdata = obj._vals_.get(attr) if setdata is None: setdata = obj._vals_[attr] = SetData() elif setdata.count is not None: return setdata.count - if cache is None or not cache.is_alive: throw_db_session_is_over(obj, attr) + if cache is None or not cache.is_alive: throw_db_session_is_over('read value of', obj, attr) entity = attr.entity reverse = attr.reverse database = entity._database_ @@ -3039,7 +3035,7 @@ def __contains__(wrapper, item): attr = wrapper._attr_ obj = wrapper._obj_ if obj._status_ in del_statuses: throw_object_was_deleted(obj) - if obj._vals_ is None: throw_db_session_is_over(obj, attr) + if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) if not isinstance(item, attr.py_type): return False reverse = attr.reverse @@ -3081,8 +3077,7 @@ def add(wrapper, new_items): obj = wrapper._obj_ attr = wrapper._attr_ cache = obj._session_cache_ - if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr)) + if cache is None or not cache.is_alive: throw_db_session_is_over('change collection', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): reverse = attr.reverse @@ -3121,8 +3116,7 @@ def remove(wrapper, items): obj = wrapper._obj_ attr = wrapper._attr_ cache = obj._session_cache_ - if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr)) + if cache is None or not cache.is_alive: throw_db_session_is_over('change collection', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): reverse = attr.reverse @@ -3165,8 +3159,7 @@ def clear(wrapper): obj = wrapper._obj_ attr = wrapper._attr_ cache = obj._session_cache_ - if cache is None or not obj._session_cache_.is_alive: throw(DatabaseSessionIsOver, - 'Cannot change collection %s.%s: the database session is over' % (safe_repr(obj), attr)) + if cache is None or not obj._session_cache_.is_alive: throw_db_session_is_over('change collection', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) attr.__set__(obj, ()) @cut_traceback @@ -4285,8 +4278,7 @@ def __repr__(obj): return '%s[%s]' % (obj.__class__.__name__, pkval) def _load_(obj): cache = obj._session_cache_ - if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot load object %s: the database session is over' % safe_repr(obj)) + if cache is None or not cache.is_alive: throw_db_session_is_over('load object', obj) entity = obj.__class__ database = entity._database_ if cache is not database._get_cache(): @@ -4307,8 +4299,7 @@ def _load_(obj): @cut_traceback def load(obj, *attrs): cache = obj._session_cache_ - if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot load object %s: the database session is over' % safe_repr(obj)) + if cache is None or not cache.is_alive: throw_db_session_is_over('load object', obj) entity = obj.__class__ database = entity._database_ if cache is not database._get_cache(): @@ -4368,8 +4359,7 @@ def load(obj, *attrs): 'Phantom object %s disappeared' % safe_repr(obj)) def _attr_changed_(obj, attr): cache = obj._session_cache_ - if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot assign new value to attribute %s.%s: the database session is over' % (safe_repr(obj), attr.name)) + if cache is None or not cache.is_alive: throw_db_session_is_over('assign new value to', obj, attr) if obj._status_ in del_statuses: throw_object_was_deleted(obj) status = obj._status_ wbits = obj._wbits_ @@ -4540,14 +4530,12 @@ def undo_func(): @cut_traceback def delete(obj): cache = obj._session_cache_ - if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot delete object %s: the database session is over' % safe_repr(obj)) + if cache is None or not cache.is_alive: throw_db_session_is_over('delete object', obj) obj._delete_() @cut_traceback def set(obj, **kwargs): cache = obj._session_cache_ - if cache is None or not cache.is_alive: throw(DatabaseSessionIsOver, - 'Cannot change object %s: the database session is over' % safe_repr(obj)) + if cache is None or not cache.is_alive: throw_db_session_is_over('change object', obj) if obj._status_ in del_statuses: throw_object_was_deleted(obj) with cache.flush_disabled(): avdict, collection_avdict = obj._keyargs_to_avdicts_(kwargs) diff --git a/pony/orm/tests/test_db_session.py b/pony/orm/tests/test_db_session.py index 8145d7c3f..fc939501c 100644 --- a/pony/orm/tests/test_db_session.py +++ b/pony/orm/tests/test_db_session.py @@ -379,7 +379,7 @@ def test3(self): group_id = s1.group.id major = s1.group.major - @raises_exception(DatabaseSessionIsOver, 'Cannot assign new value to attribute Student[1].name: the database session is over') + @raises_exception(DatabaseSessionIsOver, 'Cannot assign new value to Student[1].name: the database session is over') def test4(self): with db_session: s1 = Student[1] @@ -396,28 +396,28 @@ def test6(self): g1 = Group[1] l = len(g1.students) - @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].Group.students: the database session is over') + @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].students: the database session is over') def test7(self): with db_session: s1 = Student[1] g1 = Group[1] g1.students.remove(s1) - @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].Group.students: the database session is over') + @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].students: the database session is over') def test8(self): with db_session: g2_students = Group[2].students g1 = Group[1] g1.students = g2_students - @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].Group.students: the database session is over') + @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].students: the database session is over') def test9(self): with db_session: s3 = Student[3] g1 = Group[1] g1.students.add(s3) - @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].Group.students: the database session is over') + @raises_exception(DatabaseSessionIsOver, 'Cannot change collection Group[1].students: the database session is over') def test10(self): with db_session: g1 = Group[1] From 7a220a306567a256d604d9bfaf7e13a9e5ff52b2 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 4 Aug 2017 15:59:12 +0300 Subject: [PATCH 149/547] Allow db_session to accept `ddl` parameter when used as context manager --- pony/orm/core.py | 4 ++-- pony/orm/tests/test_db_session.py | 21 +++++++++++++++++++-- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 0a418b2da..b6efdd88d 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -369,13 +369,13 @@ def __call__(db_session, *args, **kwargs): def __enter__(db_session): if db_session.retry is not 0: throw(TypeError, "@db_session can accept 'retry' parameter only when used as decorator and not as context manager") - if db_session.ddl: throw(TypeError, - "@db_session can accept 'ddl' parameter only when used as decorator and not as context manager") db_session._enter() def _enter(db_session): if local.db_session is None: assert not local.db_context_counter local.db_session = db_session + elif db_session.ddl and not local.db_session.ddl: throw(TransactionError, + 'Cannot start ddl transaction inside non-ddl transaction') elif db_session.serializable and not local.db_session.serializable: throw(TransactionError, 'Cannot start serializable transaction inside non-serializable transaction') local.db_context_counter += 1 diff --git a/pony/orm/tests/test_db_session.py b/pony/orm/tests/test_db_session.py index fc939501c..05e538208 100644 --- a/pony/orm/tests/test_db_session.py +++ b/pony/orm/tests/test_db_session.py @@ -276,12 +276,29 @@ def test_db_session_manager_4(self): else: self.fail() - @raises_exception(TypeError, "@db_session can accept 'ddl' parameter " - "only when used as decorator and not as context manager") + # restriction removed in 0.7.3: + # @raises_exception(TypeError, "@db_session can accept 'ddl' parameter " + # "only when used as decorator and not as context manager") def test_db_session_ddl_1(self): with db_session(ddl=True): pass + def test_db_session_ddl_1a(self): + with db_session(ddl=True): + with db_session(ddl=True): + pass + + def test_db_session_ddl_1b(self): + with db_session(ddl=True): + with db_session: + pass + + @raises_exception(TransactionError, 'Cannot start ddl transaction inside non-ddl transaction') + def test_db_session_ddl_1c(self): + with db_session: + with db_session(ddl=True): + pass + @raises_exception(TransactionError, "test() cannot be called inside of db_session") def test_db_session_ddl_2(self): @db_session(ddl=True) From cf4aede69c9613374c7b5cabec53cc9eea0bff4e Mon Sep 17 00:00:00 2001 From: pwtail Date: Fri, 4 Aug 2017 17:20:00 +0300 Subject: [PATCH 150/547] fix: max_len --- pony/orm/dbapiprovider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 674630b2a..30f6846f9 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -399,7 +399,7 @@ def __init__(converter, provider, py_type, attr=None): Converter.__init__(converter, provider, py_type, attr) def init(converter, kwargs): attr = converter.attr - max_len = attr.kwargs.pop('max_len', None) + max_len = kwargs.pop('max_len', None) if len(attr.args) > 1: unexpected_args(attr, attr.args[1:]) elif attr.args: if max_len is not None: throw(TypeError, From 5e3ebedde967794513e789ca5dcd9590b5061abc Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 2 Aug 2017 13:10:24 +0300 Subject: [PATCH 151/547] Improve error message --- pony/orm/core.py | 14 +++++++++----- pony/orm/tests/test_relations_one2one2.py | 3 ++- pony/orm/tests/test_relations_symmetric_one2one.py | 3 ++- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index b6efdd88d..d729a70ca 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2094,11 +2094,15 @@ def db_set(attr, obj, new_dbval, is_reverse_call=False): bit = obj._bits_except_volatile_[attr] if obj._rbits_ & bit: assert old_dbval is not NOT_LOADED - if new_dbval is NOT_LOADED: diff = '' - else: diff = ' (was: %s, now: %s)' % (old_dbval, new_dbval) - throw(UnrepeatableReadError, - 'Value of %s.%s for %s was updated outside of current transaction%s' - % (obj.__class__.__name__, attr.name, obj, diff)) + msg = 'Value of %s for %s was updated outside of current transaction' % (attr, obj) + if new_dbval is not NOT_LOADED: + msg = '%s (was: %s, now: %s)' % (msg, old_dbval, new_dbval) + elif isinstance(attr.reverse, Optional): + assert old_dbval is not None + msg = "Multiple %s objects linked with the same %s object. " \ + "Maybe %s attribute should be Set instead of Optional" \ + % (attr.entity.__name__, old_dbval, attr.reverse) + throw(UnrepeatableReadError, msg) if new_dbval is NOT_LOADED: obj._dbvals_.pop(attr, None) else: obj._dbvals_[attr] = new_dbval diff --git a/pony/orm/tests/test_relations_one2one2.py b/pony/orm/tests/test_relations_one2one2.py index 4add12bb8..c3f5f303b 100644 --- a/pony/orm/tests/test_relations_one2one2.py +++ b/pony/orm/tests/test_relations_one2one2.py @@ -121,7 +121,8 @@ def test_8(self): self.assertEqual([2, None, None], wives) husbands = db.select('husband from female order by female.id') self.assertEqual([None, 1, None], husbands) - @raises_exception(UnrepeatableReadError, 'Value of Male.wife for Male[1] was updated outside of current transaction') + @raises_exception(UnrepeatableReadError, 'Multiple Male objects linked with the same Female[1] object. ' + 'Maybe Female.husband attribute should be Set instead of Optional') def test_9(self): db.execute('update female set husband = 3 where id = 1') m1 = Male[1] diff --git a/pony/orm/tests/test_relations_symmetric_one2one.py b/pony/orm/tests/test_relations_symmetric_one2one.py index e07b920bd..47d1bad53 100644 --- a/pony/orm/tests/test_relations_symmetric_one2one.py +++ b/pony/orm/tests/test_relations_symmetric_one2one.py @@ -64,7 +64,8 @@ def test3(self): def test4(self): persons = set(select(p for p in Person if p.spouse.name in ('B', 'D'))) self.assertEqual(persons, {Person[1], Person[3]}) - @raises_exception(UnrepeatableReadError, 'Value of Person.spouse for Person[1] was updated outside of current transaction') + @raises_exception(UnrepeatableReadError, 'Multiple Person objects linked with the same Person[2] object. ' + 'Maybe Person.spouse attribute should be Set instead of Optional') def test5(self): db.execute('update person set spouse = 3 where id = 2') p1 = Person[1] From 42262152f93961c6e38606f9040de3d774f2918b Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 23 Aug 2017 11:17:33 +0300 Subject: [PATCH 152/547] Remove duplicate code --- pony/orm/dbproviders/sqlite.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 0b44fbbc9..9c272dfc6 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -524,7 +524,6 @@ def _connect(pool): con.create_function('py_json_contains', 3, py_json_contains) con.create_function('py_json_nonzero', 2, py_json_nonzero) con.create_function('py_json_array_length', -1, py_json_array_length) - con.create_function('py_lower', 1, py_lower) if sqlite.sqlite_version_info >= (3, 6, 19): con.execute('PRAGMA foreign_keys = true') def disconnect(pool): From 04fbaa2b6206721fe9fff70dd421077d3e8999ad Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 23 Aug 2017 12:59:23 +0300 Subject: [PATCH 153/547] Remove @print_traceback decarator from pony.dbproviders.sqlite --- pony/orm/dbproviders/sqlite.py | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 9c272dfc6..d6672a984 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -210,20 +210,6 @@ def py2sql(converter, val): class SQLiteJsonConverter(dbapiprovider.JsonConverter): json_kwargs = {'separators': (',', ':'), 'sort_keys': True, 'ensure_ascii': False} -def print_traceback(func): - @wraps(func) - def wrapper(*args, **kw): - try: - return func(*args, **kw) - except: - if core.debug: - import traceback - msg = traceback.format_exc() - log_orm(msg) - raise - return wrapper - - class SQLiteProvider(DBAPIProvider): dialect = 'SQLite' max_name_len = 1024 @@ -414,7 +400,6 @@ def func(value): py_upper = make_string_function('py_upper', unicode.upper) py_lower = make_string_function('py_lower', unicode.lower) -@print_traceback def py_json_unwrap(value): # [null,some-value] -> some-value assert value.startswith('[null,'), value @@ -462,14 +447,12 @@ def _extract(expr, *paths): result.append(_traverse(expr, keys)) return result[0] if len(paths) == 1 else result -@print_traceback def py_json_extract(expr, *paths): result = _extract(expr, *paths) if type(result) in (list, dict): result = json.dumps(result, **SQLiteJsonConverter.json_kwargs) return result -@print_traceback def py_json_query(expr, path, with_wrapper): result = _extract(expr, path) if type(result) not in (list, dict): @@ -477,26 +460,22 @@ def py_json_query(expr, path, with_wrapper): result = [result] return json.dumps(result, **SQLiteJsonConverter.json_kwargs) -@print_traceback def py_json_value(expr, path): result = _extract(expr, path) return result if type(result) not in (list, dict) else None -@print_traceback def py_json_contains(expr, path, key): expr = json.loads(expr) if isinstance(expr, basestring) else expr keys = _parse_path(path) expr = _traverse(expr, keys) return type(expr) in (list, dict) and key in expr -@print_traceback def py_json_nonzero(expr, path): expr = json.loads(expr) if isinstance(expr, basestring) else expr keys = _parse_path(path) expr = _traverse(expr, keys) return bool(expr) -@print_traceback def py_json_array_length(expr, path=None): expr = json.loads(expr) if isinstance(expr, basestring) else expr if path: From 2869d46cec815cc5916233507a543a7d579f4f9d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 23 Aug 2017 13:01:12 +0300 Subject: [PATCH 154/547] Correctly pass exception from user-defined functions in SQLite --- pony/orm/dbapiprovider.py | 12 +++++-- pony/orm/dbproviders/sqlite.py | 58 +++++++++++++++++++++++++++------- 2 files changed, 57 insertions(+), 13 deletions(-) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 30f6846f9..1a091e030 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -45,12 +45,20 @@ class NotSupportedError(DatabaseError): pass @decorator def wrap_dbapi_exceptions(func, provider, *args, **kwargs): dbapi_module = provider.dbapi_module - try: return func(provider, *args, **kwargs) + try: + if provider.dialect != 'SQLite': + return func(provider, *args, **kwargs) + else: + provider.local_exceptions.keep_traceback = True + try: return func(provider, *args, **kwargs) + finally: provider.local_exceptions.keep_traceback = False except dbapi_module.NotSupportedError as e: raise NotSupportedError(e) except dbapi_module.ProgrammingError as e: raise ProgrammingError(e) except dbapi_module.InternalError as e: raise InternalError(e) except dbapi_module.IntegrityError as e: raise IntegrityError(e) - except dbapi_module.OperationalError as e: raise OperationalError(e) + except dbapi_module.OperationalError as e: + if provider.dialect == 'SQLite': provider.restore_exception() + raise OperationalError(e) except dbapi_module.DataError as e: raise DataError(e) except dbapi_module.DatabaseError as e: raise DatabaseError(e) except dbapi_module.InterfaceError as e: diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index d6672a984..23eb78d30 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -1,7 +1,7 @@ from __future__ import absolute_import from pony.py23compat import PY2, imap, basestring, buffer, int_types, unicode -import os.path, re, json +import os.path, sys, re, json import sqlite3 as sqlite from decimal import Decimal from datetime import datetime, date, time, timedelta @@ -16,7 +16,7 @@ from pony.orm.core import log_orm from pony.orm.sqlbuilding import SQLBuilder, join, make_unary_func from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions -from pony.utils import datetime2timestamp, timestamp2datetime, absolutize_path, throw +from pony.utils import datetime2timestamp, timestamp2datetime, absolutize_path, localbase, throw, reraise from contextlib import contextmanager @@ -210,8 +210,33 @@ def py2sql(converter, val): class SQLiteJsonConverter(dbapiprovider.JsonConverter): json_kwargs = {'separators': (',', ':'), 'sort_keys': True, 'ensure_ascii': False} + +class LocalExceptions(localbase): + def __init__(self): + self.exc_info = None + self.keep_traceback = False + +local_exceptions = LocalExceptions() + +def keep_exception(func): + @wraps(func) + def new_func(*args): + local_exceptions.exc_info = None + try: + return func(*args) + except Exception: + local_exceptions.exc_info = sys.exc_info() + if not local_exceptions.keep_traceback: + local_exceptions.exc_info = local_exceptions.exc_info[:2] + (None,) + raise + finally: + local_exceptions.keep_traceback = False + return new_func + + class SQLiteProvider(DBAPIProvider): dialect = 'SQLite' + local_exceptions = local_exceptions max_name_len = 1024 select_for_update_nowait_syntax = False @@ -249,6 +274,11 @@ def inspect_connection(provider, conn): DBAPIProvider.inspect_connection(provider, conn) provider.json1_available = provider.check_json1(conn) + def restore_exception(provider): + if provider.local_exceptions.exc_info is not None: + try: reraise(*provider.local_exceptions.exc_info) + finally: provider.local_exceptions.exc_info = None + @wrap_dbapi_exceptions def set_transaction_mode(provider, connection, cache): assert not cache.in_transaction @@ -494,15 +524,21 @@ def _connect(pool): throw(IOError, "Database file is not found: %r" % filename) pool.con = con = sqlite.connect(filename, isolation_level=None) con.text_factory = _text_factory - con.create_function('power', 2, pow) - con.create_function('rand', 0, random) - con.create_function('py_upper', 1, py_upper) - con.create_function('py_lower', 1, py_lower) - con.create_function('py_json_unwrap', 1, py_json_unwrap) - con.create_function('py_json_extract', -1, py_json_extract) - con.create_function('py_json_contains', 3, py_json_contains) - con.create_function('py_json_nonzero', 2, py_json_nonzero) - con.create_function('py_json_array_length', -1, py_json_array_length) + + def create_function(name, num_params, func): + func = keep_exception(func) + con.create_function(name, num_params, func) + + create_function('power', 2, pow) + create_function('rand', 0, random) + create_function('py_upper', 1, py_upper) + create_function('py_lower', 1, py_lower) + create_function('py_json_unwrap', 1, py_json_unwrap) + create_function('py_json_extract', -1, py_json_extract) + create_function('py_json_contains', 3, py_json_contains) + create_function('py_json_nonzero', 2, py_json_nonzero) + create_function('py_json_array_length', -1, py_json_array_length) + if sqlite.sqlite_version_info >= (3, 6, 19): con.execute('PRAGMA foreign_keys = true') def disconnect(pool): From 0bd0f40183117424c6900dedc458c4ddc73f1c76 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 23 Aug 2017 13:33:46 +0300 Subject: [PATCH 155/547] Remove unused import --- pony/orm/dbproviders/sqlite.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 23eb78d30..f09d0ce9b 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -18,8 +18,6 @@ from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions from pony.utils import datetime2timestamp, timestamp2datetime, absolutize_path, localbase, throw, reraise -from contextlib import contextmanager - class SqliteExtensionUnavailable(Exception): pass From 409449f3f7a7c988897f5947d2bbaa8a4b753777 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 23 Aug 2017 13:57:14 +0300 Subject: [PATCH 156/547] Rename: self -> converter --- pony/orm/dbapiprovider.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 1a091e030..af949cf3e 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -754,22 +754,22 @@ def sql_type(converter): class JsonConverter(Converter): json_kwargs = {} class JsonEncoder(json.JSONEncoder): - def default(self, obj): + def default(converter, obj): if isinstance(obj, Json): return obj.wrapped - return json.JSONEncoder.default(self, obj) - def val2dbval(self, val, obj=None): - return json.dumps(val, cls=self.JsonEncoder, **self.json_kwargs) - def dbval2val(self, dbval, obj=None): + return json.JSONEncoder.default(converter, obj) + def val2dbval(converter, val, obj=None): + return json.dumps(val, cls=converter.JsonEncoder, **converter.json_kwargs) + def dbval2val(converter, dbval, obj=None): if isinstance(dbval, (int, bool, float, type(None))): return dbval val = json.loads(dbval) if obj is None: return val - return TrackedValue.make(obj, self.attr, val) - def dbvals_equal(self, x, y): + return TrackedValue.make(obj, converter.attr, val) + def dbvals_equal(converter, x, y): if isinstance(x, basestring): x = json.loads(x) if isinstance(y, basestring): y = json.loads(y) return x == y - def sql_type(self): + def sql_type(converter): return "JSON" From 1ba414058166c38a7b68c59a08730b8c09232006 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 23 Aug 2017 16:50:43 +0300 Subject: [PATCH 157/547] Remove unnecessary val2dbval conversion --- pony/orm/core.py | 29 ++++++++++++++++++++--------- pony/orm/sqlbuilding.py | 10 +++++----- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index d729a70ca..dfb450e24 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4659,7 +4659,7 @@ def _save_principal_objects_(obj, dependent_objects): val = obj._vals_[attr] if val is not None and val._status_ == 'created': val._save_(dependent_objects) - def _update_dbvals_(obj, after_create): + def _update_dbvals_(obj, after_create, new_dbvals): bits = obj._bits_ vals = obj._vals_ dbvals = obj._dbvals_ @@ -4678,10 +4678,7 @@ def _update_dbvals_(obj, after_create): elif after_create and val is None: obj._rbits_ &= ~bits[attr] else: - # For normal attribute, set `dbval` to the same value as `val` after update/create - # dbvals[attr] = val - converter = attr.converters[0] - dbvals[attr] = converter.val2dbval(val, obj) # TODO this conversion should be unnecessary + dbvals[attr] = new_dbvals.get(attr, val) continue # Clear value of volatile attribute or null values after create, because the value may be changed in the DB del vals[attr] @@ -4691,12 +4688,19 @@ def _save_created_(obj): auto_pk = (obj._pkval_ is None) attrs = [] values = [] + new_dbvals = {} for attr in obj._attrs_with_columns_: if auto_pk and attr.is_pk: continue val = obj._vals_[attr] if val is not None: attrs.append(attr) - values.extend(attr.get_raw_values(val)) + if not attr.reverse: + assert len(attr.converters) == 1 + dbval = attr.converters[0].val2dbval(val, obj) + new_dbvals[attr] = dbval + values.append(dbval) + else: + values.extend(attr.get_raw_values(val)) attrs = tuple(attrs) database = obj._database_ @@ -4746,14 +4750,21 @@ def _save_created_(obj): obj._status_ = 'inserted' obj._rbits_ = obj._all_bits_except_volatile_ obj._wbits_ = 0 - obj._update_dbvals_(True) + obj._update_dbvals_(True, new_dbvals) def _save_updated_(obj): update_columns = [] values = [] + new_dbvals = {} for attr in obj._attrs_with_bit_(obj._attrs_with_columns_, obj._wbits_): update_columns.extend(attr.columns) val = obj._vals_[attr] - values.extend(attr.get_raw_values(val)) + if not attr.reverse: + assert len(attr.converters) == 1 + dbval = attr.converters[0].val2dbval(val, obj) + new_dbvals[attr] = dbval + values.append(dbval) + else: + values.extend(attr.get_raw_values(val)) if update_columns: for attr in obj._pk_attrs_: val = obj._vals_[attr] @@ -4791,7 +4802,7 @@ def _save_updated_(obj): obj._status_ = 'updated' obj._rbits_ |= obj._wbits_ & obj._all_bits_except_volatile_ obj._wbits_ = 0 - obj._update_dbvals_(False) + obj._update_dbvals_(False, new_dbvals) def _save_deleted_(obj): values = [] values.extend(obj._get_raw_pkval_()) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 20205329b..ee6672e8a 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -31,11 +31,11 @@ def eval(param, values): if j is not None: assert type(type(value)).__name__ == 'EntityMeta' value = value._get_raw_pkval_()[j] - if value is not None: # can value be None at all? - converter = param.converter - if converter is not None: - if not param.optimistic: value = converter.val2dbval(value) - value = converter.py2sql(value) + converter = param.converter + if value is not None and converter is not None: + if converter.attr is None: + value = converter.val2dbval(value) + value = converter.py2sql(value) return value def __unicode__(param): paramstyle = param.style From b73d6b0ced92a7a0af8bc4b8693ef63df4f18b4c Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 23 Aug 2017 18:12:01 +0300 Subject: [PATCH 158/547] Fixes #283: Lost Json update --- pony/orm/core.py | 53 +++++++++++++++++----------------- pony/orm/dbapiprovider.py | 30 +++++++++++-------- pony/orm/dbproviders/oracle.py | 2 +- 3 files changed, 46 insertions(+), 39 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index dfb450e24..9017361c4 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1908,7 +1908,7 @@ def validate(attr, val, obj=None, entity=None, from_db=False): if converter is not None: try: if from_db: return converter.sql2py(val) - val = converter.validate(val) + val = converter.validate(val, obj) except UnicodeDecodeError as e: throw(ValueError, 'Value for attribute %s cannot be converted to %s: %s' % (attr, unicode.__name__, truncate_repr(val))) @@ -1923,7 +1923,7 @@ def validate(attr, val, obj=None, entity=None, from_db=False): except TypeError: throw(TypeError, 'Attribute %s must be of %s type. Got: %r' % (attr, rentity.__name__, val)) else: - if obj is not None: cache = obj._session_cache_ + if obj is not None and obj._status_ is not None: cache = obj._session_cache_ else: cache = entity._database_._get_cache() if cache is not val._session_cache_: throw(TransactionError, 'An attempt to mix objects belonging to different transactions') @@ -2236,7 +2236,8 @@ def validate(attr, val, obj=None, entity=None, from_db=False): val = Attribute.validate(attr, val, obj, entity, from_db) if val == '' or (val is None and not (attr.auto or attr.is_volatile or attr.sql_default)): if not from_db: - throw(ValueError, 'Attribute %s is required' % (attr if obj is None else '%r.%s' % (obj, attr.name))) + throw(ValueError, 'Attribute %s is required' % ( + attr if obj is None or obj._status_ is None else '%r.%s' % (obj, attr.name))) else: warnings.warn('Database contains %s for required attribute %s' % ('NULL' if val is None else 'empty string', attr), @@ -2507,7 +2508,7 @@ def validate(attr, val, obj=None, entity=None, from_db=False): if not isinstance(item, rentity): throw(TypeError, 'Item of collection %s.%s must be an instance of %s. Got: %r' % (entity.__name__, attr.name, rentity.__name__, item)) - if obj is not None: cache = obj._session_cache_ + if obj is not None and obj._status_ is not None: cache = obj._session_cache_ else: cache = entity._database_._get_cache() for item in items: if item._session_cache_ is not cache: @@ -3555,26 +3556,6 @@ def _get_pk_columns_(entity): return pk_columns def __iter__(entity): return EntityIter(entity) - def _normalize_args_(entity, kwargs, setdefault=False): - avdict = {} - if setdefault: - for name in kwargs: - if name not in entity._adict_: throw(TypeError, 'Unknown attribute %r' % name) - for attr in entity._attrs_: - val = kwargs.get(attr.name, DEFAULT) - avdict[attr] = attr.validate(val, None, entity, from_db=False) - else: - get_attr = entity._adict_.get - for name, val in iteritems(kwargs): - attr = get_attr(name) - if attr is None: throw(TypeError, 'Unknown attribute %r' % name) - avdict[attr] = attr.validate(val, None, entity, from_db=False) - if entity._pk_is_composite_: - get_val = avdict.get - pkval = tuple(get_val(attr) for attr in entity._pk_attrs_) - if None in pkval: pkval = None - else: pkval = avdict.get(entity._pk_attrs_[0]) - return pkval, avdict @cut_traceback def __getitem__(entity, key): if type(key) is not tuple: key = (key,) @@ -3678,7 +3659,16 @@ def select_random(entity, limit): def _find_one_(entity, kwargs, for_update=False, nowait=False): if entity._database_.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % entity.__name__) - pkval, avdict = entity._normalize_args_(kwargs, False) + avdict = {} + get_attr = entity._adict_.get + for name, val in iteritems(kwargs): + attr = get_attr(name) + if attr is None: throw(TypeError, 'Unknown attribute %r' % name) + avdict[attr] = attr.validate(val, None, entity, from_db=False) + if entity._pk_is_composite_: + pkval = tuple(imap(avdict.get, entity._pk_attrs_)) + if None in pkval: pkval = None + else: pkval = avdict.get(entity._pk_attrs_[0]) for attr in avdict: if attr.is_collection: throw(TypeError, 'Collection attribute %s cannot be specified as search criteria' % attr) @@ -4190,13 +4180,24 @@ def __reduce__(obj): return unpickle_entity, (d,) @cut_traceback def __init__(obj, *args, **kwargs): + obj._status_ = None entity = obj.__class__ if args: raise TypeError('%s constructor accept only keyword arguments. Got: %d positional argument%s' % (entity.__name__, len(args), len(args) > 1 and 's' or '')) if entity._database_.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % entity.__name__) - pkval, avdict = entity._normalize_args_(kwargs, True) + avdict = {} + for name in kwargs: + if name not in entity._adict_: throw(TypeError, 'Unknown attribute %r' % name) + for attr in entity._attrs_: + val = kwargs.get(attr.name, DEFAULT) + avdict[attr] = attr.validate(val, obj, entity, from_db=False) + if entity._pk_is_composite_: + pkval = tuple(imap(avdict.get, entity._pk_attrs_)) + if None in pkval: pkval = None + else: pkval = avdict.get(entity._pk_attrs_[0]) + undo_funcs = [] cache = entity._database_._get_cache() cache_indexes = cache.indexes diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index af949cf3e..fc7b13ba6 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -350,7 +350,7 @@ def __init__(converter, provider, py_type, attr=None): def init(converter, kwargs): attr = converter.attr if attr and attr.args: unexpected_args(attr, attr.args) - def validate(converter, val): + def validate(converter, val, obj=None): return val def py2sql(converter, val): return val @@ -393,7 +393,7 @@ def get_fk_type(converter, sql_type): assert False class BoolConverter(Converter): - def validate(converter, val): + def validate(converter, val, obj=None): return bool(val) def sql2py(converter, val): return bool(val) @@ -421,7 +421,7 @@ def init(converter, kwargs): converter.max_len = max_len converter.db_encoding = kwargs.pop('db_encoding', None) converter.autostrip = kwargs.pop('autostrip', True) - def validate(converter, val): + def validate(converter, val, obj=None): if PY2 and isinstance(val, str): val = val.decode('ascii') elif not isinstance(val, unicode): throw(TypeError, 'Value type for attribute %s must be %s. Got: %r' % (converter.attr, unicode.__name__, type(val))) @@ -492,7 +492,7 @@ def init(converter, kwargs): converter.max_val = max_val or highest converter.size = size converter.unsigned = unsigned - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, int_types): pass elif isinstance(val, basestring): try: val = int(val) @@ -539,7 +539,7 @@ def init(converter, kwargs): converter.min_val = min_val converter.max_val = max_val converter.tolerance = kwargs.pop('tolerance', converter.default_tolerance) - def validate(converter, val): + def validate(converter, val, obj=None): try: val = float(val) except ValueError: throw(TypeError, 'Invalid value for attribute %s: %r' % (converter.attr, val)) @@ -604,7 +604,7 @@ def init(converter, kwargs): converter.min_val = min_val converter.max_val = max_val - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, float): s = str(val) if float(s) != val: s = repr(val) @@ -625,7 +625,7 @@ def sql_type(converter): return 'DECIMAL(%d, %d)' % (converter.precision, converter.scale) class BlobConverter(Converter): - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, buffer): return val if isinstance(val, str): return buffer(val) throw(TypeError, "Attribute %r: expected type is 'buffer'. Got: %r" % (converter.attr, type(val))) @@ -638,7 +638,7 @@ def sql_type(converter): return 'BLOB' class DateConverter(Converter): - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, datetime): return val.date() if isinstance(val, date): return val if isinstance(val, basestring): return str2date(val) @@ -688,7 +688,7 @@ def sql_type(converter): class TimeConverter(ConverterWithMicroseconds): sql_type_name = 'TIME' - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, time): pass elif isinstance(val, basestring): val = str2time(val) else: throw(TypeError, "Attribute %r: expected type is 'time'. Got: %r" % (converter.attr, val)) @@ -702,7 +702,7 @@ def sql2py(converter, val): class TimedeltaConverter(ConverterWithMicroseconds): sql_type_name = 'INTERVAL' - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, timedelta): pass elif isinstance(val, basestring): val = str2timedelta(val) else: throw(TypeError, "Attribute %r: expected type is 'timedelta'. Got: %r" % (converter.attr, val)) @@ -716,7 +716,7 @@ def sql2py(converter, val): class DatetimeConverter(ConverterWithMicroseconds): sql_type_name = 'DATETIME' - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, datetime): pass elif isinstance(val, basestring): val = str2datetime(val) else: throw(TypeError, "Attribute %r: expected type is 'datetime'. Got: %r" % (converter.attr, val)) @@ -734,7 +734,7 @@ def __init__(converter, provider, py_type, attr=None): attr.auto = False if not attr.default: attr.default = uuid4 Converter.__init__(converter, provider, py_type, attr) - def validate(converter, val): + def validate(converter, val, obj=None): if isinstance(val, UUID): return val if isinstance(val, buffer): return UUID(bytes=val) if isinstance(val, basestring): @@ -758,6 +758,12 @@ def default(converter, obj): if isinstance(obj, Json): return obj.wrapped return json.JSONEncoder.default(converter, obj) + def validate(converter, val, obj=None): + if obj is None or converter.attr is None: + return val + if isinstance(val, TrackedValue) and val.obj is obj and val.attr is converter.attr: + return val + return TrackedValue.make(obj, converter.attr, val) def val2dbval(converter, val, obj=None): return json.dumps(val, cls=converter.JsonEncoder, **converter.json_kwargs) def dbval2val(converter, dbval, obj=None): diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 600b8880d..16d83214e 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -287,7 +287,7 @@ def sql_type(converter): return "NUMBER(1)" class OraStrConverter(dbapiprovider.StrConverter): - def validate(converter, val): + def validate(converter, val, obj=None): if val == '': return None return dbapiprovider.StrConverter.validate(converter, val) def sql2py(converter, val): From f31b796485c6eb3a4295aaeb22771864f1e4817a Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 24 Aug 2017 14:19:43 +0300 Subject: [PATCH 159/547] Fixes #288: AttributeError: module 'symbol' has no attribute 'list_for' with Python 3 --- pony/thirdparty/compiler/transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/thirdparty/compiler/transformer.py b/pony/thirdparty/compiler/transformer.py index 182712439..dea639574 100644 --- a/pony/thirdparty/compiler/transformer.py +++ b/pony/thirdparty/compiler/transformer.py @@ -1153,7 +1153,7 @@ def com_list_constructor(self, nodelist): # listmaker: test ( list_for | (',' test)* [','] ) values = [] for i in range(1, len(nodelist)): - if nodelist[i][0] == symbol.list_for: + if PY2 and nodelist[i][0] == symbol.list_for: assert len(nodelist[i:]) == 1 return self.com_list_comprehension(values[0], nodelist[i]) From 74753c4c36df4bbbd5a1f19aa584c6ac349ae4f3 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 24 Aug 2017 16:30:45 +0300 Subject: [PATCH 160/547] Fixes #283 for nested Json objects --- pony/orm/ormtypes.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index ce591746b..a4bc11630 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -1,5 +1,5 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PY2, items_list, izip, basestring, unicode, buffer, int_types +from pony.py23compat import PY2, items_list, izip, basestring, unicode, buffer, int_types, iteritems import types, weakref from decimal import Decimal @@ -240,8 +240,13 @@ def get_untracked(self): def tracked_method(func): @wraps(func, assigned=('__name__', '__doc__') if PY2 else WRAPPER_ASSIGNMENTS) - def new_func(self, *args, **kw): - result = func(self, *args, **kw) + def new_func(self, *args, **kwargs): + obj = self.obj_ref() + attr = self.attr + if obj is not None: + args = tuple(TrackedValue.make(obj, attr, arg) for arg in args) + if kwargs: kwargs = {key: TrackedValue.make(obj, attr, value) for key, value in iteritems(kwargs)} + result = func(self, *args, **kwargs) self._changed_() return result return new_func @@ -249,13 +254,15 @@ def new_func(self, *args, **kw): class TrackedDict(TrackedValue, dict): def __init__(self, obj, attr, value): TrackedValue.__init__(self, obj, attr) - dict.__init__(self, ((key, self.make(obj, attr, val)) - for key, val in value.items())) + dict.__init__(self, {key: self.make(obj, attr, val) for key, val in iteritems(value)}) def __reduce__(self): return dict, (dict(self),) __setitem__ = tracked_method(dict.__setitem__) __delitem__ = tracked_method(dict.__delitem__) - update = tracked_method(dict.update) + _update = tracked_method(dict.update) + def update(self, *args, **kwargs): + args = [ arg if isinstance(arg, dict) else dict(arg) for arg in args ] + return self._update(*args, **kwargs) setdefault = tracked_method(dict.setdefault) pop = tracked_method(dict.pop) popitem = tracked_method(dict.popitem) From 9cb02a7cda7e9e0403584193062edad9b14888c1 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 25 Aug 2017 18:27:36 +0300 Subject: [PATCH 161/547] Remove unused flag translator.inside_expr --- pony/orm/sqltranslation.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index f155008e2..9f8dfcd01 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -182,7 +182,6 @@ def __init__(translator, tree, extractors, vartypes, parent_translator=None, lef translator.having_conditions = [] translator.order = [] translator.aggregated = False if not optimize else True - translator.inside_expr = False translator.inside_not = False translator.hint_join = False translator.query_result_is_cacheable = True @@ -273,7 +272,6 @@ def __init__(translator, tree, extractors, vartypes, parent_translator=None, lef if not m.aggregated: translator.conditions.extend(m.getsql()) else: translator.having_conditions.extend(m.getsql()) - translator.inside_expr = True translator.dispatch(tree.expr) assert not translator.hint_join assert not translator.inside_not From 5feb5194c11f7e27e329de7fb77173093b42f58d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 25 Aug 2017 18:55:25 +0300 Subject: [PATCH 162/547] Remove unused flag translator.inside_not --- pony/orm/sqltranslation.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 9f8dfcd01..74cfe60bb 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -182,7 +182,6 @@ def __init__(translator, tree, extractors, vartypes, parent_translator=None, lef translator.having_conditions = [] translator.order = [] translator.aggregated = False if not optimize else True - translator.inside_not = False translator.hint_join = False translator.query_result_is_cacheable = True translator.aggregated_subquery_paths = set() @@ -274,7 +273,6 @@ def __init__(translator, tree, extractors, vartypes, parent_translator=None, lef translator.dispatch(tree.expr) assert not translator.hint_join - assert not translator.inside_not monad = tree.expr.monad if isinstance(monad, translator.ParamMonad): throw(TranslationError, "External parameter '%s' cannot be used as query result" % ast2src(tree.expr)) @@ -638,12 +636,9 @@ def preCompare(translator, node): ops = node.ops left = node.expr translator.dispatch(left) - inside_not = translator.inside_not # op: '<' | '>' | '=' | '>=' | '<=' | '<>' | '!=' | '==' # | 'in' | 'not in' | 'is' | 'is not' for op, right in node.ops: - translator.inside_not = inside_not - if op == 'not in': translator.inside_not = not inside_not translator.dispatch(right) if op.endswith('in'): monad = right.monad.contains(left.monad, op == 'not in') else: monad = left.monad.cmp(op, right.monad) @@ -655,7 +650,6 @@ def preCompare(translator, node): 'Too complex aggregation, expressions cannot be combined: {EXPR}') monads.append(monad) left = right - translator.inside_not = inside_not if len(monads) == 1: return monads[0] return translator.AndMonad(monads) def postConst(translator, node): @@ -712,11 +706,7 @@ def postBitand(translator, node): def postBitxor(translator, node): left, right = (subnode.monad for subnode in node.nodes) return left ^ right - - def preNot(translator, node): - translator.inside_not = not translator.inside_not def postNot(translator, node): - translator.inside_not = not translator.inside_not return node.expr.monad.negate() def preCallFunc(translator, node): if node.star_args is not None: throw(NotImplementedError, '*%s is not supported' % ast2src(node.star_args)) From 70aa83016ba46cc7e934320998bac942356f5b07 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 25 Aug 2017 15:45:35 +0300 Subject: [PATCH 163/547] Minor changes --- pony/orm/dbproviders/postgres.py | 2 +- pony/orm/dbproviders/sqlite.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 90c5b5c20..9316bea1c 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -89,7 +89,7 @@ def eval_json_path(builder, values): def JSON_QUERY(builder, expr, path): path_sql, has_params, has_wildcards = builder.build_json_path(path) return '(', builder(expr), " #> ", path_sql, ')' - json_value_type_mapping = {bool: 'boolean', int: 'integer', float: 'real'} + json_value_type_mapping = {bool: 'boolean', int: 'int', float: 'real'} def JSON_VALUE(builder, expr, path, type): if type is ormtypes.Json: return builder.JSON_QUERY(expr, path) path_sql, has_params, has_wildcards = builder.build_json_path(path) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index f09d0ce9b..b3eacccfe 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -12,8 +12,9 @@ from binascii import hexlify from functools import wraps -from pony.orm import core, dbschema, sqltranslation, dbapiprovider, ormtypes +from pony.orm import core, dbschema, sqltranslation, dbapiprovider from pony.orm.core import log_orm +from pony.orm.ormtypes import Json from pony.orm.sqlbuilding import SQLBuilder, join, make_unary_func from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions from pony.utils import datetime2timestamp, timestamp2datetime, absolutize_path, localbase, throw, reraise @@ -260,7 +261,7 @@ class SQLiteProvider(DBAPIProvider): (timedelta, SQLiteTimedeltaConverter), (UUID, dbapiprovider.UuidConverter), (buffer, dbapiprovider.BlobConverter), - (ormtypes.Json, SQLiteJsonConverter) + (Json, SQLiteJsonConverter) ] def __init__(provider, *args, **kwargs): From f87d55c8569f828a31c88dfc09d83a535242b7fe Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 24 Aug 2017 19:04:15 +0300 Subject: [PATCH 164/547] Add support of explicit casting to int in queries using int() function --- pony/orm/core.py | 2 +- pony/orm/dbproviders/mysql.py | 2 ++ pony/orm/dbproviders/sqlite.py | 7 +++++-- pony/orm/sqltranslation.py | 9 +++++++++ 4 files changed, 17 insertions(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 9017361c4..854056a18 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5662,5 +5662,5 @@ def show(entity): from pprint import pprint pprint(x) -special_functions = {itertools.count, utils.count, count, random, raw_sql, getattr} +special_functions = {itertools.count, utils.count, count, random, raw_sql, getattr, int} const_functions = {buffer, Decimal, datetime.datetime, datetime.date, datetime.time, datetime.timedelta} diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index a3aa63187..d7830b0b2 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -60,6 +60,8 @@ def LTRIM(builder, expr, chars=None): def RTRIM(builder, expr, chars=None): if chars is None: return 'rtrim(', builder(expr), ')' return 'trim(trailing ', builder(chars), ' from ' ,builder(expr), ')' + def TO_INT(builder, expr): + return 'CAST(', builder(expr), ' AS SIGNED)' def YEAR(builder, expr): return 'year(', builder(expr), ')' def MONTH(builder, expr): diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index b3eacccfe..101345692 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -141,11 +141,14 @@ def JSON_QUERY(builder, expr, path): fname = 'json_extract' if builder.json1_available else 'py_json_extract' path_sql, has_params, has_wildcards = builder.build_json_path(path) return 'py_json_unwrap(', fname, '(', builder(expr), ', null, ', path_sql, '))' - # json_value_type_mapping = {unicode: 'text', bool: 'boolean', int: 'integer', float: 'real', Json: None} + json_value_type_mapping = {unicode: 'text', bool: 'integer', int: 'integer', float: 'real'} def JSON_VALUE(builder, expr, path, type): func_name = 'json_extract' if builder.json1_available else 'py_json_extract' path_sql, has_params, has_wildcards = builder.build_json_path(path) - return func_name, '(', builder(expr), ', ', path_sql, ')' + type_name = builder.json_value_type_mapping.get(type) + result = func_name, '(', builder(expr), ', ', path_sql, ')' + if type_name is not None: result = 'CAST(', result, ' as ', type_name, ')' + return result def JSON_NONZERO(builder, expr): return builder(expr), ''' NOT IN ('null', 'false', '0', '""', '[]', '{}')''' def JSON_ARRAY_LENGTH(builder, value): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 74cfe60bb..4d066b365 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1080,6 +1080,8 @@ def __and__(monad): throw(TypeError) def __xor__(monad): throw(TypeError) def abs(monad): throw(TypeError) def cast_from_json(monad, type): assert False, monad + def to_int(monad): + return NumericExprMonad(monad.translator, int, [ 'TO_INT', monad.getsql()[0] ]) class RawSQLMonad(Monad): def __init__(monad, translator, rawtype, varkey): @@ -1729,6 +1731,8 @@ def get_path(monad): monad = monad.parent path.reverse() return monad, path + def to_int(monad): + return monad.cast_from_json(int) def cast_from_json(monad, type): translator = monad.translator if issubclass(type, Json): @@ -1971,6 +1975,11 @@ def call(monad, source, encoding=None, errors=None): else: value = buffer(source) return translator.ConstMonad.new(translator, value) +class FuncIntMonad(FuncMonad): + func = int + def call(monad, x): + return x.to_int() + class FuncDecimalMonad(FuncMonad): func = Decimal def call(monad, x): From cf5285a65d1189d82d6741810b72331281f0cab9 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 24 Aug 2017 18:31:36 +0300 Subject: [PATCH 165/547] Fixes #284: query.order_by() orders Json numbers like strings --- pony/orm/sqltranslation.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 4d066b365..594c3417e 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -181,6 +181,7 @@ def __init__(translator, tree, extractors, vartypes, parent_translator=None, lef translator.conditions = subquery.conditions translator.having_conditions = [] translator.order = [] + translator.inside_order_by = False translator.aggregated = False if not optimize else True translator.hint_join = False translator.query_result_is_cacheable = True @@ -605,6 +606,7 @@ def apply_lambda(translator, filter_num, order_by, func_ast, argnames, extractor if isinstance(func_ast, ast.Tuple): nodes = func_ast.nodes else: nodes = (func_ast,) if order_by: + translator.inside_order_by = True new_order = [] for node in nodes: if isinstance(node.monad, translator.SetMixin): @@ -614,6 +616,7 @@ def apply_lambda(translator, filter_num, order_by, func_ast, argnames, extractor % (t, ast2src(node))) new_order.extend(node.monad.getsql()) translator.order[:0] = new_order + translator.inside_order_by = False else: for node in nodes: monad = node.monad @@ -1745,6 +1748,9 @@ def cast_from_json(monad, type): def getsql(monad): base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] + translator = monad.translator + if translator.inside_order_by and translator.dialect == 'SQLite': + return [ [ 'JSON_VALUE', base_sql, path, None ] ] return [ [ 'JSON_QUERY', base_sql, path ] ] class ConstMonad(Monad): From 3a649fecaba6f1364a66ee3f1516dcd0768f2f92 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Fri, 25 Aug 2017 20:33:10 +0300 Subject: [PATCH 166/547] Fixes #280: repeated database locked (Operational Error) issues in high activity, db intensive application. --- pony/orm/dbproviders/sqlite.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 101345692..3921e1e2f 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -353,7 +353,7 @@ def release(provider, connection, cache=None): raise DBAPIProvider.release(provider, connection, cache) - def get_pool(provider, filename, create_db=False): + def get_pool(provider, filename, create_db=False, **kwargs): if filename != ':memory:': # When relative filename is specified, it is considered # not relative to cwd, but to user module where @@ -369,7 +369,7 @@ def get_pool(provider, filename, create_db=False): # 1 - SQLiteProvider.__init__() # 0 - pony.dbproviders.sqlite.get_pool() filename = absolutize_path(filename, frame_depth=7) - return SQLitePool(filename, create_db) + return SQLitePool(filename, create_db, **kwargs) def table_exists(provider, connection, table_name, case_sensitive=True): return provider._exists(connection, table_name, None, case_sensitive) @@ -516,15 +516,16 @@ def py_json_array_length(expr, path=None): return len(expr) if type(expr) is list else 0 class SQLitePool(Pool): - def __init__(pool, filename, create_db): # called separately in each thread + def __init__(pool, filename, create_db, **kwargs): # called separately in each thread pool.filename = filename pool.create_db = create_db + pool.kwargs = kwargs pool.con = None def _connect(pool): filename = pool.filename if filename != ':memory:' and not pool.create_db and not os.path.exists(filename): throw(IOError, "Database file is not found: %r" % filename) - pool.con = con = sqlite.connect(filename, isolation_level=None) + pool.con = con = sqlite.connect(filename, isolation_level=None, **pool.kwargs) con.text_factory = _text_factory def create_function(name, num_params, func): From b82474b85ba93bf46ce705c60e59436a5b1de100 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 28 Aug 2017 16:01:40 +0300 Subject: [PATCH 167/547] Fixes #266: Add handler to "pony.orm" logger does not work --- pony/orm/core.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 854056a18..9a900868a 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -79,8 +79,19 @@ def sql_debug(value): orm_log_level = logging.INFO +def has_handlers(logger): + if not PY2: + return logger.hasHandlers() + while logger: + if logger.handlers: + return True + elif not logger.propagate: + return False + logger = logger.parent + return False + def log_orm(msg): - if logging.root.handlers: + if has_handlers(orm_logger): orm_logger.log(orm_log_level, msg) else: print(msg) @@ -88,7 +99,7 @@ def log_orm(msg): def log_sql(sql, arguments=None): if type(arguments) is list: sql = 'EXECUTEMANY (%d)\n%s' % (len(arguments), sql) - if logging.root.handlers: + if has_handlers(sql_logger): sql_logger.log(orm_log_level, sql) # arguments can hold sensitive information else: print(sql) From 554f2f2738e983ab3668e9f3323f639a028436de Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 29 Aug 2017 15:31:58 +0300 Subject: [PATCH 168/547] Make debug flag thread-local --- pony/orm/core.py | 14 +++++++------- pony/orm/dbapiprovider.py | 14 +++++++------- pony/orm/dbproviders/mysql.py | 6 +++--- pony/orm/dbproviders/oracle.py | 4 ++-- pony/orm/dbproviders/postgres.py | 6 +++--- pony/orm/dbproviders/sqlite.py | 8 ++++---- pony/orm/dbschema.py | 4 ++-- 7 files changed, 28 insertions(+), 28 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 9a900868a..b4cae137d 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -67,12 +67,11 @@ 'user_groups_getter', 'user_roles_getter', 'obj_labels_getter' ] -debug = False suppress_debug_change = False def sql_debug(value): - global debug - if not suppress_debug_change: debug = value + if not suppress_debug_change: + local.debug = value orm_logger = logging.getLogger('pony.orm') sql_logger = logging.getLogger('pony.orm.sql') @@ -263,6 +262,7 @@ def adapt_sql(sql, paramstyle): class Local(localbase): def __init__(local): + local.debug = False local.db2cache = {} local.db_context_counter = 0 local.db_session = None @@ -709,14 +709,14 @@ def _exec_sql(database, sql, arguments=None, returning_id=False, start_transacti if start_transaction: cache.immediate = True connection = cache.prepare_connection_for_query_execution() cursor = connection.cursor() - if debug: log_sql(sql, arguments) + if local.debug: log_sql(sql, arguments) provider = database.provider t = time() try: new_id = provider.execute(cursor, sql, arguments, returning_id) except Exception as e: connection = cache.reconnect(e) cursor = connection.cursor() - if debug: log_sql(sql, arguments) + if local.debug: log_sql(sql, arguments) t = time() new_id = provider.execute(cursor, sql, arguments, returning_id) if cache.immediate: cache.in_transaction = True @@ -943,7 +943,7 @@ def _drop_tables(database, table_names, if_exists, with_all_data, try_normalized 'Cannot drop table %s because it is not empty. Specify option ' 'with_all_data=True if you want to drop table with all data' % table_name) for table_name in existed_tables: - if debug: log_orm('DROPPING TABLE %s' % table_name) + if local.debug: log_orm('DROPPING TABLE %s' % table_name) provider.drop_table(connection, table_name) @cut_traceback @db_session(ddl=True) @@ -1524,7 +1524,7 @@ def reconnect(cache, exc): if exc is not None: exc = getattr(exc, 'original_exc', exc) if not provider.should_reconnect(exc): reraise(*sys.exc_info()) - if debug: log_orm('CONNECTION FAILED: %s' % exc) + if local.debug: log_orm('CONNECTION FAILED: %s' % exc) connection = cache.connection assert connection is not None cache.connection = None diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index fc7b13ba6..2bb090765 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -205,14 +205,14 @@ def set_transaction_mode(provider, connection, cache): @wrap_dbapi_exceptions def commit(provider, connection, cache=None): core = pony.orm.core - if core.debug: core.log_orm('COMMIT') + if core.local.debug: core.log_orm('COMMIT') connection.commit() if cache is not None: cache.in_transaction = False @wrap_dbapi_exceptions def rollback(provider, connection, cache=None): core = pony.orm.core - if core.debug: core.log_orm('ROLLBACK') + if core.local.debug: core.log_orm('ROLLBACK') connection.rollback() if cache is not None: cache.in_transaction = False @@ -222,20 +222,20 @@ def release(provider, connection, cache=None): if cache is not None and cache.db_session is not None and cache.db_session.ddl: provider.drop(connection, cache) else: - if core.debug: core.log_orm('RELEASE CONNECTION') + if core.local.debug: core.log_orm('RELEASE CONNECTION') provider.pool.release(connection) @wrap_dbapi_exceptions def drop(provider, connection, cache=None): core = pony.orm.core - if core.debug: core.log_orm('CLOSE CONNECTION') + if core.local.debug: core.log_orm('CLOSE CONNECTION') provider.pool.drop(connection) if cache is not None: cache.in_transaction = False @wrap_dbapi_exceptions def disconnect(provider): core = pony.orm.core - if core.debug: core.log_orm('DISCONNECT') + if core.local.debug: core.log_orm('DISCONNECT') provider.pool.disconnect() @wrap_dbapi_exceptions @@ -311,10 +311,10 @@ def connect(pool): pool.con = pool.pid = None core = pony.orm.core if pool.con is None: - if core.debug: core.log_orm('GET NEW CONNECTION') + if core.local.debug: core.log_orm('GET NEW CONNECTION') pool._connect() pool.pid = pid - elif core.debug: core.log_orm('GET CONNECTION FROM THE LOCAL POOL') + elif core.local.debug: core.log_orm('GET CONNECTION FROM THE LOCAL POOL') return pool.con def _connect(pool): pool.con = pool.dbapi_module.connect(*pool.args, **pool.kwargs) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index d7830b0b2..28ab3aa5f 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -261,7 +261,7 @@ def set_transaction_mode(provider, connection, cache): if fk is not None: fk = (fk[1] == 'ON') if fk: sql = 'SET foreign_key_checks = 0' - if core.debug: log_orm(sql) + if core.local.debug: log_orm(sql) cursor.execute(sql) cache.saved_fk_state = bool(fk) cache.in_transaction = True @@ -269,7 +269,7 @@ def set_transaction_mode(provider, connection, cache): if db_session is not None and db_session.serializable: cursor = connection.cursor() sql = 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE' - if core.debug: log_orm(sql) + if core.local.debug: log_orm(sql) cursor.execute(sql) cache.in_transaction = True @@ -281,7 +281,7 @@ def release(provider, connection, cache=None): try: cursor = connection.cursor() sql = 'SET foreign_key_checks = 1' - if core.debug: log_orm(sql) + if core.local.debug: log_orm(sql) cursor.execute(sql) except: provider.pool.drop(connection) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 16d83214e..9fc7d269f 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -440,7 +440,7 @@ def set_transaction_mode(provider, connection, cache): if db_session is not None and db_session.serializable: cursor = connection.cursor() sql = 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE' - if core.debug: log_orm(sql) + if core.local.debug: log_orm(sql) cursor.execute(sql) cache.immediate = True if db_session is not None and (db_session.serializable or db_session.ddl): @@ -570,7 +570,7 @@ def connect(pool): pool.forked_pools.append((pool.cx_pool, pool.pid)) pool.cx_pool = cx_Oracle.SessionPool(**pool.kwargs) pool.pid = os.getpid() - if core.debug: log_orm('GET CONNECTION') + if core.local.debug: log_orm('GET CONNECTION') con = pool.cx_pool.acquire() con.outputtypehandler = output_type_handler return con diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 9316bea1c..41412d8d5 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -192,16 +192,16 @@ def set_transaction_mode(provider, connection, cache): assert not cache.in_transaction if cache.immediate and connection.autocommit: connection.autocommit = False - if core.debug: log_orm('SWITCH FROM AUTOCOMMIT TO TRANSACTION MODE') + if core.local.debug: log_orm('SWITCH FROM AUTOCOMMIT TO TRANSACTION MODE') db_session = cache.db_session if db_session is not None and db_session.serializable: cursor = connection.cursor() sql = 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE' - if core.debug: log_orm(sql) + if core.local.debug: log_orm(sql) cursor.execute(sql) elif not cache.immediate and not connection.autocommit: connection.autocommit = True - if core.debug: log_orm('SWITCH TO AUTOCOMMIT MODE') + if core.local.debug: log_orm('SWITCH TO AUTOCOMMIT MODE') if db_session is not None and (db_session.serializable or db_session.ddl): cache.in_transaction = True diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 3921e1e2f..d3e5ba343 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -296,17 +296,17 @@ def set_transaction_mode(provider, connection, cache): if fk is not None: fk = fk[0] if fk: sql = 'PRAGMA foreign_keys = false' - if core.debug: log_orm(sql) + if core.local.debug: log_orm(sql) cursor.execute(sql) cache.saved_fk_state = bool(fk) assert cache.immediate if cache.immediate: sql = 'BEGIN IMMEDIATE TRANSACTION' - if core.debug: log_orm(sql) + if core.local.debug: log_orm(sql) cursor.execute(sql) cache.in_transaction = True - elif core.debug: log_orm('SWITCH TO AUTOCOMMIT MODE') + elif core.local.debug: log_orm('SWITCH TO AUTOCOMMIT MODE') finally: if cache.immediate and not cache.in_transaction: provider.transaction_lock.release() @@ -346,7 +346,7 @@ def release(provider, connection, cache=None): try: cursor = connection.cursor() sql = 'PRAGMA foreign_keys = true' - if core.debug: log_orm(sql) + if core.local.debug: log_orm(sql) cursor.execute(sql) except: provider.pool.drop(connection) diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index e6c3f8b7e..fcc0cd8af 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -73,13 +73,13 @@ def check_tables(schema, provider, connection): [ 'WHERE', [ 'EQ', [ 'VALUE', 0 ], [ 'VALUE', 1 ] ] ] ] sql, adapter = provider.ast2sql(sql_ast) - if core.debug: log_sql(sql) + if core.local.debug: log_sql(sql) provider.execute(cursor, sql) class DBObject(object): def create(table, provider, connection): sql = table.get_create_command() - if core.debug: log_sql(sql) + if core.local.debug: log_sql(sql) cursor = connection.cursor() provider.execute(cursor, sql) From cdc41d52e6e278ce2cba8b13b736ad78e0654c6b Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 29 Aug 2017 16:40:58 +0300 Subject: [PATCH 169/547] sql_debugging context manager added. --- pony/orm/core.py | 71 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 70 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index b4cae137d..5f55c6ce0 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -46,7 +46,7 @@ 'DatabaseContainsIncorrectValue', 'DatabaseContainsIncorrectEmptyValue', 'TranslationError', 'ExprEvalError', 'PermissionError', - 'Database', 'sql_debug', 'show', + 'Database', 'sql_debug', 'sql_debugging', 'show', 'PrimaryKey', 'Required', 'Optional', 'Set', 'Discriminator', 'composite_key', 'composite_index', @@ -263,6 +263,7 @@ def adapt_sql(sql, paramstyle): class Local(localbase): def __init__(local): local.debug = False + local.debug_stack = [] local.db2cache = {} local.db_context_counter = 0 local.db_session = None @@ -270,6 +271,11 @@ def __init__(local): local.perms_context = None local.user_groups_cache = {} local.user_roles_cache = defaultdict(dict) + def push_debug_state(local, debug): + local.debug_stack.append(local.debug) + local.debug = debug + def pop_debug_state(local): + local.debug = local.debug_stack.pop() local = Local() @@ -506,6 +512,69 @@ def wrapped_interact(iterator, input=None, exc_info=None): db_session = DBSessionContextManager() + +class SQLDebuggingContextManager(object): + def __init__(self, debug=True): + self.debug = debug + def __call__(self, *args, **kwargs): + if not kwargs and len(args) == 1 and callable(args[0]): + arg = args[0] + if not isgeneratorfunction(arg): + return self._wrap_function(arg) + return self._wrap_generator_function(arg) + return self.__class__(*args, **kwargs) + def __enter__(self): + local.push_debug_state(self.debug) + def __exit__(self, exc_type=None, exc=None, tb=None): + local.pop_debug_state() + def _wrap_function(self, func): + def new_func(func, *args, **kwargs): + self.__enter__() + try: + return func(*args, **kwargs) + finally: + self.__exit__() + return decorator(new_func, func) + def _wrap_generator_function(self, gen_func): + def interact(iterator, input=None, exc_info=None): + if exc_info is None: + return next(iterator) if input is None else iterator.send(input) + + if exc_info[0] is GeneratorExit: + close = getattr(iterator, 'close', None) + if close is not None: close() + reraise(*exc_info) + + throw_ = getattr(iterator, 'throw', None) + if throw_ is None: reraise(*exc_info) + return throw_(*exc_info) + + def new_gen_func(gen_func, *args, **kwargs): + def wrapped_interact(iterator, input=None, exc_info=None): + self.__enter__() + try: + return interact(iterator, input, exc_info) + finally: + self.__exit__() + + gen = gen_func(*args, **kwargs) + iterator = iter(gen) + output = wrapped_interact(iterator) + try: + while True: + try: + input = yield output + except: + output = wrapped_interact(iterator, exc_info=sys.exc_info()) + else: + output = wrapped_interact(iterator, input) + except StopIteration: + return + return decorator(new_gen_func, gen_func) + +sql_debugging = SQLDebuggingContextManager() + + def throw_db_session_is_over(action, obj, attr=None): msg = 'Cannot %s %s%s: the database session is over' throw(DatabaseSessionIsOver, msg % (action, safe_repr(obj), '.%s' % attr.name if attr else '')) From f54f08140042667b79d24cc03818db1b55e52afc Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 29 Aug 2017 18:08:42 +0300 Subject: [PATCH 170/547] show_values flag added to sql_debugging --- pony/orm/core.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 5f55c6ce0..ef7d949f8 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -99,14 +99,17 @@ def log_sql(sql, arguments=None): if type(arguments) is list: sql = 'EXECUTEMANY (%d)\n%s' % (len(arguments), sql) if has_handlers(sql_logger): - sql_logger.log(orm_log_level, sql) # arguments can hold sensitive information + if local.show_values and arguments: + sql = '%s\n%s' % (sql, format_arguments(arguments)) + sql_logger.log(orm_log_level, sql) else: - print(sql) - if not arguments: pass - elif type(arguments) is list: - for args in arguments: print(args2str(args)) - else: print(args2str(arguments)) - print() + if (local.show_values is None or local.show_values) and arguments: + sql = '%s\n%s' % (sql, format_arguments(arguments)) + print(sql, end='\n\n') + +def format_arguments(arguments): + if type(arguments) is not list: return args2str(arguments) + return '\n'.join(args2str(args) for args in arguments) def args2str(args): if isinstance(args, (tuple, list)): @@ -263,6 +266,7 @@ def adapt_sql(sql, paramstyle): class Local(localbase): def __init__(local): local.debug = False + local.show_values = None local.debug_stack = [] local.db2cache = {} local.db_context_counter = 0 @@ -271,11 +275,12 @@ def __init__(local): local.perms_context = None local.user_groups_cache = {} local.user_roles_cache = defaultdict(dict) - def push_debug_state(local, debug): - local.debug_stack.append(local.debug) + def push_debug_state(local, debug, show_values): + local.debug_stack.append((local.debug, local.show_values)) local.debug = debug + local.show_values = show_values def pop_debug_state(local): - local.debug = local.debug_stack.pop() + local.debug, local.show_values = local.debug_stack.pop() local = Local() @@ -514,8 +519,9 @@ def wrapped_interact(iterator, input=None, exc_info=None): class SQLDebuggingContextManager(object): - def __init__(self, debug=True): + def __init__(self, debug=True, show_values=None): self.debug = debug + self.show_values = show_values def __call__(self, *args, **kwargs): if not kwargs and len(args) == 1 and callable(args[0]): arg = args[0] @@ -524,7 +530,7 @@ def __call__(self, *args, **kwargs): return self._wrap_generator_function(arg) return self.__class__(*args, **kwargs) def __enter__(self): - local.push_debug_state(self.debug) + local.push_debug_state(self.debug, self.show_values) def __exit__(self, exc_type=None, exc=None, tb=None): local.pop_debug_state() def _wrap_function(self, func): From ac3f5727f4c3112d01b87e6ee116faa0e9d04690 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 29 Aug 2017 18:38:37 +0300 Subject: [PATCH 171/547] Add sql_debug and show_values arguments to db_session --- pony/orm/core.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index ef7d949f8..26cf27440 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -356,9 +356,10 @@ def rollback(): select_re = re.compile(r'\s*select\b', re.IGNORECASE) class DBSessionContextManager(object): - __slots__ = 'retry', 'retry_exceptions', 'allowed_exceptions', 'immediate', 'ddl', 'serializable', 'strict' + __slots__ = 'retry', 'retry_exceptions', 'allowed_exceptions', 'immediate', 'ddl', 'serializable', 'strict', \ + 'sql_debug', 'show_values' def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False, strict=False, - retry_exceptions=(TransactionError,), allowed_exceptions=()): + retry_exceptions=(TransactionError,), allowed_exceptions=(), sql_debug=None, show_values=None): if retry is not 0: if type(retry) is not int: throw(TypeError, "'retry' parameter of db_session must be of integer type. Got: %s" % type(retry)) @@ -377,6 +378,8 @@ def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False db_session.strict = strict db_session.retry_exceptions = retry_exceptions db_session.allowed_exceptions = allowed_exceptions + db_session.sql_debug = sql_debug + db_session.show_values = show_values def __call__(db_session, *args, **kwargs): if not args and not kwargs: return db_session if len(args) > 1: throw(TypeError, @@ -401,7 +404,11 @@ def _enter(db_session): elif db_session.serializable and not local.db_session.serializable: throw(TransactionError, 'Cannot start serializable transaction inside non-serializable transaction') local.db_context_counter += 1 + if db_session.sql_debug is not None: + local.push_debug_state(db_session.sql_debug, db_session.show_values) def __exit__(db_session, exc_type=None, exc=None, tb=None): + if db_session.sql_debug is not None: + local.pop_debug_state() local.db_context_counter -= 1 if local.db_context_counter: return assert local.db_session is db_session @@ -431,6 +438,8 @@ def new_func(func, *args, **kwargs): if db_session.ddl and local.db_context_counter: if isinstance(func, types.FunctionType): func = func.__name__ + '()' throw(TransactionError, '%s cannot be called inside of db_session' % func) + if db_session.sql_debug is not None: + local.push_debug_state(db_session.sql_debug, db_session.show_values) exc = tb = None try: for i in xrange(db_session.retry+1): @@ -448,7 +457,10 @@ def new_func(func, *args, **kwargs): if not do_retry: raise finally: db_session.__exit__(exc_type, exc, tb) reraise(exc_type, exc, tb) - finally: del exc, tb + finally: + del exc, tb + if db_session.sql_debug is not None: + local.pop_debug_state() return decorator(new_func, func) def _wrap_generator_function(db_session, gen_func): for option in ('ddl', 'retry', 'serializable'): @@ -479,6 +491,8 @@ def wrapped_interact(iterator, input=None, exc_info=None): local.db_session = db_session local.db2cache.update(db2cache_copy) db2cache_copy.clear() + if db_session.sql_debug is not None: + local.push_debug_state(db_session.sql_debug, db_session.show_values) try: try: output = interact(iterator, input, exc_info) @@ -495,6 +509,8 @@ def wrapped_interact(iterator, input=None, exc_info=None): else: return output finally: + if db_session.sql_debug is not None: + local.pop_debug_state() db2cache_copy.update(local.db2cache) local.db2cache.clear() local.db_context_counter = 0 From 5e6ec2cba542da448b2995a2bbcbe0cfd7efce34 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 29 Aug 2017 18:50:42 +0300 Subject: [PATCH 172/547] Add set_sql_debug function --- pony/orm/core.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 26cf27440..d81d7f793 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -46,7 +46,7 @@ 'DatabaseContainsIncorrectValue', 'DatabaseContainsIncorrectEmptyValue', 'TranslationError', 'ExprEvalError', 'PermissionError', - 'Database', 'sql_debug', 'sql_debugging', 'show', + 'Database', 'sql_debug', 'set_sql_debug', 'sql_debugging', 'show', 'PrimaryKey', 'Required', 'Optional', 'Set', 'Discriminator', 'composite_key', 'composite_index', @@ -70,9 +70,17 @@ suppress_debug_change = False def sql_debug(value): + # todo: make sql_debug deprecated if not suppress_debug_change: local.debug = value + +def set_sql_debug(debug=True, show_values=None): + if not suppress_debug_change: + local.debug = debug + local.show_values = show_values + + orm_logger = logging.getLogger('pony.orm') sql_logger = logging.getLogger('pony.orm.sql') @@ -277,8 +285,9 @@ def __init__(local): local.user_roles_cache = defaultdict(dict) def push_debug_state(local, debug, show_values): local.debug_stack.append((local.debug, local.show_values)) - local.debug = debug - local.show_values = show_values + if not suppress_debug_change: + local.debug = debug + local.show_values = show_values def pop_debug_state(local): local.debug, local.show_values = local.debug_stack.pop() From 892d5364aa8761f9510512d1956ff45035d13af8 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 29 Aug 2017 19:30:47 +0300 Subject: [PATCH 173/547] Cosmetic changes --- pony/orm/tests/test_bug_182.py | 6 +----- pony/orm/tests/test_indexes.py | 4 ++-- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/pony/orm/tests/test_bug_182.py b/pony/orm/tests/test_bug_182.py index f96bdbbe4..5d59b2fe7 100644 --- a/pony/orm/tests/test_bug_182.py +++ b/pony/orm/tests/test_bug_182.py @@ -4,8 +4,6 @@ from pony.orm import * from pony import orm -import os - class Test(unittest.TestCase): @@ -36,9 +34,7 @@ class Server(db.Entity): @db_session def test(self): - qu = left_join( - (s.name, s.user.name) for s in self.db.Server - )[:] + qu = left_join((s.name, s.user.name) for s in self.db.Server)[:] for server, user in qu: if user is None: break diff --git a/pony/orm/tests/test_indexes.py b/pony/orm/tests/test_indexes.py index 0aee54622..403d673cc 100644 --- a/pony/orm/tests/test_indexes.py +++ b/pony/orm/tests/test_indexes.py @@ -14,7 +14,7 @@ class Person(db.Entity): composite_key(name, 'age') db.generate_mapping(create_tables=True) - [ i1, i2 ] = Person._indexes_ + i1, i2 = Person._indexes_ self.assertEqual(i1.attrs, (Person.id,)) self.assertEqual(i1.is_pk, True) self.assertEqual(i1.is_unique, True) @@ -38,7 +38,7 @@ class Person(db.Entity): composite_index(name, 'age') db.generate_mapping(create_tables=True) - [ i1, i2 ] = Person._indexes_ + i1, i2 = Person._indexes_ self.assertEqual(i1.attrs, (Person.id,)) self.assertEqual(i1.is_pk, True) self.assertEqual(i1.is_unique, True) From a68007452cf47ca1e3c8a4c003547f38464bfe25 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 29 Aug 2017 19:32:42 +0300 Subject: [PATCH 174/547] Fix wrong test name --- pony/orm/tests/test_indexes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/tests/test_indexes.py b/pony/orm/tests/test_indexes.py index 403d673cc..eef32618f 100644 --- a/pony/orm/tests/test_indexes.py +++ b/pony/orm/tests/test_indexes.py @@ -58,7 +58,7 @@ class Person(db.Entity): index_sql = 'CREATE INDEX "idx_person__name_age" ON "Person" ("name", "age")' self.assertTrue(index_sql in create_script) - def test_2(self): + def test_3(self): db = Database('sqlite', ':memory:') class User(db.Entity): name = Required(str, unique=True) From cae000fff2fe4f9ebbd49ea214903447924219f6 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 29 Aug 2017 19:53:23 +0300 Subject: [PATCH 175/547] Fixes #170: Problem with a primary key column used as a part of another key --- pony/orm/dbschema.py | 6 +++--- pony/orm/tests/test_bug_170.py | 22 ++++++++++++++++++++++ 2 files changed, 25 insertions(+), 3 deletions(-) create mode 100644 pony/orm/tests/test_bug_170.py diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index fcc0cd8af..19597bf57 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -257,9 +257,9 @@ def __init__(index, name, table, columns, is_pk=False, is_unique=None): throw(DBSchemaError, 'Index %s cannot be created, name is already in use' % name) Constraint.__init__(index, name, schema) for column in columns: - column.is_pk = len(columns) == 1 and is_pk - column.is_pk_part = bool(is_pk) - column.is_unique = is_unique and len(columns) == 1 + column.is_pk = column.is_pk or (len(columns) == 1 and is_pk) + column.is_pk_part = column.is_pk_part or bool(is_pk) + column.is_unique = column.is_unique or (is_unique and len(columns) == 1) table.indexes[columns] = index index.table = table index.columns = columns diff --git a/pony/orm/tests/test_bug_170.py b/pony/orm/tests/test_bug_170.py new file mode 100644 index 000000000..e60127d78 --- /dev/null +++ b/pony/orm/tests/test_bug_170.py @@ -0,0 +1,22 @@ +import unittest + +from pony import orm + +class Test(unittest.TestCase): + def test_1(self): + db = orm.Database('sqlite', ':memory:') + + class Person(db.Entity): + id = orm.PrimaryKey(int, auto=True) + name = orm.Required(str) + orm.composite_key(id, name) + + db.generate_mapping(create_tables=True) + + table = db.schema.tables[Person._table_] + pk_column = table.column_dict[Person.id.column] + self.assertTrue(pk_column.is_pk) + + with orm.db_session: + p1 = Person(name='John') + p2 = Person(name='Mike') From f6246e98afaaebfedbb2014202d38af9038994fc Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 30 Aug 2017 16:33:26 +0300 Subject: [PATCH 176/547] Fix bugs with composite table names --- pony/orm/dbschema.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index 19597bf57..51131f2ee 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -31,7 +31,8 @@ def add_table(schema, table_name): def order_tables_to_create(schema): tables = [] created_tables = set() - tables_to_create = sorted(itervalues(schema.tables), key=lambda table: table.name) + split = schema.provider.split_table_name + tables_to_create = sorted(itervalues(schema.tables), key=lambda table: split(table.name)) while tables_to_create: for table in tables_to_create: if table.parent_tables.issubset(created_tables): @@ -63,7 +64,8 @@ def create_tables(schema, provider, connection): 'Try to delete %s %s first.' % (tn1, n1, tn2, n2, n2, tn2)) def check_tables(schema, provider, connection): cursor = connection.cursor() - for table in sorted(itervalues(schema.tables), key=lambda table: table.name): + split = provider.split_table_name + for table in sorted(itervalues(schema.tables), key=lambda table: split(table.name)): if isinstance(table.name, tuple): alias = table.name[-1] elif isinstance(table.name, basestring): alias = table.name else: assert False # pragma: no cover From 0f02aad5cc4ea43677f5e5d844e46a3c95fb5f56 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 29 Aug 2017 20:31:07 +0300 Subject: [PATCH 177/547] Fix invalid foreign key & index names for tables which names include schema name --- pony/orm/dbapiprovider.py | 11 +++++++++-- pony/orm/dbschema.py | 4 +--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 2bb090765..ea8a13eb1 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -158,12 +158,12 @@ def get_default_index_name(provider, table_name, column_names, is_pk=False, is_u if is_unique: template = 'unq_%(tname)s__%(cnames)s' elif m2m: template = 'idx_%(tname)s' else: template = 'idx_%(tname)s__%(cnames)s' - index_name = template % dict(tname=table_name, + index_name = template % dict(tname=provider.base_name(table_name), cnames='_'.join(name for name in column_names)) return provider.normalize_name(index_name.lower()) def get_default_fk_name(provider, child_table_name, parent_table_name, child_column_names): - fk_name = 'fk_%s__%s' % (child_table_name, '__'.join(child_column_names)) + fk_name = 'fk_%s__%s' % (provider.base_name(child_table_name), '__'.join(child_column_names)) return provider.normalize_name(fk_name.lower()) def split_table_name(provider, table_name): @@ -177,6 +177,13 @@ def split_table_name(provider, table_name): size, 's' if size != 1 else '', table_name)) return table_name[0], table_name[1] + def base_name(provider, name): + if not isinstance(name, basestring): + assert type(name) is tuple + name = name[-1] + assert isinstance(name, basestring) + return name + def quote_name(provider, name): quote_char = provider.quote_char if isinstance(name, basestring): diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index 51131f2ee..f968d555f 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -66,9 +66,7 @@ def check_tables(schema, provider, connection): cursor = connection.cursor() split = provider.split_table_name for table in sorted(itervalues(schema.tables), key=lambda table: split(table.name)): - if isinstance(table.name, tuple): alias = table.name[-1] - elif isinstance(table.name, basestring): alias = table.name - else: assert False # pragma: no cover + alias = provider.base_name(table.name) sql_ast = [ 'SELECT', [ 'ALL', ] + [ [ 'COLUMN', alias, column.name ] for column in table.column_list ], [ 'FROM', [ alias, 'TABLE', table.name ] ], From de7e7e1317e1cc934b3b6f8c36b18ad7386a2e8d Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 30 Aug 2017 18:06:35 +0300 Subject: [PATCH 178/547] Raise on unknown options for attributes that are part of relationship --- pony/orm/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index d81d7f793..2e80100fc 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1985,6 +1985,8 @@ def linked(attr): if reverse.is_collection: throw(TypeError, "'cascade_delete' option cannot be set for attribute %s, " "because reverse attribute %s is collection" % (attr, reverse)) + for option in attr.kwargs: + throw(TypeError, 'Attribute %s has unknown option %r' % (attr, option)) @cut_traceback def __repr__(attr): owner_name = attr.entity.__name__ if attr.entity else '?' @@ -2532,7 +2534,6 @@ def __init__(attr, py_type, *args, **kwargs): else: attr.reverse_columns = [] attr.nplus1_threshold = kwargs.pop('nplus1_threshold', 1) - for option in attr.kwargs: throw(TypeError, 'Unknown option %r' % option) attr.cached_load_sql = {} attr.cached_add_m2m_sql = None attr.cached_remove_m2m_sql = None From dae99b66141d5b14c314efc60d066abac62df323 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 30 Aug 2017 17:26:45 +0300 Subject: [PATCH 179/547] fk_name option added for attributes --- pony/orm/core.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 2e80100fc..7c3a675b8 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -978,16 +978,16 @@ def get_columns(table, column_names): m2m_table = schema.tables[attr.table] parent_columns = get_columns(table, entity._pk_columns_) child_columns = get_columns(m2m_table, reverse.columns) - m2m_table.add_foreign_key(None, child_columns, table, parent_columns, attr.index) + m2m_table.add_foreign_key(reverse.fk_name, child_columns, table, parent_columns, attr.index) if attr.symmetric: child_columns = get_columns(m2m_table, attr.reverse_columns) - m2m_table.add_foreign_key(None, child_columns, table, parent_columns) + m2m_table.add_foreign_key(attr.reverse_fk_name, child_columns, table, parent_columns) elif attr.reverse and attr.columns: rentity = attr.reverse.entity parent_table = schema.tables[rentity._table_] parent_columns = get_columns(parent_table, rentity._pk_columns_) child_columns = get_columns(table, attr.columns) - table.add_foreign_key(None, child_columns, parent_table, parent_columns, attr.index) + table.add_foreign_key(attr.reverse.fk_name, child_columns, parent_table, parent_columns, attr.index) elif attr.index and attr.columns: columns = tuple(imap(table.column_dict.__getitem__, attr.columns)) table.add_index(attr.index, columns, is_unique=attr.is_unique) @@ -1853,7 +1853,7 @@ class Attribute(object): 'lazy', 'lazy_sql_cache', 'args', 'auto', 'default', 'reverse', 'composite_keys', \ 'column', 'columns', 'col_paths', '_columns_checked', 'converters', 'kwargs', \ 'cascade_delete', 'index', 'original_default', 'sql_default', 'py_check', 'hidden', \ - 'optimistic' + 'optimistic', 'fk_name' def __deepcopy__(attr, memo): return attr # Attribute cannot be cloned by deepcopy() @cut_traceback @@ -1912,6 +1912,7 @@ def __init__(attr, py_type, *args, **kwargs): if len(attr.columns) == 1: attr.column = attr.columns[0] else: attr.columns = [] attr.index = kwargs.pop('index', None) + attr.fk_name = kwargs.pop('fk_name', None) attr.col_paths = [] attr._columns_checked = False attr.composite_keys = [] @@ -1985,6 +1986,9 @@ def linked(attr): if reverse.is_collection: throw(TypeError, "'cascade_delete' option cannot be set for attribute %s, " "because reverse attribute %s is collection" % (attr, reverse)) + if attr.is_collection and not reverse.is_collection: + if attr.fk_name is not None: + throw(TypeError, 'You should specify fk_name in %s instead of %s' % (reverse, attr)) for option in attr.kwargs: throw(TypeError, 'Attribute %s has unknown option %r' % (attr, option)) @cut_traceback @@ -2500,7 +2504,7 @@ def __new__(cls, *args, **kwargs): class Collection(Attribute): __slots__ = 'table', 'wrapper_class', 'symmetric', 'reverse_column', 'reverse_columns', \ 'nplus1_threshold', 'cached_load_sql', 'cached_add_m2m_sql', 'cached_remove_m2m_sql', \ - 'cached_count_sql', 'cached_empty_sql' + 'cached_count_sql', 'cached_empty_sql', 'reverse_fk_name' def __init__(attr, py_type, *args, **kwargs): if attr.__class__ is Collection: throw(TypeError, "'Collection' is abstract type") table = kwargs.pop('table', None) # TODO: rename table to link_table or m2m_table @@ -2533,6 +2537,8 @@ def __init__(attr, py_type, *args, **kwargs): if len(attr.reverse_columns) == 1: attr.reverse_column = attr.reverse_columns[0] else: attr.reverse_columns = [] + attr.reverse_fk_name = kwargs.pop('reverse_fk_name', None) + attr.nplus1_threshold = kwargs.pop('nplus1_threshold', 1) attr.cached_load_sql = {} attr.cached_add_m2m_sql = None From ba11b480bbf60e84fa2c5dc045054e2017472be2 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 7 Sep 2017 19:41:45 +0300 Subject: [PATCH 180/547] Fix TestDatabase provider: add support of `provider` keyword constructor argument and `json1_available` flag for SQLite --- pony/orm/tests/testutils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pony/orm/tests/testutils.py b/pony/orm/tests/testutils.py index 890cddf8d..e87c65952 100644 --- a/pony/orm/tests/testutils.py +++ b/pony/orm/tests/testutils.py @@ -96,7 +96,9 @@ class TestDatabase(Database): real_provider_name = None raw_server_version = None sql = None - def bind(self, provider_name, *args, **kwargs): + def bind(self, provider, *args, **kwargs): + provider_name = provider + assert isinstance(provider_name, basestring) if self.real_provider_name is not None: provider_name = self.real_provider_name self.provider_name = provider_name @@ -118,6 +120,7 @@ def bind(self, provider_name, *args, **kwargs): server_version = int('%d%02d%02d' % server_version) class TestProvider(provider_cls): + json1_available = False # for SQLite def inspect_connection(provider, connection): pass TestProvider.server_version = server_version From 03260bff431ceaf142cd2da796176954641fb82a Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 7 Sep 2017 19:42:27 +0300 Subject: [PATCH 181/547] Fix aliases in queries.txt --- pony/orm/tests/queries.txt | 164 ++++++++++++++++++------------------- 1 file changed, 82 insertions(+), 82 deletions(-) diff --git a/pony/orm/tests/queries.txt b/pony/orm/tests/queries.txt index fb51bbd58..f3cb605a6 100644 --- a/pony/orm/tests/queries.txt +++ b/pony/orm/tests/queries.txt @@ -106,9 +106,9 @@ WHERE "s"."group" = 101 >>> avg(s.gpa for s in Student if s.group.dept.number == 44) SELECT AVG("s"."gpa") -FROM "Student" "s", "Group" "group-1" -WHERE "group-1"."dept" = 44 - AND "s"."group" = "group-1"."number" +FROM "Student" "s", "Group" "group" +WHERE "group"."dept" = 44 + AND "s"."group" = "group"."number" >>> select(s for s in Student if s.group.number == 101 and s.dob == max(s.dob for s in Student if s.group.number == 101)) @@ -116,9 +116,9 @@ SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" FROM "Student" "s" WHERE "s"."group" = 101 AND "s"."dob" = ( - SELECT MAX("s"."dob") - FROM "Student" "s" - WHERE "s"."group" = 101 + SELECT MAX("s-2"."dob") + FROM "Student" "s-2" + WHERE "s-2"."group" = 101 ) >>> select(g for g in Group if avg(s.gpa for s in g.students) > 4.5) @@ -135,10 +135,10 @@ WHERE ( SELECT "g"."number" FROM "Group" "g" - LEFT JOIN "Student" "student-1" - ON "g"."number" = "student-1"."group" + LEFT JOIN "Student" "student" + ON "g"."number" = "student"."group" GROUP BY "g"."number" -HAVING AVG("student-1"."gpa") > 4.5 +HAVING AVG("student"."gpa") > 4.5 >>> select((s.group, min(s.gpa), max(s.gpa)) for s in Student) @@ -154,19 +154,19 @@ WHERE "s"."group" = 101 >>> select((g, count(g.students)) for g in Group if g.dept.number == 44) -SELECT "g"."number", COUNT(DISTINCT "student-1"."id") +SELECT "g"."number", COUNT(DISTINCT "student"."id") FROM "Group" "g" - LEFT JOIN "Student" "student-1" - ON "g"."number" = "student-1"."group" + LEFT JOIN "Student" "student" + ON "g"."number" = "student"."group" WHERE "g"."dept" = 44 GROUP BY "g"."number" >>> select((s.group, count(s)) for s in Student if s.group.dept.number == 44) SELECT "s"."group", COUNT(DISTINCT "s"."id") -FROM "Student" "s", "Group" "group-1" -WHERE "group-1"."dept" = 44 - AND "s"."group" = "group-1"."number" +FROM "Student" "s", "Group" "group" +WHERE "group"."dept" = 44 + AND "s"."group" = "group"."number" GROUP BY "s"."group" >>> select((g, count(s for s in g.students if s.gpa <= 3), count(s for s in g.students if s.gpa > 3 and s.gpa <= 4), count(s for s in g.students if s.gpa > 4)) for g in Group) @@ -222,35 +222,35 @@ GROUP BY "item"."order" >>> select((order, sum(order.items.price * order.items.quantity)) for order in Order if order.id == 123) -SELECT "order"."id", coalesce(SUM(("orderitem-1"."price" * "orderitem-1"."quantity")), 0) +SELECT "order"."id", coalesce(SUM(("orderitem"."price" * "orderitem"."quantity")), 0) FROM "Order" "order" - LEFT JOIN "OrderItem" "orderitem-1" - ON "order"."id" = "orderitem-1"."order" + LEFT JOIN "OrderItem" "orderitem" + ON "order"."id" = "orderitem"."order" WHERE "order"."id" = 123 GROUP BY "order"."id" >>> select((item.order, item.order.total_price, sum(item.price * item.quantity)) for item in OrderItem if item.order.total_price < sum(item.price * item.quantity)) -SELECT "item"."order", "order-1"."total_price", coalesce(SUM(("item"."price" * "item"."quantity")), 0) -FROM "OrderItem" "item", "Order" "order-1" -WHERE "item"."order" = "order-1"."id" -GROUP BY "item"."order", "order-1"."total_price" -HAVING "order-1"."total_price" < coalesce(SUM(("item"."price" * "item"."quantity")), 0) +SELECT "item"."order", "order"."total_price", coalesce(SUM(("item"."price" * "item"."quantity")), 0) +FROM "OrderItem" "item", "Order" "order" +WHERE "item"."order" = "order"."id" +GROUP BY "item"."order", "order"."total_price" +HAVING "order"."total_price" < coalesce(SUM(("item"."price" * "item"."quantity")), 0) >>> select(c for c in Customer for p in c.orders.items.product if 'Tablets' in p.categories.name and count(p) > 1) SELECT DISTINCT "c"."id" -FROM "Customer" "c", "Order" "order-1", "OrderItem" "orderitem-1" +FROM "Customer" "c", "Order" "order", "OrderItem" "orderitem" WHERE 'Tablets' IN ( - SELECT "category-1"."name" - FROM "Category_Product" "t-1", "Category" "category-1" - WHERE "orderitem-1"."product" = "t-1"."product" - AND "t-1"."category" = "category-1"."id" + SELECT "category"."name" + FROM "Category_Product" "t-1", "Category" "category" + WHERE "orderitem"."product" = "t-1"."product" + AND "t-1"."category" = "category"."id" ) - AND "c"."id" = "order-1"."customer" - AND "order-1"."id" = "orderitem-1"."order" + AND "c"."id" = "order"."customer" + AND "order"."id" = "orderitem"."order" GROUP BY "c"."id" -HAVING COUNT(DISTINCT "orderitem-1"."product") > 1 +HAVING COUNT(DISTINCT "orderitem"."product") > 1 Schema: pony.orm.examples.university1 @@ -258,9 +258,9 @@ pony.orm.examples.university1 >>> select((s.group, count(s)) for s in Student if s.group.dept.number == 44 and avg(s.gpa) > 4) SELECT "s"."group", COUNT(DISTINCT "s"."id") -FROM "Student" "s", "Group" "group-1" -WHERE "group-1"."dept" = 44 - AND "s"."group" = "group-1"."number" +FROM "Student" "s", "Group" "group" +WHERE "group"."dept" = 44 + AND "s"."group" = "group"."number" GROUP BY "s"."group" HAVING AVG("s"."gpa") > 4 @@ -268,19 +268,19 @@ HAVING AVG("s"."gpa") > 4 SELECT "g"."number" FROM "Group" "g" - LEFT JOIN "Student" "student-1" - ON "g"."number" = "student-1"."group" + LEFT JOIN "Student" "student" + ON "g"."number" = "student"."group" GROUP BY "g"."number" -HAVING MAX("student-1"."gpa") < 4 +HAVING MAX("student"."gpa") < 4 >>> select(g for g in Group if JOIN(max(g.students.gpa) < 4)) SELECT "g"."number" FROM "Group" "g" LEFT JOIN ( - SELECT "student-1"."group" AS "group", MAX("student-1"."gpa") AS "expr-1" - FROM "Student" "student-1" - GROUP BY "student-1"."group" + SELECT "student"."group" AS "group", MAX("student"."gpa") AS "expr-1" + FROM "Student" "student" + GROUP BY "student"."group" ) "t-1" ON "g"."number" = "t-1"."group" WHERE "t-1"."expr-1" < 4 @@ -559,9 +559,9 @@ WHERE "s"."TEL" IS NULL SELECT DISTINCT "s"."name" FROM "Student" "s" WHERE "s"."name" IN ( - SELECT "s"."name" - FROM "Student" "s" - GROUP BY "s"."name" + SELECT "s-2"."name" + FROM "Student" "s-2" + GROUP BY "s-2"."name" HAVING COUNT(*) > 1 ) @@ -675,11 +675,11 @@ SELECT COUNT(*) FROM ( SELECT "g"."number" FROM "Group" "g" - LEFT JOIN "Student" "student-1" - ON "g"."number" = "student-1"."group" + LEFT JOIN "Student" "student" + ON "g"."number" = "student"."group" WHERE "g"."number" > 101 GROUP BY "g"."number" - HAVING COUNT(DISTINCT "student-1"."id") > 0 + HAVING COUNT(DISTINCT "student"."id") > 0 ) "t" >>> count(g for g in Group if count(s for s in g.students) > 0 and g.number > 101) @@ -731,26 +731,26 @@ WHERE "gpa" > 3 DELETE FROM "Student" WHERE "id" IN ( SELECT "s"."id" - FROM "Student" "s", "Group" "group-1" - WHERE "group-1"."dept" = 1 - AND "s"."group" = "group-1"."number" + FROM "Student" "s", "Group" "group" + WHERE "group"."dept" = 1 + AND "s"."group" = "group"."number" ) MySQL: DELETE s FROM `student` `s` - INNER JOIN `group` `group-1` - ON `s`.`group` = `group-1`.`number` -WHERE `group-1`.`dept` = 1 + INNER JOIN `group` `group` + ON `s`.`group` = `group`.`number` +WHERE `group`.`dept` = 1 PostgreSQL: DELETE FROM "student" WHERE "id" IN ( SELECT "s"."id" - FROM "student" "s", "group" "group-1" - WHERE "group-1"."dept" = 1 - AND "s"."group" = "group-1"."number" + FROM "student" "s", "group" "group" + WHERE "group"."dept" = 1 + AND "s"."group" = "group"."number" ) Oracle: @@ -758,9 +758,9 @@ Oracle: DELETE FROM "STUDENT" WHERE "ID" IN ( SELECT "s"."ID" - FROM "STUDENT" "s", "GROUP" "group-1" - WHERE "group-1"."DEPT" = 1 - AND "s"."GROUP" = "group-1"."NUMBER" + FROM "STUDENT" "s", "GROUP" "group" + WHERE "group"."DEPT" = 1 + AND "s"."GROUP" = "group"."NUMBER" ) >>> select(c for c in Course if c.dept.name.startswith('D')).delete(bulk=True) @@ -768,26 +768,26 @@ WHERE "ID" IN ( DELETE FROM "Course" WHERE "ROWID" IN ( SELECT "c"."ROWID" - FROM "Course" "c", "Department" "department-1" - WHERE "department-1"."name" LIKE 'D%' - AND "c"."dept" = "department-1"."number" + FROM "Course" "c", "Department" "department" + WHERE "department"."name" LIKE 'D%' + AND "c"."dept" = "department"."number" ) MySQL: DELETE c FROM `course` `c` - INNER JOIN `department` `department-1` - ON `c`.`dept` = `department-1`.`number` -WHERE `department-1`.`name` LIKE 'D%%' + INNER JOIN `department` `department` + ON `c`.`dept` = `department`.`number` +WHERE `department`.`name` LIKE 'D%%' PostgreSQL: DELETE FROM "course" WHERE ("name", "semester") IN ( SELECT "c"."name", "c"."semester" - FROM "course" "c", "department" "department-1" - WHERE "department-1"."name" LIKE 'D%%' - AND "c"."dept" = "department-1"."number" + FROM "course" "c", "department" "department" + WHERE "department"."name" LIKE 'D%%' + AND "c"."dept" = "department"."number" ) Oracle: @@ -795,9 +795,9 @@ Oracle: DELETE FROM "COURSE" WHERE "ROWID" IN ( SELECT "c"."ROWID" - FROM "COURSE" "c", "DEPARTMENT" "department-1" - WHERE "department-1"."NAME" LIKE 'D%' - AND "c"."DEPT" = "department-1"."NUMBER" + FROM "COURSE" "c", "DEPARTMENT" "department" + WHERE "department"."NAME" LIKE 'D%' + AND "c"."DEPT" = "department"."NUMBER" ) >>> select(s for s in Student if s.gpa > 3 and s not in (s2 for s2 in Student if s2.group.dept.name.startswith('A'))).delete(bulk=True) @@ -806,10 +806,10 @@ DELETE FROM "Student" WHERE "gpa" > 3 AND "id" NOT IN ( SELECT "s2"."id" - FROM "Student" "s2", "Group" "group-1", "Department" "department-1" - WHERE "department-1"."name" LIKE 'A%' - AND "s2"."group" = "group-1"."number" - AND "group-1"."dept" = "department-1"."number" + FROM "Student" "s2", "Group" "group", "Department" "department" + WHERE "department"."name" LIKE 'A%' + AND "s2"."group" = "group"."number" + AND "group"."dept" = "department"."number" ) # MySQL does not support such queries @@ -820,10 +820,10 @@ DELETE FROM "student" WHERE "gpa" > 3 AND "id" NOT IN ( SELECT "s2"."id" - FROM "student" "s2", "group" "group-1", "department" "department-1" - WHERE "department-1"."name" LIKE 'A%%' - AND "s2"."group" = "group-1"."number" - AND "group-1"."dept" = "department-1"."number" + FROM "student" "s2", "group" "group", "department" "department" + WHERE "department"."name" LIKE 'A%%' + AND "s2"."group" = "group"."number" + AND "group"."dept" = "department"."number" ) Oracle: @@ -832,10 +832,10 @@ DELETE FROM "STUDENT" WHERE "GPA" > 3 AND "ID" NOT IN ( SELECT "s2"."ID" - FROM "STUDENT" "s2", "GROUP" "group-1", "DEPARTMENT" "department-1" - WHERE "department-1"."NAME" LIKE 'A%' - AND "s2"."GROUP" = "group-1"."NUMBER" - AND "group-1"."DEPT" = "department-1"."NUMBER" + FROM "STUDENT" "s2", "GROUP" "group", "DEPARTMENT" "department" + WHERE "department"."NAME" LIKE 'A%' + AND "s2"."GROUP" = "group"."NUMBER" + AND "group"."DEPT" = "department"."NUMBER" ) >>> select(s for s in Student if exists(s2 for s2 in Student if s.gpa > s2.gpa)).delete(bulk=True) From 2d0938e06ffa86f0ce765b9f25274fbb6e7d587a Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Thu, 7 Sep 2017 19:44:41 +0300 Subject: [PATCH 182/547] Added modulo division native support in queries. --- pony/orm/dbproviders/oracle.py | 2 ++ pony/orm/sqlbuilding.py | 3 +++ pony/orm/sqltranslation.py | 3 +++ pony/orm/tests/queries.txt | 26 ++++++++++++++++++++++++++ 4 files changed, 34 insertions(+) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 9fc7d269f..7ef0a148e 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -214,6 +214,8 @@ def DATE(builder, expr): return 'TRUNC(', builder(expr), ')' def RANDOM(builder): return 'dbms_random.value' + def MOD(builder, a, b): + return 'MOD(', builder(a), ', ', builder(b), ')' def DATE_ADD(builder, expr, delta): if isinstance(delta, timedelta): return '(', builder(expr), " + INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index ee6672e8a..d449eed7e 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -400,6 +400,9 @@ def POW(builder, expr1, expr2): DIV = make_binary_op(' / ', True) FLOORDIV = make_binary_op(' / ', True) + def MOD(builder, a, b): + symbol = ' %% ' if builder.paramstyle in ('format', 'pyformat') else ' % ' + return '(', builder(a), symbol, builder(b), ')' def FLOAT_EQ(builder, a, b): a, b = builder(a), builder(b) return 'abs(', a, ' - ', b, ') / coalesce(nullif(greatest(abs(', a, '), abs(', b, ')), 0), 1) <= 1e-14' diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 594c3417e..a8f2d291f 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -690,6 +690,8 @@ def postDiv(translator, node): return node.left.monad / node.right.monad def postFloorDiv(translator, node): return node.left.monad // node.right.monad + def postMod(translator, node): + return node.left.monad % node.right.monad def postPower(translator, node): return node.left.monad ** node.right.monad def postUnarySub(translator, node): @@ -1247,6 +1249,7 @@ def mixin_init(monad): __mul__ = make_numeric_binop('*', 'MUL') __truediv__ = make_numeric_binop('/', 'DIV') __floordiv__ = make_numeric_binop('//', 'FLOORDIV') + __mod__ = make_numeric_binop('%', 'MOD') def __pow__(monad, monad2): translator = monad.translator if not isinstance(monad2, translator.NumericMixin): diff --git a/pony/orm/tests/queries.txt b/pony/orm/tests/queries.txt index f3cb605a6..bca7e8c00 100644 --- a/pony/orm/tests/queries.txt +++ b/pony/orm/tests/queries.txt @@ -888,3 +888,29 @@ PostgreSQL: SELECT DISTINCT upper("s"."name") FROM "student" "s" + +# Test modulo division operator + +>>> select(s for s in Student if s.id % 2 == 0) + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "Student" "s" +WHERE ("s"."id" % 2) = 0 + +PostgreSQL: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "student" "s" +WHERE ("s"."id" %% 2) = 0 + +MySQL: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "student" "s" +WHERE ("s"."id" %% 2) = 0 + +Oracle: + +SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" +FROM "STUDENT" "s" +WHERE MOD("s"."ID", 2) = 0 From 6c8a985013a3eea8abe9b4d2a6a3f73a8032332c Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 13 Sep 2017 16:00:08 +0300 Subject: [PATCH 183/547] Fix quotes in queries.txt for MySQL --- pony/orm/tests/queries.txt | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pony/orm/tests/queries.txt b/pony/orm/tests/queries.txt index bca7e8c00..2fb773fc4 100644 --- a/pony/orm/tests/queries.txt +++ b/pony/orm/tests/queries.txt @@ -905,9 +905,9 @@ WHERE ("s"."id" %% 2) = 0 MySQL: -SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" -FROM "student" "s" -WHERE ("s"."id" %% 2) = 0 +SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` +FROM `student` `s` +WHERE (`s`.`id` %% 2) = 0 Oracle: From 3e5c0142431f461cacaf669d07479149bee7456d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 13 Sep 2017 16:01:01 +0300 Subject: [PATCH 184/547] Fix incorrect aliases in nested queries --- pony/orm/sqltranslation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index a8f2d291f..be75ab281 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -875,8 +875,8 @@ def join_table(subquery, parent_alias, alias, table_name, join_cond): class TableRef(object): def __init__(tableref, subquery, name, entity): tableref.subquery = subquery - tableref.name_path = name tableref.alias = subquery.make_alias(name) + tableref.name_path = tableref.alias tableref.entity = entity tableref.joined = False tableref.can_affect_distinct = True From c697b2de7bb0cb6e4387eeac15835d56b10081de Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 13 Sep 2017 17:23:10 +0300 Subject: [PATCH 185/547] Translator bug fixed --- pony/orm/asttranslation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index ca5e57661..5f722249e 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -328,7 +328,8 @@ def create_extractors(code_key, tree, filter_num, globals, locals, getattr_attrname_values = {} for node in pretranslator.getattr_nodes: if node in pretranslator.externals: - code = extractors[filter_num, node.src] + src = node.src + code = extractors[filter_num, src] getattr_extractors[src] = code attrname_value = eval(code, globals, locals) getattr_attrname_values[src] = attrname_value From 377bf2e183fff49c01347082a0704e430755d8e1 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 13 Sep 2017 17:24:12 +0300 Subject: [PATCH 186/547] Local variable renaming --- pony/orm/asttranslation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 5f722249e..947357ffd 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -325,14 +325,14 @@ def create_extractors(code_key, tree, filter_num, globals, locals, extractors[filter_num, src] = code getattr_extractors = {} - getattr_attrname_values = {} + getattr_attrname_dict = {} for node in pretranslator.getattr_nodes: if node in pretranslator.externals: src = node.src code = extractors[filter_num, src] getattr_extractors[src] = code attrname_value = eval(code, globals, locals) - getattr_attrname_values[src] = attrname_value + getattr_attrname_dict[src] = attrname_value elif isinstance(node, ast.Const): attrname_value = node.value else: throw(TypeError, '`%s` should be either external expression or constant.' % ast2src(node)) @@ -342,7 +342,7 @@ def create_extractors(code_key, tree, filter_num, globals, locals, getattr_cache[getattr_key] = tuple(sorted(getattr_extractors.items())) varnames = list(sorted(extractors)) - getattr_attrname_values = tuple(val for key, val in sorted(getattr_attrname_values.items())) + getattr_attrname_values = tuple(val for key, val in sorted(getattr_attrname_dict.items())) extractors_key = (code_key, filter_num, getattr_attrname_values) result = extractors_cache[extractors_key] = extractors, varnames, tree, extractors_key return result From 0cafa7db8fda15ce94f279f2fa3fb06031ec94b2 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 13 Sep 2017 18:07:06 +0300 Subject: [PATCH 187/547] Move pickle_ast and unpickle_ast to from pony.orm.core to pony.utils --- pony/orm/core.py | 29 +++-------------------------- pony/orm/sqltranslation.py | 8 ++++---- pony/utils/utils.py | 26 ++++++++++++++++++++++++-- 3 files changed, 31 insertions(+), 32 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 7c3a675b8..b4d32b781 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1,8 +1,8 @@ from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, izip, imap, iteritems, itervalues, items_list, values_list, xrange, cmp, \ - basestring, unicode, buffer, int_types, builtins, pickle, with_metaclass + basestring, unicode, buffer, int_types, builtins, with_metaclass -import io, json, re, sys, types, datetime, logging, itertools, warnings +import json, re, sys, types, datetime, logging, itertools, warnings from operator import attrgetter, itemgetter from itertools import chain, starmap, repeat from time import time @@ -27,7 +27,7 @@ ) from pony import utils from pony.utils import localbase, decorator, cut_traceback, throw, reraise, truncate_repr, get_lambda_args, \ - deprecated, import_module, parse_expr, is_ident, tostring, strjoin, concat + pickle_ast, unpickle_ast, deprecated, import_module, parse_expr, is_ident, tostring, strjoin, concat __all__ = [ 'pony', @@ -5185,29 +5185,6 @@ def extract_vars(extractors, globals, locals, cells=None): def unpickle_query(query_result): return query_result -def persistent_id(obj): - if obj is Ellipsis: - return "Ellipsis" - -def persistent_load(persid): - if persid == "Ellipsis": - return Ellipsis - raise pickle.UnpicklingError("unsupported persistent object") - -def pickle_ast(val): - pickled = io.BytesIO() - pickler = pickle.Pickler(pickled) - pickler.persistent_id = persistent_id - pickler.dump(val) - return pickled - -def unpickle_ast(pickled): - pickled.seek(0) - unpickler = pickle.Unpickler(pickled) - unpickler.persistent_load = persistent_load - return unpickler.load() - - class Query(object): def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False): assert isinstance(tree, ast.GenExprInner) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index be75ab281..c90582629 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1,5 +1,5 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PY2, items_list, izip, xrange, basestring, unicode, buffer, pickle, with_metaclass +from pony.py23compat import PY2, items_list, izip, xrange, basestring, unicode, buffer, with_metaclass import types, sys, re, itertools from decimal import Decimal @@ -12,7 +12,7 @@ from pony.thirdparty.compiler import ast from pony import options, utils -from pony.utils import is_ident, throw, reraise, concat +from pony.utils import is_ident, throw, reraise, concat, pickle_ast, unpickle_ast from pony.orm.asttranslation import ASTTranslator, ast2src, TranslationError from pony.orm.ormtypes import \ numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ @@ -596,8 +596,8 @@ def apply_kwfilters(translator, filterattrs): return translator def apply_lambda(translator, filter_num, order_by, func_ast, argnames, extractors, vartypes): translator = deepcopy(translator) - pickled_func_ast = pickle.dumps(func_ast, 2) - func_ast = pickle.loads(pickled_func_ast) # func_ast = deepcopy(func_ast) + pickled_func_ast = pickle_ast(func_ast) + func_ast = unpickle_ast(pickled_func_ast) # func_ast = deepcopy(func_ast) translator.filter_num = filter_num translator.extractors.update(extractors) translator.vartypes.update(vartypes) diff --git a/pony/utils/utils.py b/pony/utils/utils.py index 7182c6854..f4adfcee5 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -1,9 +1,9 @@ #coding: cp1251 from __future__ import absolute_import, print_function -from pony.py23compat import PY2, imap, basestring, unicode +from pony.py23compat import PY2, imap, basestring, unicode, pickle -import re, os, os.path, sys, datetime, inspect, types, linecache, warnings, json +import io, re, os, os.path, sys, datetime, inspect, types, linecache, warnings, json from itertools import count as _count from inspect import isfunction, ismethod @@ -511,3 +511,25 @@ def concat(*args): def is_utf8(encoding): return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') + +def _persistent_id(obj): + if obj is Ellipsis: + return "Ellipsis" + +def _persistent_load(persid): + if persid == "Ellipsis": + return Ellipsis + raise pickle.UnpicklingError("unsupported persistent object") + +def pickle_ast(val): + pickled = io.BytesIO() + pickler = pickle.Pickler(pickled) + pickler.persistent_id = _persistent_id + pickler.dump(val) + return pickled + +def unpickle_ast(pickled): + pickled.seek(0) + unpickler = pickle.Unpickler(pickled) + unpickler.persistent_load = _persistent_load + return unpickler.load() From 0665843ca3314aa3a01e786486eb41f3013e318f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 13 Sep 2017 18:11:28 +0300 Subject: [PATCH 188/547] Add copy_ast function --- pony/orm/sqltranslation.py | 5 ++--- pony/utils/utils.py | 3 +++ 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index c90582629..39b941639 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -12,7 +12,7 @@ from pony.thirdparty.compiler import ast from pony import options, utils -from pony.utils import is_ident, throw, reraise, concat, pickle_ast, unpickle_ast +from pony.utils import is_ident, throw, reraise, concat, copy_ast from pony.orm.asttranslation import ASTTranslator, ast2src, TranslationError from pony.orm.ormtypes import \ numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ @@ -596,8 +596,7 @@ def apply_kwfilters(translator, filterattrs): return translator def apply_lambda(translator, filter_num, order_by, func_ast, argnames, extractors, vartypes): translator = deepcopy(translator) - pickled_func_ast = pickle_ast(func_ast) - func_ast = unpickle_ast(pickled_func_ast) # func_ast = deepcopy(func_ast) + func_ast = copy_ast(func_ast) # func_ast = deepcopy(func_ast) translator.filter_num = filter_num translator.extractors.update(extractors) translator.vartypes.update(vartypes) diff --git a/pony/utils/utils.py b/pony/utils/utils.py index f4adfcee5..9bf4530aa 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -533,3 +533,6 @@ def unpickle_ast(pickled): unpickler = pickle.Unpickler(pickled) unpickler.persistent_load = _persistent_load return unpickler.load() + +def copy_ast(tree): + return unpickle_ast(pickle_ast(tree)) From d1ca677517af3edec7cad963c0b493b8cb42ef40 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 13 Sep 2017 18:17:21 +0300 Subject: [PATCH 189/547] Fixes 223: incorrect result of getattr(entity, attrname) when the same lambda applies to different entities --- pony/orm/asttranslation.py | 5 ++++- pony/orm/tests/test_getattr.py | 25 ++++++++++++++++++++----- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 947357ffd..376f4675f 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -5,7 +5,7 @@ from pony.thirdparty.compiler import ast -from pony.utils import throw +from pony.utils import throw, copy_ast class TranslationError(Exception): pass @@ -313,6 +313,9 @@ def create_extractors(code_key, tree, filter_num, globals, locals, result = extractors_cache.get(extractors_key) except TypeError: pass # unhashable type + if not result: + tree = copy_ast(tree) + if not result: pretranslator = PreTranslator( tree, globals, locals, special_functions, const_functions, additional_internal_names) diff --git a/pony/orm/tests/test_getattr.py b/pony/orm/tests/test_getattr.py index 8704f9f5d..629d50623 100644 --- a/pony/orm/tests/test_getattr.py +++ b/pony/orm/tests/test_getattr.py @@ -33,17 +33,18 @@ class Artist(db.Entity): db.generate_mapping(check_tables=True, create_tables=True) with orm.db_session: - pop = Genre(name='pop') + pop = Genre(name='Pop') Artist(name='Sia', age=40, genres=[pop]) + Hobby(name='Swimming') pony.options.INNER_JOIN_SYNTAX = True @db_session def test_no_caching(self): - for attr, type in zip(['name', 'age'], [basestring, int]): - val = select(getattr(x, attr) for x in self.db.Artist).first() - self.assertIsInstance(val, type) - + for attr_name, attr_type in zip(['name', 'age'], [basestring, int]): + val = select(getattr(x, attr_name) for x in self.db.Artist).first() + self.assertIsInstance(val, attr_type) + @db_session def test_simple(self): val = select(getattr(x, 'age') for x in self.db.Artist).first() @@ -88,4 +89,18 @@ def test_not_string(self): name = 1 select(getattr(x, name) for x in self.db.Artist) + @db_session + def test_lambda_1(self): + for name, value in [('name', 'Sia'), ('age', 40), ('name', 'Sia')]: + result = self.db.Artist.select(lambda a: getattr(a, name) == value) + self.assertEqual(set(obj.name for obj in result), {'Sia'}) + @db_session + def test_lambda_2(self): + for entity, name, value in [ + (self.db.Genre, 'name', 'Pop'), + (self.db.Artist, 'age', 40), + (self.db.Hobby, 'name', 'Swimming'), + ]: + result = entity.select(lambda a: getattr(a, name) == value) + self.assertEqual(set(result[:]), {entity.select().first()}) From 519a486aaf22202dfff29dccfccadb25245fccbb Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 6 Sep 2017 17:39:38 +0300 Subject: [PATCH 190/547] Refactoring: remove filter_num from create_extractors() --- pony/orm/asttranslation.py | 16 +++++++--------- pony/orm/core.py | 21 +++++++++++---------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 376f4675f..f2366578c 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -301,14 +301,12 @@ def postCallFunc(translator, node): getattr_cache = {} extractors_cache = {} -def create_extractors(code_key, tree, filter_num, globals, locals, - special_functions, const_functions, additional_internal_names=()): +def create_extractors(code_key, tree, globals, locals, special_functions, const_functions, additional_internal_names=()): result = None - getattr_key = code_key, filter_num - getattr_extractors = getattr_cache.get(getattr_key) + getattr_extractors = getattr_cache.get(code_key) if getattr_extractors: getattr_attrname_values = tuple(eval(code, globals, locals) for src, code in getattr_extractors) - extractors_key = (code_key, filter_num, getattr_attrname_values) + extractors_key = (code_key, getattr_attrname_values) try: result = extractors_cache.get(extractors_key) except TypeError: @@ -325,14 +323,14 @@ def create_extractors(code_key, tree, filter_num, globals, locals, src = node.src = ast2src(node) if src == '.0': code = None else: code = compile(src, src, 'eval') - extractors[filter_num, src] = code + extractors[src] = code getattr_extractors = {} getattr_attrname_dict = {} for node in pretranslator.getattr_nodes: if node in pretranslator.externals: src = node.src - code = extractors[filter_num, src] + code = extractors[src] getattr_extractors[src] = code attrname_value = eval(code, globals, locals) getattr_attrname_dict[src] = attrname_value @@ -342,10 +340,10 @@ def create_extractors(code_key, tree, filter_num, globals, locals, if not isinstance(attrname_value, basestring): throw(TypeError, '%s: attribute name must be string. Got: %r' % (ast2src(node.parent_node), attrname_value)) node._attrname_value = attrname_value - getattr_cache[getattr_key] = tuple(sorted(getattr_extractors.items())) + getattr_cache[code_key] = tuple(sorted(getattr_extractors.items())) varnames = list(sorted(extractors)) getattr_attrname_values = tuple(val for key, val in sorted(getattr_attrname_dict.items())) - extractors_key = (code_key, filter_num, getattr_attrname_values) + extractors_key = (code_key, getattr_attrname_values) result = extractors_cache[extractors_key] = extractors, varnames, tree, extractors_key return result diff --git a/pony/orm/core.py b/pony/orm/core.py index b4d32b781..9cc23f6f6 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5151,15 +5151,15 @@ def raw_sql(sql, result_type=None): locals = sys._getframe(1).f_locals return RawSQL(sql, globals, locals, result_type) -def extract_vars(extractors, globals, locals, cells=None): +def extract_vars(filter_num, extractors, globals, locals, cells=None): if cells: locals = locals.copy() for name, cell in cells.items(): locals[name] = cell.cell_contents vars = {} vartypes = {} - for key, code in iteritems(extractors): - filter_num, src = key + for src, code in iteritems(extractors): + key = filter_num, src if src == '.0': value = locals['.0'] else: try: value = eval(code, globals, locals) @@ -5189,11 +5189,12 @@ class Query(object): def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False): assert isinstance(tree, ast.GenExprInner) extractors, varnames, tree, pretranslator_key = create_extractors( - code_key, tree, 0, globals, locals, special_functions, const_functions) - vars, vartypes = extract_vars(extractors, globals, locals, cells) + code_key, tree, globals, locals, special_functions, const_functions) + filter_num = 0 + vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) node = tree.quals[0].iter - origin = vars[0, node.src] + origin = vars[filter_num, node.src] if isinstance(origin, EntityIter): origin = origin.entity elif not isinstance(origin, EntityMeta): if node.src == '.0': throw(TypeError, 'Cannot iterate over non-entity object') @@ -5204,7 +5205,7 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False if database.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % origin.__name__) database.provider.normalize_vars(vars, vartypes) query._vars = vars - query._key = pretranslator_key, tuple(vartypes[name] for name in varnames), left_join + query._key = pretranslator_key, tuple(vartypes[filter_num, name] for name in varnames), left_join query._database = database translator = database._translator_cache.get(query._key) @@ -5505,14 +5506,14 @@ def _process_lambda(query, func, globals, locals, order_by): filter_num = len(query._filters) + 1 extractors, varnames, func_ast, pretranslator_key = create_extractors( - func_id, func_ast, filter_num, globals, locals, special_functions, const_functions, + func_id, func_ast, globals, locals, special_functions, const_functions, argnames or prev_translator.subquery) if extractors: - vars, vartypes = extract_vars(extractors, globals, locals, cells) + vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) query._database.provider.normalize_vars(vars, vartypes) new_query_vars = query._vars.copy() new_query_vars.update(vars) - sorted_vartypes = tuple(vartypes[name] for name in varnames) + sorted_vartypes = tuple(vartypes[filter_num, name] for name in varnames) else: new_query_vars, vartypes, sorted_vartypes = query._vars, {}, () new_key = query._key + (('order_by' if order_by else 'filter', pretranslator_key, sorted_vartypes),) From 0047188c8ce263dc30e64dfcd856437e83f84b0e Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 13 Sep 2017 18:59:29 +0300 Subject: [PATCH 191/547] Skip optimistic checks for queries in db_session with serializable=True --- pony/orm/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 9cc23f6f6..907b18016 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4889,7 +4889,7 @@ def _save_updated_(obj): val = obj._vals_[attr] values.extend(attr.get_raw_values(val)) cache = obj._session_cache_ - if obj not in cache.for_update: + if not cache.db_session.serializable and obj not in cache.for_update: optimistic_ops, optimistic_columns, optimistic_converters, optimistic_values = \ obj._construct_optimistic_criteria_() values.extend(optimistic_values) @@ -4926,7 +4926,7 @@ def _save_deleted_(obj): values = [] values.extend(obj._get_raw_pkval_()) cache = obj._session_cache_ - if obj not in cache.for_update: + if not cache.db_session.serializable and obj not in cache.for_update: optimistic_ops, optimistic_columns, optimistic_converters, optimistic_values = \ obj._construct_optimistic_criteria_() values.extend(optimistic_values) From e60c4325484a3d86573b6037eaa222ab5ad711e5 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 13 Sep 2017 19:13:59 +0300 Subject: [PATCH 192/547] Add `optimistic=True` option to db_session --- pony/orm/core.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 907b18016..df3c7e5c1 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -365,9 +365,10 @@ def rollback(): select_re = re.compile(r'\s*select\b', re.IGNORECASE) class DBSessionContextManager(object): - __slots__ = 'retry', 'retry_exceptions', 'allowed_exceptions', 'immediate', 'ddl', 'serializable', 'strict', \ + __slots__ = 'retry', 'retry_exceptions', 'allowed_exceptions', \ + 'immediate', 'ddl', 'serializable', 'strict', 'optimistic', \ 'sql_debug', 'show_values' - def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False, strict=False, + def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False, strict=False, optimistic=True, retry_exceptions=(TransactionError,), allowed_exceptions=(), sql_debug=None, show_values=None): if retry is not 0: if type(retry) is not int: throw(TypeError, @@ -383,8 +384,9 @@ def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False db_session.retry = retry db_session.ddl = ddl db_session.serializable = serializable - db_session.immediate = immediate or ddl or serializable + db_session.immediate = immediate or ddl or serializable or not optimistic db_session.strict = strict + db_session.optimistic = optimistic and not serializable db_session.retry_exceptions = retry_exceptions db_session.allowed_exceptions = allowed_exceptions db_session.sql_debug = sql_debug @@ -4889,7 +4891,8 @@ def _save_updated_(obj): val = obj._vals_[attr] values.extend(attr.get_raw_values(val)) cache = obj._session_cache_ - if not cache.db_session.serializable and obj not in cache.for_update: + optimistic_session = cache.db_session is None or cache.db_session.optimistic + if optimistic_session and obj not in cache.for_update: optimistic_ops, optimistic_columns, optimistic_converters, optimistic_values = \ obj._construct_optimistic_criteria_() values.extend(optimistic_values) @@ -4926,7 +4929,8 @@ def _save_deleted_(obj): values = [] values.extend(obj._get_raw_pkval_()) cache = obj._session_cache_ - if not cache.db_session.serializable and obj not in cache.for_update: + optimistic_session = cache.db_session is None or cache.db_session.optimistic + if optimistic_session and obj not in cache.for_update: optimistic_ops, optimistic_columns, optimistic_converters, optimistic_values = \ obj._construct_optimistic_criteria_() values.extend(optimistic_values) From e4450d219eb376cb946aeba42cafd8040e91f428 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 13 Sep 2017 19:56:27 +0300 Subject: [PATCH 193/547] Optimistic checking for delete() method. --- pony/orm/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index df3c7e5c1..3495a4bb0 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4949,7 +4949,9 @@ def _save_deleted_(obj): obj.__class__._delete_sql_cache_[query_key] = sql, adapter else: sql, adapter = cached_sql arguments = adapter(values) - database._exec_sql(sql, arguments, start_transaction=True) + cursor = database._exec_sql(sql, arguments, start_transaction=True) + if cursor.rowcount != 1: + throw(OptimisticCheckError, 'Object %s was updated outside of current transaction' % safe_repr(obj)) obj._status_ = 'deleted' cache.indexes[obj._pk_attrs_].pop(obj._pkval_) def _save_(obj, dependent_objects=None): From 6ab4da910a189eddfabc082586f451eb0c90d997 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 13 Sep 2017 19:57:44 +0300 Subject: [PATCH 194/547] More correct usage of cursor.rowcount (it may equals to -1) --- pony/orm/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 3495a4bb0..95dadde6d 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4919,7 +4919,7 @@ def _save_updated_(obj): else: sql, adapter = cached_sql arguments = adapter(values) cursor = database._exec_sql(sql, arguments, start_transaction=True) - if cursor.rowcount != 1: + if cursor.rowcount == 0: throw(OptimisticCheckError, 'Object %s was updated outside of current transaction' % safe_repr(obj)) obj._status_ = 'updated' obj._rbits_ |= obj._wbits_ & obj._all_bits_except_volatile_ @@ -4950,7 +4950,7 @@ def _save_deleted_(obj): else: sql, adapter = cached_sql arguments = adapter(values) cursor = database._exec_sql(sql, arguments, start_transaction=True) - if cursor.rowcount != 1: + if cursor.rowcount == 0: throw(OptimisticCheckError, 'Object %s was updated outside of current transaction' % safe_repr(obj)) obj._status_ = 'deleted' cache.indexes[obj._pk_attrs_].pop(obj._pkval_) From 08c83b9d09e06843320c01100efcb1bf317d8e75 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 14 Sep 2017 16:35:29 +0300 Subject: [PATCH 195/547] Show updated attributes when OptimisticCheckError is being raised --- pony/orm/core.py | 53 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 51 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 95dadde6d..615024ca0 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4920,7 +4920,7 @@ def _save_updated_(obj): arguments = adapter(values) cursor = database._exec_sql(sql, arguments, start_transaction=True) if cursor.rowcount == 0: - throw(OptimisticCheckError, 'Object %s was updated outside of current transaction' % safe_repr(obj)) + throw(OptimisticCheckError, obj.find_updated_attributes()) obj._status_ = 'updated' obj._rbits_ |= obj._wbits_ & obj._all_bits_except_volatile_ obj._wbits_ = 0 @@ -4951,9 +4951,58 @@ def _save_deleted_(obj): arguments = adapter(values) cursor = database._exec_sql(sql, arguments, start_transaction=True) if cursor.rowcount == 0: - throw(OptimisticCheckError, 'Object %s was updated outside of current transaction' % safe_repr(obj)) + throw(OptimisticCheckError, obj.find_updated_attributes()) obj._status_ = 'deleted' cache.indexes[obj._pk_attrs_].pop(obj._pkval_) + + def find_updated_attributes(obj): + entity = obj.__class__ + attrs_to_select = [] + attrs_to_select.extend(entity._pk_attrs_) + discr = entity._discriminator_attr_ + if discr is not None and discr.pk_offset is None: + attrs_to_select.append(discr) + for attr in obj._attrs_with_bit_(obj._attrs_with_columns_, obj._rbits_): + optimistic = attr.optimistic if attr.optimistic is not None else attr.converters[0].optimistic + if optimistic: + attrs_to_select.append(attr) + + optimistic_converters = [] + attr_offsets = {} + select_list = [ 'ALL' ] + for attr in attrs_to_select: + optimistic_converters.extend(attr.converters) + attr_offsets[attr] = offsets = [] + for columns in attr.columns: + select_list.append([ 'COLUMN', None, columns]) + offsets.append(len(select_list) - 2) + + from_list = [ 'FROM', [ None, 'TABLE', entity._table_ ] ] + pk_columns = entity._pk_columns_ + pk_converters = entity._pk_converters_ + criteria_list = [ [ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ] + for i, (column, converter) in enumerate(izip(pk_columns, pk_converters)) ] + sql_ast = [ 'SELECT', select_list, from_list, [ 'WHERE' ] + criteria_list ] + database = entity._database_ + sql, adapter = database._ast2sql(sql_ast) + arguments = adapter(obj._get_raw_pkval_()) + cursor = database._exec_sql(sql, arguments) + row = cursor.fetchone() + if row is None: + return "Object %s was deleted outside of current transaction" % safe_repr(obj) + + real_entity_subclass, pkval, avdict = entity._parse_row_(row, attr_offsets) + diff = [] + for attr, new_dbval in avdict.items(): + old_dbval = obj._dbvals_[attr] + converter = attr.converters[0] + if old_dbval != new_dbval and ( + attr.reverse or not converter.dbvals_equal(old_dbval, new_dbval)): + diff.append('%s (%r -> %r)' % (attr.name, old_dbval, new_dbval)) + + return "Object %s was updated outside of current transaction%s" % ( + safe_repr(obj), ('. Changes: %s' % ', '.join(diff) if diff else '')) + def _save_(obj, dependent_objects=None): status = obj._status_ if status in ('created', 'modified'): From 53f923263b458b8639436a32b6f7a2b80dac7564 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Thu, 14 Sep 2017 18:02:15 +0300 Subject: [PATCH 196/547] StringMixing.negate() method overrides base method for monad. --- pony/orm/sqltranslation.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 39b941639..61edd27a3 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1415,6 +1415,12 @@ def __getitem__(monad, index): index_sql = [ 'ADD', inner_sql, [ 'CASE', None, [ (['GE', inner_sql, [ 'VALUE', 0 ]], [ 'VALUE', 1 ]) ], [ 'VALUE', 0 ] ] ] sql = [ 'SUBSTR', expr_sql, index_sql, [ 'VALUE', 1 ] ] return translator.StringExprMonad(translator, monad.type, sql) + def negate(monad): + sql = monad.getsql()[0] + translator = monad.translator + result = translator.BoolExprMonad(translator, [ 'EQ', [ 'LENGTH', sql ], [ 'VALUE', 0 ]]) + result.aggregated = monad.aggregated + return result def nonzero(monad): sql = monad.getsql()[0] translator = monad.translator From c1c895ea44936a152bab9c0adaa6c9abbaf71992 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Thu, 14 Sep 2017 18:27:14 +0300 Subject: [PATCH 197/547] StringAttrMonad now takes into account `nullable` attribute option in negate() and nonzero() methods. --- pony/orm/sqltranslation.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 61edd27a3..9f8d05877 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1623,8 +1623,25 @@ def __init__(monad, parent, attr): parent_subquery = parent_monad.tableref.subquery monad.tableref = parent_subquery.add_tableref(name_path, parent_monad.tableref, attr) +class StringAttrMonad(StringMixin, AttrMonad): + def negate(monad): + sql = monad.getsql()[0] + translator = monad.translator + result_sql = [ 'EQ', [ 'LENGTH', sql ], [ 'VALUE', 0 ] ] + if monad.attr.nullable: + result_sql = [ 'OR', result_sql, [ 'IS_NULL', sql ] ] + result = translator.BoolExprMonad(translator, result_sql) + result.aggregated = monad.aggregated + return result + def nonzero(monad): + sql = monad.getsql()[0] + translator = monad.translator + result_sql = [ 'GT', [ 'LENGTH', sql ], [ 'VALUE', 0 ] ] + result = translator.BoolExprMonad(translator, result_sql) + result.aggregated = monad.aggregated + return result + class NumericAttrMonad(NumericMixin, AttrMonad): pass -class StringAttrMonad(StringMixin, AttrMonad): pass class DateAttrMonad(DateMixin, AttrMonad): pass class TimeAttrMonad(TimeMixin, AttrMonad): pass class TimedeltaAttrMonad(TimedeltaMixin, AttrMonad): pass From a3c44c4ea5363b3a0cd840caf098cc714f6b8185 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Thu, 14 Sep 2017 19:22:46 +0300 Subject: [PATCH 198/547] IfExp support for PythonTranslator. --- pony/orm/asttranslation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index f2366578c..7018a079a 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -77,6 +77,8 @@ def postGenExprFor(translator, node): return src def postGenExprIf(translator, node): return 'if %s' % node.test.src + def postIfExp(translator, node): + return '%s if %s else %s' % (node.then.src, node.test.src, node.else_.src) @priority(14) def postOr(translator, node): return ' or '.join(expr.src for expr in node.nodes) From b63634b8e3b56619e9bae0494379edcceb729c5f Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Thu, 14 Sep 2017 19:52:26 +0300 Subject: [PATCH 199/547] coalesce() function added --- pony/orm/core.py | 6 +++--- pony/orm/sqltranslation.py | 15 ++++++++++++++- pony/utils/utils.py | 6 ++++++ 3 files changed, 23 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 615024ca0..bb369d6b6 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -27,7 +27,7 @@ ) from pony import utils from pony.utils import localbase, decorator, cut_traceback, throw, reraise, truncate_repr, get_lambda_args, \ - pickle_ast, unpickle_ast, deprecated, import_module, parse_expr, is_ident, tostring, strjoin, concat + pickle_ast, unpickle_ast, deprecated, import_module, parse_expr, is_ident, tostring, strjoin, concat, coalesce __all__ = [ 'pony', @@ -58,7 +58,7 @@ 'count', 'sum', 'min', 'max', 'avg', 'distinct', - 'JOIN', 'desc', 'concat', 'raw_sql', + 'JOIN', 'desc', 'concat', 'coalesce', 'raw_sql', 'buffer', 'unicode', @@ -5813,5 +5813,5 @@ def show(entity): from pprint import pprint pprint(x) -special_functions = {itertools.count, utils.count, count, random, raw_sql, getattr, int} +special_functions = {itertools.count, utils.count, count, random, raw_sql, getattr, int, coalesce} const_functions = {buffer, Decimal, datetime.datetime, datetime.date, datetime.time, datetime.timedelta} diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 9f8d05877..cfda8398c 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -12,7 +12,7 @@ from pony.thirdparty.compiler import ast from pony import options, utils -from pony.utils import is_ident, throw, reraise, concat, copy_ast +from pony.utils import is_ident, throw, reraise, concat, copy_ast, coalesce from pony.orm.asttranslation import ASTTranslator, ast2src, TranslationError from pony.orm.ormtypes import \ numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ @@ -2120,6 +2120,19 @@ class FuncAvgMonad(FuncMonad): def call(monad, x): return x.aggregate('AVG') +class FuncCoalesceMonad(FuncMonad): + func = coalesce + def call(monad, *args): + if len(args) < 2: throw(TranslationError, 'coalesce() function requires at least two arguments') + translator = args[0].translator + result_ast = [ 'COALESCE' ] + t = None + for arg in args: + if t is None: t = arg.type + elif arg.type is not t: throw(TypeError, 'All arguments of coalesce() function should have the same type') + result_ast.append(arg.getsql()[0]) + return translator.ExprMonad.new(translator, unicode, result_ast) + class FuncDistinctMonad(FuncMonad): func = utils.distinct, core.distinct def call(monad, x): diff --git a/pony/utils/utils.py b/pony/utils/utils.py index 9bf4530aa..7f7fb61fc 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -500,6 +500,12 @@ def avg(iter): if not count: return None return sum / count +def coalesce(*args): + for arg in args: + if arg is not None: + return arg + return None + def distinct(iter): d = defaultdict(int) for item in iter: From ccbcf8cefa618bf0a5a2be02b7dd29df1e203b64 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 14 Sep 2017 20:28:37 +0300 Subject: [PATCH 200/547] Make coalesce() works for objects --- pony/orm/sqltranslation.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index cfda8398c..7171c3710 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1719,6 +1719,7 @@ def new(translator, type, sql): elif type is timedelta: cls = translator.TimedeltaExprMonad elif type is datetime: cls = translator.DatetimeExprMonad elif type is Json: cls = translator.JsonExprMonad + elif isinstance(type, EntityMeta): cls = translator.ObjectExprMonad else: throw(NotImplementedError, type) # pragma: no cover return cls(translator, type, sql) def __new__(cls, *args): @@ -1730,6 +1731,10 @@ def __init__(monad, translator, type, sql): def getsql(monad, subquery=None): return [ monad.sql ] +class ObjectExprMonad(ObjectMixin, ExprMonad): + def getsql(monad, subquery=None): + return monad.sql + class StringExprMonad(StringMixin, ExprMonad): pass class NumericExprMonad(NumericMixin, ExprMonad): pass class DateExprMonad(DateMixin, ExprMonad): pass @@ -2125,13 +2130,16 @@ class FuncCoalesceMonad(FuncMonad): def call(monad, *args): if len(args) < 2: throw(TranslationError, 'coalesce() function requires at least two arguments') translator = args[0].translator - result_ast = [ 'COALESCE' ] - t = None - for arg in args: - if t is None: t = arg.type - elif arg.type is not t: throw(TypeError, 'All arguments of coalesce() function should have the same type') - result_ast.append(arg.getsql()[0]) - return translator.ExprMonad.new(translator, unicode, result_ast) + arg = args[0] + t = arg.type + result = [ [ sql ] for sql in arg.getsql() ] + for arg in args[1:]: + if arg.type is not t: throw(TypeError, 'All arguments of coalesce() function should have the same type') + for i, sql in enumerate(arg.getsql()): + result[i].append(sql) + sql = [ [ 'COALESCE' ] + coalesce_args for coalesce_args in result ] + if not isinstance(t, EntityMeta): sql = sql[0] + return translator.ExprMonad.new(translator, t, sql) class FuncDistinctMonad(FuncMonad): func = utils.distinct, core.distinct From 8eaf8128b65ac32d84664f74ae8cf2768d6f0cb0 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 18 Sep 2017 18:37:20 +0300 Subject: [PATCH 201/547] Refactoring: extract query._apply_kwargs() function from query.filter() --- pony/orm/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index bb369d6b6..cb3dae01b 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5618,7 +5618,9 @@ def filter(query, *args, **kwargs): entity = query._translator.expr_type if not isinstance(entity, EntityMeta): throw(TypeError, 'Keyword arguments are not allowed: since query result type is not an entity, filter() method can accept only lambda') - + return query._apply_kwargs(kwargs) + def _apply_kwargs(query, kwargs): + entity = query._translator.expr_type get_attr = entity._adict_.get filterattrs = [] value_dict = {} From fb95f005627eb674845653a80df1fda93b44e9d5 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 Sep 2017 15:58:56 +0300 Subject: [PATCH 202/547] Bug fixed --- pony/orm/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index cb3dae01b..e2e7f2d73 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5596,7 +5596,7 @@ def _reapply_filters(query, translator): translator = translator.without_order() elif len(tup) == 1: attrnames = tup[0] - translator.apply_kwfilters(attrnames) + translator = translator.apply_kwfilters(attrnames) elif len(tup) == 2: numbers, args = tup if numbers: translator = translator.order_by_numbers(args) From 6ffa721baa7e5ac247981d911064f50b51d60a30 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 Sep 2017 16:01:35 +0300 Subject: [PATCH 203/547] Refactoring: make first tuple item of query._filter() items a method name --- pony/orm/core.py | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index e2e7f2d73..5fb1d8328 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5501,9 +5501,9 @@ def order_by(query, *args): if not args: throw(TypeError, 'order_by() method requires at least one argument') if args[0] is None: if len(args) > 1: throw(TypeError, 'When first argument of order_by() method is None, it must be the only argument') - tup = ((),) - new_key = query._key + tup - new_filters = query._filters + tup + tup = ('without_order',) + new_key = query._key + (tup,) + new_filters = query._filters + (tup,) new_translator = query._database._translator_cache.get(new_key) if new_translator is None: new_translator = query._translator.without_order() @@ -5525,8 +5525,10 @@ def order_by(query, *args): else: throw(TypeError, "order_by() method receive an argument of invalid type: %r" % arg) if numbers and attributes: throw(TypeError, 'order_by() method receive invalid combination of arguments') - new_key = query._key + ('order_by', args,) - new_filters = query._filters + ((numbers, args),) + + tup = ('order_by_numbers' if numbers else 'order_by_attributes', args) + new_key = query._key + (tup,) + new_filters = query._filters + (tup,) new_translator = query._database._translator_cache.get(new_key) if new_translator is None: if numbers: new_translator = query._translator.order_by_numbers(args) @@ -5572,7 +5574,7 @@ def _process_lambda(query, func, globals, locals, order_by): else: new_query_vars, vartypes, sorted_vartypes = query._vars, {}, () new_key = query._key + (('order_by' if order_by else 'filter', pretranslator_key, sorted_vartypes),) - new_filters = query._filters + ((order_by, func_ast, argnames, extractors, vartypes),) + new_filters = query._filters + (('apply_lambda', order_by, func_ast, argnames, extractors, vartypes),) new_translator = query._database._translator_cache.get(new_key) if new_translator is None: prev_optimized = prev_translator.optimize @@ -5592,18 +5594,22 @@ def _process_lambda(query, func, globals, locals, order_by): return query._clone(_vars=new_query_vars, _key=new_key, _filters=new_filters, _translator=new_translator) def _reapply_filters(query, translator): for i, tup in enumerate(query._filters): - if not tup: + cmd = tup[0] + if cmd == 'without_order': translator = translator.without_order() - elif len(tup) == 1: - attrnames = tup[0] + elif cmd == 'apply_kwfilters': + attrnames = tup[1] translator = translator.apply_kwfilters(attrnames) - elif len(tup) == 2: - numbers, args = tup - if numbers: translator = translator.order_by_numbers(args) - else: translator = translator.order_by_attributes(args) - else: - order_by, func_ast, argnames, extractors, vartypes = tup + elif cmd == 'order_by_numbers': + args = tup[1] + translator = translator.order_by_numbers(args) + elif cmd == 'order_by_attributes': + args = tup[1] + translator = translator.order_by_attributes(args) + elif cmd == 'apply_lambda': + order_by, func_ast, argnames, extractors, vartypes = tup[1:] translator = translator.apply_lambda(i+1, order_by, func_ast, argnames, extractors, vartypes) + else: assert False, cmd return translator @cut_traceback def filter(query, *args, **kwargs): @@ -5639,8 +5645,9 @@ def _apply_kwargs(query, kwargs): value_dict[id] = val filterattrs = tuple(filterattrs) - new_key = query._key + ('filter', filterattrs) - new_filters = query._filters + ((filterattrs,),) + tup = ('apply_kwfilters', filterattrs) + new_key = query._key + (tup,) + new_filters = query._filters + (tup,) new_translator = query._database._translator_cache.get(new_key) if new_translator is None: new_translator = query._translator.apply_kwfilters(filterattrs) From 45246bd47d21c806e9b5a591199ed045b13ca270 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 Sep 2017 16:05:49 +0300 Subject: [PATCH 204/547] Minor optimization --- pony/orm/core.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 5fb1d8328..8d33869db 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5501,9 +5501,9 @@ def order_by(query, *args): if not args: throw(TypeError, 'order_by() method requires at least one argument') if args[0] is None: if len(args) > 1: throw(TypeError, 'When first argument of order_by() method is None, it must be the only argument') - tup = ('without_order',) - new_key = query._key + (tup,) - new_filters = query._filters + (tup,) + tup = (('without_order',),) + new_key = query._key + tup + new_filters = query._filters + tup new_translator = query._database._translator_cache.get(new_key) if new_translator is None: new_translator = query._translator.without_order() @@ -5526,9 +5526,9 @@ def order_by(query, *args): if numbers and attributes: throw(TypeError, 'order_by() method receive invalid combination of arguments') - tup = ('order_by_numbers' if numbers else 'order_by_attributes', args) - new_key = query._key + (tup,) - new_filters = query._filters + (tup,) + tup = (('order_by_numbers' if numbers else 'order_by_attributes', args),) + new_key = query._key + tup + new_filters = query._filters + tup new_translator = query._database._translator_cache.get(new_key) if new_translator is None: if numbers: new_translator = query._translator.order_by_numbers(args) @@ -5645,9 +5645,9 @@ def _apply_kwargs(query, kwargs): value_dict[id] = val filterattrs = tuple(filterattrs) - tup = ('apply_kwfilters', filterattrs) - new_key = query._key + (tup,) - new_filters = query._filters + (tup,) + tup = (('apply_kwfilters', filterattrs),) + new_key = query._key + tup + new_filters = query._filters + tup new_translator = query._database._translator_cache.get(new_key) if new_translator is None: new_translator = query._translator.apply_kwfilters(filterattrs) From 237d61840240d770e84a7f70eabde31817005cea Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 Sep 2017 16:20:19 +0300 Subject: [PATCH 205/547] Refactoring: simplify query._reapply_filters() --- pony/orm/core.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 8d33869db..3d9b9cdd6 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5574,7 +5574,7 @@ def _process_lambda(query, func, globals, locals, order_by): else: new_query_vars, vartypes, sorted_vartypes = query._vars, {}, () new_key = query._key + (('order_by' if order_by else 'filter', pretranslator_key, sorted_vartypes),) - new_filters = query._filters + (('apply_lambda', order_by, func_ast, argnames, extractors, vartypes),) + new_filters = query._filters + (('apply_lambda', filter_num, order_by, func_ast, argnames, extractors, vartypes),) new_translator = query._database._translator_cache.get(new_key) if new_translator is None: prev_optimized = prev_translator.optimize @@ -5593,23 +5593,10 @@ def _process_lambda(query, func, globals, locals, order_by): query._database._translator_cache[new_key] = new_translator return query._clone(_vars=new_query_vars, _key=new_key, _filters=new_filters, _translator=new_translator) def _reapply_filters(query, translator): - for i, tup in enumerate(query._filters): - cmd = tup[0] - if cmd == 'without_order': - translator = translator.without_order() - elif cmd == 'apply_kwfilters': - attrnames = tup[1] - translator = translator.apply_kwfilters(attrnames) - elif cmd == 'order_by_numbers': - args = tup[1] - translator = translator.order_by_numbers(args) - elif cmd == 'order_by_attributes': - args = tup[1] - translator = translator.order_by_attributes(args) - elif cmd == 'apply_lambda': - order_by, func_ast, argnames, extractors, vartypes = tup[1:] - translator = translator.apply_lambda(i+1, order_by, func_ast, argnames, extractors, vartypes) - else: assert False, cmd + for tup in query._filters: + method_name, args = tup[0], tup[1:] + translator_method = getattr(translator, method_name) + translator = translator_method(*args) return translator @cut_traceback def filter(query, *args, **kwargs): From ad36d36dfd4cae226260b1867dfa7f5e6adddd11 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 Sep 2017 15:30:29 +0300 Subject: [PATCH 206/547] query.where() method added --- pony/orm/core.py | 51 ++++++++++++++++++++++++++++---------- pony/orm/sqltranslation.py | 22 ++++++++++------ 2 files changed, 52 insertions(+), 21 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 3d9b9cdd6..536deda51 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5535,7 +5535,7 @@ def order_by(query, *args): else: new_translator = query._translator.order_by_attributes(args) query._database._translator_cache[new_key] = new_translator return query._clone(_key=new_key, _filters=new_filters, _translator=new_translator) - def _process_lambda(query, func, globals, locals, order_by): + def _process_lambda(query, func, globals, locals, order_by=False, original_names=False): prev_translator = query._translator argnames = () if isinstance(func, basestring): @@ -5555,11 +5555,16 @@ def _process_lambda(query, func, globals, locals, order_by): else: assert False # pragma: no cover if argnames: - expr_type = prev_translator.expr_type - expr_count = len(expr_type) if type(expr_type) is tuple else 1 - if len(argnames) != expr_count: - throw(TypeError, 'Incorrect number of lambda arguments. ' - 'Expected: %d, got: %d' % (expr_count, len(argnames))) + if original_names: + for name in argnames: + if name not in prev_translator.subquery.tablerefs: throw(TypeError, + 'Lambda argument %s does not correspond to any loop variable in original query' % name) + else: + expr_type = prev_translator.expr_type + expr_count = len(expr_type) if type(expr_type) is tuple else 1 + if len(argnames) != expr_count: + throw(TypeError, 'Incorrect number of lambda arguments. ' + 'Expected: %d, got: %d' % (expr_count, len(argnames))) filter_num = len(query._filters) + 1 extractors, varnames, func_ast, pretranslator_key = create_extractors( @@ -5574,11 +5579,11 @@ def _process_lambda(query, func, globals, locals, order_by): else: new_query_vars, vartypes, sorted_vartypes = query._vars, {}, () new_key = query._key + (('order_by' if order_by else 'filter', pretranslator_key, sorted_vartypes),) - new_filters = query._filters + (('apply_lambda', filter_num, order_by, func_ast, argnames, extractors, vartypes),) + new_filters = query._filters + (('apply_lambda', filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes),) new_translator = query._database._translator_cache.get(new_key) if new_translator is None: prev_optimized = prev_translator.optimize - new_translator = prev_translator.apply_lambda(filter_num, order_by, func_ast, argnames, extractors, vartypes) + new_translator = prev_translator.apply_lambda(filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes) if not prev_optimized: name_path = new_translator.can_be_optimized() if name_path: @@ -5589,7 +5594,7 @@ def _process_lambda(query, func, globals, locals, order_by): new_translator = translator_cls(tree, prev_extractors, prev_vartypes, left_join=True, optimize=name_path) new_translator = query._reapply_filters(new_translator) - new_translator = new_translator.apply_lambda(filter_num, order_by, func_ast, argnames, extractors, vartypes) + new_translator = new_translator.apply_lambda(filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes) query._database._translator_cache[new_key] = new_translator return query._clone(_vars=new_query_vars, _key=new_key, _filters=new_filters, _translator=new_translator) def _reapply_filters(query, translator): @@ -5612,8 +5617,28 @@ def filter(query, *args, **kwargs): if not isinstance(entity, EntityMeta): throw(TypeError, 'Keyword arguments are not allowed: since query result type is not an entity, filter() method can accept only lambda') return query._apply_kwargs(kwargs) - def _apply_kwargs(query, kwargs): - entity = query._translator.expr_type + @cut_traceback + def where(query, *args, **kwargs): + if args: + if isinstance(args[0], RawSQL): + raw = args[0] + return query.where(lambda: raw) + func, globals, locals = get_globals_and_locals(args, kwargs, frame_depth=3) + return query._process_lambda(func, globals, locals, order_by=False, original_names=True) + if not kwargs: return query + + if len(query._translator.tree.quals) > 1: throw(TypeError, + 'Keyword arguments are not allowed: query iterates over more than one entity') + return query._apply_kwargs(kwargs, original_names=True) + def _apply_kwargs(query, kwargs, original_names=False): + translator = query._translator + if original_names: + tablerefs = translator.subquery.tablerefs + alias = translator.tree.quals[0].assign.name + tableref = tablerefs[alias] + entity = tableref.entity + else: + entity = translator.expr_type get_attr = entity._adict_.get filterattrs = [] value_dict = {} @@ -5632,12 +5657,12 @@ def _apply_kwargs(query, kwargs): value_dict[id] = val filterattrs = tuple(filterattrs) - tup = (('apply_kwfilters', filterattrs),) + tup = (('apply_kwfilters', filterattrs, original_names),) new_key = query._key + tup new_filters = query._filters + tup new_translator = query._database._translator_cache.get(new_key) if new_translator is None: - new_translator = query._translator.apply_kwfilters(filterattrs) + new_translator = translator.apply_kwfilters(filterattrs, original_names) query._database._translator_cache[new_key] = new_translator new_query = query._clone(_key=new_key, _filters=new_filters, _translator=new_translator, _next_kwarg_id=next_id, _vars=query._vars.copy()) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 7171c3710..88b692e1d 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -218,6 +218,7 @@ def __init__(translator, tree, extractors, vartypes, parent_translator=None, lef tableref = TableRef(subquery, name, entity) tablerefs[name] = tableref tableref.make_join() + node.monad = translator.ObjectIterMonad(translator, tableref, entity) else: attr_names = [] while isinstance(node, ast.Getattr): @@ -578,29 +579,34 @@ def order_by_attributes(translator, attrs): new_order.append(desc_wrapper([ 'COLUMN', alias, column])) order[:0] = new_order return translator - def apply_kwfilters(translator, filterattrs): - entity = translator.expr_type - if not isinstance(entity, EntityMeta): - throw(TypeError, 'Keyword arguments are not allowed when query result is not entity objects') + def apply_kwfilters(translator, filterattrs, original_names=False): translator = deepcopy(translator) - expr_monad = translator.tree.expr.monad + if original_names: + object_monad = translator.tree.quals[0].iter.monad + assert isinstance(object_monad.type, EntityMeta) + else: + object_monad = translator.tree.expr.monad + if not isinstance(object_monad.type, EntityMeta): + throw(TypeError, 'Keyword arguments are not allowed when query result is not entity objects') + monads = [] none_monad = translator.NoneMonad(translator) for attr, id, is_none in filterattrs: - attr_monad = expr_monad.getattr(attr.name) + attr_monad = object_monad.getattr(attr.name) if is_none: monads.append(CmpMonad('is', attr_monad, none_monad)) else: param_monad = translator.ParamMonad.new(translator, attr.py_type, (id, None, None)) monads.append(CmpMonad('==', attr_monad, param_monad)) for m in monads: translator.conditions.extend(m.getsql()) return translator - def apply_lambda(translator, filter_num, order_by, func_ast, argnames, extractors, vartypes): + def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes): translator = deepcopy(translator) func_ast = copy_ast(func_ast) # func_ast = deepcopy(func_ast) translator.filter_num = filter_num translator.extractors.update(extractors) translator.vartypes.update(vartypes) translator.argnames = list(argnames) + translator.original_names = original_names translator.dispatch(func_ast) if isinstance(func_ast, ast.Tuple): nodes = func_ast.nodes else: nodes = (func_ast,) @@ -671,7 +677,7 @@ def postName(translator, node): t = translator while t is not None: argnames = t.argnames - if argnames is not None and name in argnames: + if argnames is not None and not t.original_names and name in argnames: i = argnames.index(name) return t.expr_monads[i] t = t.parent From f6f64269f9d385ca119eb92bfc09a9cc3590adcb Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 20 Sep 2017 17:55:16 +0300 Subject: [PATCH 207/547] query._key value fixed --- pony/orm/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 536deda51..394fcf51f 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5578,7 +5578,7 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names sorted_vartypes = tuple(vartypes[filter_num, name] for name in varnames) else: new_query_vars, vartypes, sorted_vartypes = query._vars, {}, () - new_key = query._key + (('order_by' if order_by else 'filter', pretranslator_key, sorted_vartypes),) + new_key = query._key + (('order_by' if order_by else 'where' if original_names else 'filter', pretranslator_key, sorted_vartypes),) new_filters = query._filters + (('apply_lambda', filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes),) new_translator = query._database._translator_cache.get(new_key) if new_translator is None: From cccb2a94d325a30eabec1628658e14bae91c2cb0 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 20 Sep 2017 18:50:05 +0300 Subject: [PATCH 208/547] Fixes #278: Cascade delete error: FOREIGN KEY constraint failed, with complex entity relationships --- pony/orm/core.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 394fcf51f..4e8ba665f 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4582,9 +4582,22 @@ def undo_func(): undo_funcs.append(undo_func) try: for attr in obj._attrs_: - reverse = attr.reverse - if not reverse: continue + if not attr.is_collection: continue + if isinstance(attr, Set): + set_wrapper = attr.__get__(obj) + if not set_wrapper.__nonzero__(): pass + elif attr.cascade_delete: + for robj in set_wrapper: robj._delete_(undo_funcs) + elif not attr.reverse.is_required: attr.__set__(obj, (), undo_funcs) + else: throw(ConstraintError, "Cannot delete object %s, because it has non-empty set of %s, " + "and 'cascade_delete' option of %s is not set" + % (obj, attr.name, attr)) + else: throw(NotImplementedError) + + for attr in obj._attrs_: if not attr.is_collection: + reverse = attr.reverse + if not reverse: continue if not reverse.is_collection: val = get_val(attr) if attr in obj._vals_ else attr.load(obj) if val is None: continue @@ -4599,16 +4612,6 @@ def undo_func(): if val is None: continue reverse.reverse_remove((val,), obj, undo_funcs) else: throw(NotImplementedError) - elif isinstance(attr, Set): - set_wrapper = attr.__get__(obj) - if not set_wrapper.__nonzero__(): pass - elif attr.cascade_delete: - for robj in set_wrapper: robj._delete_(undo_funcs) - elif not reverse.is_required: attr.__set__(obj, (), undo_funcs) - else: throw(ConstraintError, "Cannot delete object %s, because it has non-empty set of %s, " - "and 'cascade_delete' option of %s is not set" - % (obj, attr.name, attr)) - else: throw(NotImplementedError) cache_indexes = cache.indexes for attr in obj._simple_keys_: From ed6cf65bced91fb81465c104c1566eef1f7958c4 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 23 Sep 2017 16:43:32 +0300 Subject: [PATCH 209/547] Refactoring: table.add_entity() method --- pony/orm/core.py | 10 ++-------- pony/orm/dbschema.py | 17 +++++++++++++---- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 4e8ba665f..524b35cba 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -858,14 +858,8 @@ def get_columns(table, column_names): else: assert isinstance(table_name, (basestring, tuple)) table = schema.tables.get(table_name) - if table is None: table = schema.add_table(table_name) - elif table.entities: - for e in table.entities: - if e._root_ is not entity._root_: - throw(MappingError, "Entities %s and %s cannot be mapped to table %s " - "because they don't belong to the same hierarchy" - % (e, entity, table_name)) - table.entities.add(entity) + if table is None: table = schema.add_table(table_name, entity) + else: table.add_entity(entity) for attr in entity._new_attrs_: if attr.is_collection: diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index f968d555f..dd1b0f3a3 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -4,7 +4,7 @@ from operator import attrgetter from pony.orm import core -from pony.orm.core import log_sql, DBSchemaError +from pony.orm.core import log_sql, DBSchemaError, MappingError from pony.utils import throw class DBSchema(object): @@ -26,8 +26,8 @@ def case(schema, s): if schema.uppercase: return s.upper().replace('%S', '%s') \ .replace(')S', ')s').replace('%R', '%r').replace(')R', ')r') else: return s.lower() - def add_table(schema, table_name): - return schema.table_class(table_name, schema) + def add_table(schema, table_name, entity=None): + return schema.table_class(table_name, schema, entity) def order_tables_to_create(schema): tables = [] created_tables = set() @@ -85,7 +85,7 @@ def create(table, provider, connection): class Table(DBObject): typename = 'Table' - def __init__(table, name, schema): + def __init__(table, name, schema, entity=None): if name in schema.tables: throw(DBSchemaError, "Table %r already exists in database schema" % name) if name in schema.names: @@ -102,12 +102,21 @@ def __init__(table, name, schema): table.parent_tables = set() table.child_tables = set() table.entities = set() + if entity is not None: + table.entities.add(entity) table.m2m = set() def __repr__(table): table_name = table.name if isinstance(table_name, tuple): table_name = '.'.join(table_name) return '' % table_name + def add_entity(table, entity): + for e in table.entities: + if e._root_ is not entity._root_: + throw(MappingError, "Entities %s and %s cannot be mapped to table %s " + "because they don't belong to the same hierarchy" + % (e, entity, table.name)) + table.entities.add(entity) def exists(table, provider, connection, case_sensitive=True): return provider.table_exists(connection, table.name, case_sensitive) def get_create_command(table): From b2a68fc05a823abbd93e603f9c2618f723939435 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 23 Sep 2017 18:21:49 +0300 Subject: [PATCH 210/547] Fixes #295: Storage engine can not be set per table --- pony/orm/core.py | 8 ++++++++ pony/orm/dbschema.py | 12 ++++++++++++ 2 files changed, 20 insertions(+) diff --git a/pony/orm/core.py b/pony/orm/core.py index 524b35cba..44d878d36 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -838,6 +838,8 @@ def generate_mapping(database, filename=None, check_tables=True, create_tables=F entity._resolve_attr_types_() for entity in entities: entity._link_reverse_attrs_() + for entity in entities: + entity._check_table_options_() def get_columns(table, column_names): column_dict = table.column_dict @@ -3651,6 +3653,12 @@ def _link_reverse_attrs_(entity): attr2.reverse = attr attr.linked() attr2.linked() + def _check_table_options_(entity): + if entity._root_ is not entity: + if '_table_options_' in entity.__dict__: throw(TypeError, + 'Cannot redefine %s options in %s entity' % (entity._root_.__name__, entity.__name__)) + elif not hasattr(entity, '_table_options_'): + entity._table_options_ = {} def _get_pk_columns_(entity): if entity._pk_columns_ is not None: return entity._pk_columns_ pk_columns = [] diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index dd1b0f3a3..334fd3566 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -102,8 +102,10 @@ def __init__(table, name, schema, entity=None): table.parent_tables = set() table.child_tables = set() table.entities = set() + table.options = {} if entity is not None: table.entities.add(entity) + table.options = entity._table_options_ table.m2m = set() def __repr__(table): table_name = table.name @@ -116,6 +118,7 @@ def add_entity(table, entity): throw(MappingError, "Entities %s and %s cannot be mapped to table %s " "because they don't belong to the same hierarchy" % (e, entity, table.name)) + assert '_table_options_' not in entity.__dict__ table.entities.add(entity) def exists(table, provider, connection, case_sensitive=True): return provider.table_exists(connection, table.name, case_sensitive) @@ -143,7 +146,16 @@ def get_create_command(table): cmd.append(schema.indent+foreign_key.get_sql() + ',') cmd[-1] = cmd[-1][:-1] cmd.append(')') + for name, value in sorted(table.options.items()): + option = table.format_option(name, value) + if option: cmd.append(option) return '\n'.join(cmd) + def format_option(table, name, value): + if value is True: + return name + if value is False: + return None + return '%s %s' % (name, value) def get_objects_to_create(table, created_tables=None): if created_tables is None: created_tables = set() created_tables.add(table) From e299c785650fcf1e5c3baefb5ff3631070732fce Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 23 Sep 2017 18:29:28 +0300 Subject: [PATCH 211/547] Fixes #294: Real stack traces swallowed within IPython shell --- pony/utils/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pony/utils/utils.py b/pony/utils/utils.py index 7f7fb61fc..543ba133e 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -94,7 +94,8 @@ def cut_traceback(func, *args, **kwargs): last_pony_tb = tb tb = tb.tb_next if last_pony_tb is None: raise - if tb.tb_frame.f_globals.get('__name__').startswith('pony.utils') and tb.tb_frame.f_code.co_name == 'throw': + module_name = tb.tb_frame.f_globals.get('__name__') or '' + if module_name.startswith('pony.utils') and tb.tb_frame.f_code.co_name == 'throw': reraise(exc_type, exc, last_pony_tb) reraise(exc_type, exc, full_tb) finally: From 0547ea9b3aa56aacc50d7708aec7d7401016bd29 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 30 Sep 2017 15:37:27 +0300 Subject: [PATCH 212/547] between(x, a, b) function added --- pony/orm/core.py | 7 ++++--- pony/orm/sqltranslation.py | 13 ++++++++++++- pony/utils/utils.py | 3 +++ 3 files changed, 19 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 44d878d36..93db8ac51 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -27,7 +27,8 @@ ) from pony import utils from pony.utils import localbase, decorator, cut_traceback, throw, reraise, truncate_repr, get_lambda_args, \ - pickle_ast, unpickle_ast, deprecated, import_module, parse_expr, is_ident, tostring, strjoin, concat, coalesce + pickle_ast, unpickle_ast, deprecated, import_module, parse_expr, is_ident, tostring, strjoin, \ + between, concat, coalesce __all__ = [ 'pony', @@ -58,7 +59,7 @@ 'count', 'sum', 'min', 'max', 'avg', 'distinct', - 'JOIN', 'desc', 'concat', 'coalesce', 'raw_sql', + 'JOIN', 'desc', 'between', 'concat', 'coalesce', 'raw_sql', 'buffer', 'unicode', @@ -5839,5 +5840,5 @@ def show(entity): from pprint import pprint pprint(x) -special_functions = {itertools.count, utils.count, count, random, raw_sql, getattr, int, coalesce} +special_functions = {itertools.count, utils.count, count, random, raw_sql, getattr, int, between, coalesce} const_functions = {buffer, Decimal, datetime.datetime, datetime.date, datetime.time, datetime.timedelta} diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 88b692e1d..cdf15fa73 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -12,7 +12,7 @@ from pony.thirdparty.compiler import ast from pony import options, utils -from pony.utils import is_ident, throw, reraise, concat, copy_ast, coalesce +from pony.utils import is_ident, throw, reraise, copy_ast, between, concat, coalesce from pony.orm.asttranslation import ASTTranslator, ast2src, TranslationError from pony.orm.ormtypes import \ numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ @@ -2082,6 +2082,17 @@ def call_now(monad): translator = monad.translator return translator.DatetimeExprMonad(translator, datetime, [ 'NOW' ]) +class FuncBetweenMonad(FuncMonad): + func = between + def call(monad, x, a, b): + check_comparable(x, a, '<') + check_comparable(x, b, '<') + if isinstance(x.type, EntityMeta): throw(TypeError, + '%s instance cannot be argument of between() function: {EXPR}' % x.type.__name__) + translator = x.translator + sql = [ 'BETWEEN', x.getsql()[0], a.getsql()[0], b.getsql()[0] ] + return translator.BoolExprMonad(translator, sql) + class FuncConcatMonad(FuncMonad): func = concat def call(monad, *args): diff --git a/pony/utils/utils.py b/pony/utils/utils.py index 543ba133e..f717c98ee 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -516,6 +516,9 @@ def distinct(iter): def concat(*args): return ''.join(tostring(arg) for arg in args) +def between(a, x, y): + return a <= x <= y + def is_utf8(encoding): return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') From e1fb8450425c2c5540c38236abd7e1c70c4c4909 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 4 Oct 2017 15:37:25 +0300 Subject: [PATCH 213/547] Remove int, between and coalesce from special_functions --- pony/orm/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 93db8ac51..ff4a18ecf 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5840,5 +5840,5 @@ def show(entity): from pprint import pprint pprint(x) -special_functions = {itertools.count, utils.count, count, random, raw_sql, getattr, int, between, coalesce} +special_functions = {itertools.count, utils.count, count, random, raw_sql, getattr} const_functions = {buffer, Decimal, datetime.datetime, datetime.date, datetime.time, datetime.timedelta} From fcecf7a28f5c0483334a54ff31c8b25553de8bdc Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Mon, 16 Oct 2017 19:38:35 +0300 Subject: [PATCH 214/547] Bug fixed: Entity.get() sql should not have LIMIT 2 while searching by unique composite key --- pony/orm/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index ff4a18ecf..1bbc3c78d 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -3819,9 +3819,9 @@ def _find_in_cache_(entity, pkval, avdict, for_update=False): get_val = avdict.get vals = tuple(get_val(attr) for attr in attrs) if None in vals: continue + unique = True cache_index = cache_indexes.get(attrs) if cache_index is None: continue - unique = True obj = cache_index.get(vals) if obj is not None: break if obj is None: From 132c83a574fec5b7e72be51e343a19409e6ff976 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 23 Oct 2017 19:41:10 +0300 Subject: [PATCH 215/547] Add `sort_by` method as an alias to `order_by`. Later `order_by` semantics will be changed. You need to rename `order_by` calls to `sort_by` in your project --- pony/orm/core.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 1bbc3c78d..39360c2f0 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -3307,6 +3307,8 @@ def page(wrapper, pagenum, pagesize=10): return wrapper.select().page(pagenum, pagesize) def order_by(wrapper, *args): return wrapper.select().order_by(*args) + def sort_by(wrapper, *args): + return wrapper.select().sort_by(*args) def random(wrapper, limit): return wrapper.select().random(limit) @@ -5504,9 +5506,14 @@ def __iter__(query): return iter(query._fetch()) @cut_traceback def order_by(query, *args): - if not args: throw(TypeError, 'order_by() method requires at least one argument') + return query._order_by('order_by', *args) + @cut_traceback + def sort_by(query, *args): + return query._order_by('sort_by', *args) + def _order_by(query, method_name, *args): + if not args: throw(TypeError, '%s() method requires at least one argument' % method_name) if args[0] is None: - if len(args) > 1: throw(TypeError, 'When first argument of order_by() method is None, it must be the only argument') + if len(args) > 1: throw(TypeError, 'When first argument of %s() method is None, it must be the only argument' % method_name) tup = (('without_order',),) new_key = query._key + tup new_filters = query._filters + tup @@ -5517,7 +5524,7 @@ def order_by(query, *args): return query._clone(_key=new_key, _filters=new_filters, _translator=new_translator) if isinstance(args[0], (basestring, types.FunctionType)): - func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=3) + func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=4) return query._process_lambda(func, globals, locals, order_by=True) if isinstance(args[0], RawSQL): From 0f809f18bc3329564808466c3a45acaccb508fa2 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 23 Oct 2017 22:11:30 +0300 Subject: [PATCH 216/547] Update changelog and change pony version: 0.7.3-dev -> 0.7.3 --- CHANGELOG.md | 46 ++++++++++++++++++++++++++++++++++++++++++++++ pony/__init__.py | 2 +- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index bc3375323..313c2aacf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,49 @@ +# Pony ORM Release 0.7.3 (2017-10-23) + +## New features + +* `where()` method added to query +* `coalesce()` function added +* `between(x, a, b)` function added +* #295: Add `_table_options_` for entity class to specify engine, tablespace, etc. +* Make debug flag thread-local +* `sql_debugging` context manager added +* `sql_debug` and show_values arguments to db_session added +* `set_sql_debug` function added as alias to (to be deprecated) `sql_debug` function +* Allow `db_session` to accept `ddl` parameter when used as context manager +* Add `optimistic=True` option to db_session +* Skip optimistic checks for queries in `db_session` with `serializable=True` +* `fk_name` option added for attributes in order to specify foreign key name +* #280: Now it's possible to specify `timeout` option, as well as pass other keyword arguments for `sqlite3.connect` function +* Add support of explicit casting to int in queries using `int()` function +* Added modulo division % native support in queries + +## Bugfixes + +* Fix bugs with composite table names +* Fix invalid foreign key & index names for tables which names include schema name +* For queries like `select(x for x in MyObject if not x.description)` add "OR x.info IS NULL" for nullable string columns +* Add optimistic checking for `delete()` method +* Show updated attributes when `OptimisticCheckError` is being raised +* Fix incorrect aliases in nested queries +* Correctly pass exception from user-defined functions in SQLite +* More clear error messages for `UnrepeatableReadError` +* Fix `db_session(strict=True)` which was broken in 2d3afb24 +* Fixes #170: Problem with a primary key column used as a part of another key +* Fixes #223: incorrect result of `getattr(entity, attrname)` when the same lambda applies to different entities +* Fixes #266: Add handler to `"pony.orm"` logger does not work +* Fixes #278: Cascade delete error: FOREIGN KEY constraint failed, with complex entity relationships +* Fixes #283: Lost Json update immediately after object creation +* Fixes #284: `query.order_by()` orders Json numbers like strings +* Fixes #288: Expression text parsing issue in Python 3 +* Fixes #293: translation of if-expressions in expression +* Fixes #294: Real stack traces swallowed within IPython shell +* `Collection.count()` method should check if session is alive +* Set `obj._session_cache_` to None after exiting from db session for better garbage collection +* Unload collections which are not fully loaded after exiting from db session for better garbage collection +* Raise on unknown options for attributes that are part of relationship + + # Pony ORM Release 0.7.2 (2017-07-17) ## New features diff --git a/pony/__init__.py b/pony/__init__.py index 99ca019c1..ed8f07475 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.3-dev' +__version__ = '0.7.3' uid = str(random.randint(1, 1000000)) From 5dd9f4ef4b02cfa55d38cedd051a512217cad4d8 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 28 Oct 2017 14:51:14 +0300 Subject: [PATCH 217/547] Change Pony version: 0.7.3 -> 0.7.4-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index ed8f07475..2f08649ef 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.3' +__version__ = '0.7.4-dev' uid = str(random.randint(1, 1000000)) From 37195e112b7152065fac8b4a6351d65d54977947 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 1 Nov 2017 00:09:04 +0300 Subject: [PATCH 218/547] __contains__ method should check if objects belong to the same db_session --- pony/orm/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/core.py b/pony/orm/core.py index 39360c2f0..19efc0445 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -3159,6 +3159,8 @@ def __contains__(wrapper, item): if obj._status_ in del_statuses: throw_object_was_deleted(obj) if obj._vals_ is None: throw_db_session_is_over('read value of', obj, attr) if not isinstance(item, attr.py_type): return False + if item._session_cache_ is not obj._session_cache_: + throw(TransactionError, 'An attempt to mix objects belonging to different transactions') reverse = attr.reverse if not reverse.is_collection: From 6d475ce4c7a5758bca3a53a2ea088c3076eaed79 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 1 Nov 2017 16:14:16 +0300 Subject: [PATCH 219/547] desc() function fixed to allow reverse its effect by calling desc(desc(x)) --- pony/orm/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 19efc0445..2bcaf9dae 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5205,7 +5205,9 @@ def JOIN(expr): def desc(expr): if isinstance(expr, Attribute): return expr.desc - if isinstance(expr, int_types) and expr > 0: + if isinstance(expr, DescWrapper): + return expr.attr + if isinstance(expr, int_types): return -expr if isinstance(expr, basestring): return 'desc(%s)' % expr From 01b1c0e3a47060d0ddb3753e88d024de4de75a37 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 1 Nov 2017 16:07:35 +0300 Subject: [PATCH 220/547] Improved ImportError exception messages when DBAPI provider was not found --- pony/orm/dbproviders/mysql.py | 2 +- pony/orm/dbproviders/postgres.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 28ab3aa5f..c42469970 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -21,7 +21,7 @@ try: import pymysql as mysql_module except ImportError: - raise ImportError('No module named MySQLdb or pymysql found') + raise ImportError('In order to use PonyORM with MySQL please install MySQLdb or pymysql') from pymysql.converters import escape_str as string_literal import pymysql.converters as mysql_converters from pymysql.constants import FIELD_TYPE, FLAG, CLIENT diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 41412d8d5..0e0b24e75 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -8,8 +8,12 @@ try: import psycopg2 except ImportError: - from psycopg2cffi import compat - compat.register() + try: + from psycopg2cffi import compat + except ImportError: + raise ImportError('In order to use PonyORM with PostgreSQL please install psycopg2 or psycopg2cffi') + else: + compat.register() from psycopg2 import extensions From 8504c099079300293a1258a9a6246e88cdd47c31 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sun, 12 Nov 2017 16:13:19 +0300 Subject: [PATCH 221/547] Fixes #306: support of frozenset constants --- pony/orm/sqltranslation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index cdf15fa73..6679b9d45 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -662,6 +662,8 @@ def preCompare(translator, node): return translator.AndMonad(monads) def postConst(translator, node): value = node.value + if type(value) is frozenset: + value = tuple(sorted(value)) if type(value) is not tuple: return translator.ConstMonad.new(translator, value) else: From a5c813ddaf41d550f5bf486fc2d8fdf0a3f516bb Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 12 Nov 2017 16:35:23 +0300 Subject: [PATCH 222/547] Fixes #308: an error when assigning JSON attribute value to the same attribute: obj.json_attr = obj.json_attr --- pony/orm/dbapiprovider.py | 2 +- pony/orm/tests/test_json.py | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index ea8a13eb1..6eafad584 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -768,7 +768,7 @@ def default(converter, obj): def validate(converter, val, obj=None): if obj is None or converter.attr is None: return val - if isinstance(val, TrackedValue) and val.obj is obj and val.attr is converter.attr: + if isinstance(val, TrackedValue) and val.obj_ref() is obj and val.attr is converter.attr: return val return TrackedValue.make(obj, converter.attr, val) def val2dbval(converter, val, obj=None): diff --git a/pony/orm/tests/test_json.py b/pony/orm/tests/test_json.py index 2858a1056..d125b629f 100644 --- a/pony/orm/tests/test_json.py +++ b/pony/orm/tests/test_json.py @@ -278,6 +278,11 @@ def test_dict_set_item(self): p = get(p for p in self.Product) self.assertDictEqual(p.info['os'], {'type': 'iOS', 'version': '9'}) + @db_session + def test_set_same_value(self): + p = get(p for p in self.Product) + p.info = p.info + @db_session def test_len(self): with raises_if(self, self.db.provider.dialect == 'Oracle', From afcef756b75aaae087993a9e8f757eb82d57ec46 Mon Sep 17 00:00:00 2001 From: CW Andrews Date: Sun, 12 Nov 2017 14:44:21 -0500 Subject: [PATCH 223/547] Update README.md Updates to improve English grammar and readability. --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index ff58d1abb..242b02e7f 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,21 @@ Pony Object-Relational Mapper ============================= -Pony is an advanced object-relational mapper. The most interesting feature of Pony is its ability to write queries to the database using Python generator expressions. Pony analyzes the abstract syntax tree of the generator expression and translates it to into SQL query. +Pony is an advanced object-relational mapper. The most interesting feature of Pony is its ability to write queries to the database using Python generator expressions. Pony analyzes the abstract syntax tree of the generator expression and translates it into a SQL query. -Here is the example of a query in Pony: +Here is an example query in Pony: select(p for p in Product if p.name.startswith('A') and p.cost <= 1000) Pony translates queries to SQL using a specific database dialect. Currently Pony works with SQLite, MySQL, PostgreSQL and Oracle databases. -By providing a Pythonic API, Pony facilitates fast app development. Pony is an easy-to-learn and easy-to-use library. It makes your work more productive and helps saving resources. Pony achieves the easiness of use through the following: +By providing a Pythonic API, Pony facilitates fast app development. Pony is an easy-to-learn and easy-to-use library. It makes your work more productive and helps save resources. Pony achieves this ease of use through the following: * Compact entity definitions -* Concise query language -* Ability to work with Pony interactively in Python interpreter -* Comprehensive error messages, showing the exact part where error happened in the query -* Displaying the generated SQL in readable format with indentation +* The concise query language +* Ability to work with Pony interactively in a Python interpreter +* Comprehensive error messages, showing the exact part where an error occurred in the query +* Displaying of the generated SQL in a readable format with indentation All this helps the developer to focus on implementing the business logic of an application, instead of struggling with a mapper trying to understand how to get the data from the database. From f56cc504584c10d431459665c2c0751b6caaf38e Mon Sep 17 00:00:00 2001 From: CW Andrews Date: Sun, 12 Nov 2017 14:44:21 -0500 Subject: [PATCH 224/547] Update README.md Updates to improve English grammar and readability. --- README.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index ff58d1abb..242b02e7f 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,21 @@ Pony Object-Relational Mapper ============================= -Pony is an advanced object-relational mapper. The most interesting feature of Pony is its ability to write queries to the database using Python generator expressions. Pony analyzes the abstract syntax tree of the generator expression and translates it to into SQL query. +Pony is an advanced object-relational mapper. The most interesting feature of Pony is its ability to write queries to the database using Python generator expressions. Pony analyzes the abstract syntax tree of the generator expression and translates it into a SQL query. -Here is the example of a query in Pony: +Here is an example query in Pony: select(p for p in Product if p.name.startswith('A') and p.cost <= 1000) Pony translates queries to SQL using a specific database dialect. Currently Pony works with SQLite, MySQL, PostgreSQL and Oracle databases. -By providing a Pythonic API, Pony facilitates fast app development. Pony is an easy-to-learn and easy-to-use library. It makes your work more productive and helps saving resources. Pony achieves the easiness of use through the following: +By providing a Pythonic API, Pony facilitates fast app development. Pony is an easy-to-learn and easy-to-use library. It makes your work more productive and helps save resources. Pony achieves this ease of use through the following: * Compact entity definitions -* Concise query language -* Ability to work with Pony interactively in Python interpreter -* Comprehensive error messages, showing the exact part where error happened in the query -* Displaying the generated SQL in readable format with indentation +* The concise query language +* Ability to work with Pony interactively in a Python interpreter +* Comprehensive error messages, showing the exact part where an error occurred in the query +* Displaying of the generated SQL in a readable format with indentation All this helps the developer to focus on implementing the business logic of an application, instead of struggling with a mapper trying to understand how to get the data from the database. From 2edd28417d14ab02d2f295f64cf65fea679152bb Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 13 Nov 2017 11:53:57 +0300 Subject: [PATCH 225/547] Typo fixed --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 242b02e7f..cc7e5f5c7 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ Here is an example query in Pony: Pony translates queries to SQL using a specific database dialect. Currently Pony works with SQLite, MySQL, PostgreSQL and Oracle databases. -By providing a Pythonic API, Pony facilitates fast app development. Pony is an easy-to-learn and easy-to-use library. It makes your work more productive and helps save resources. Pony achieves this ease of use through the following: +By providing a Pythonic API, Pony facilitates fast app development. Pony is an easy-to-learn and easy-to-use library. It makes your work more productive and helps to save resources. Pony achieves this ease of use through the following: * Compact entity definitions * The concise query language From 6bb72b501e3283d00a1bcfc19ccb75aa40c45f25 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 23 Nov 2017 01:09:56 +0300 Subject: [PATCH 226/547] Fixes #314: AttributeError: 'NoneType' object has no attribute 'seeds' --- pony/orm/core.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 2bcaf9dae..ac0d72a17 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2108,8 +2108,9 @@ def get(attr, obj): if vals is None: throw_db_session_is_over('read value of', obj, attr) val = vals[attr] if attr in vals else attr.load(obj) if val is not None and attr.reverse and val._subclasses_ and val._status_ not in ('deleted', 'cancelled'): - seeds = obj._session_cache_.seeds[val._pk_attrs_] - if val in seeds: val._load_() + cache = obj._session_cache_ + if cache is not None and val in cache.seeds[val._pk_attrs_]: + val._load_() return val @cut_traceback def __set__(attr, obj, new_val, undo_funcs=None): From e1dc7feb710df12e844bec7dbb5880e125ca5fb1 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 25 Nov 2017 17:35:16 +0300 Subject: [PATCH 227/547] Fix #308: JSON updating bug introduced in commit 1ba41405 --- pony/orm/core.py | 5 ++++- pony/orm/tests/test_json.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index ac0d72a17..8534b1fdb 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4810,7 +4810,8 @@ def _update_dbvals_(obj, after_create, new_dbvals): elif after_create and val is None: obj._rbits_ &= ~bits[attr] else: - dbvals[attr] = new_dbvals.get(attr, val) + if attr in new_dbvals: + dbvals[attr] = new_dbvals[attr] continue # Clear value of volatile attribute or null values after create, because the value may be changed in the DB del vals[attr] @@ -4832,6 +4833,7 @@ def _save_created_(obj): new_dbvals[attr] = dbval values.append(dbval) else: + new_dbvals[attr] = val values.extend(attr.get_raw_values(val)) attrs = tuple(attrs) @@ -4896,6 +4898,7 @@ def _save_updated_(obj): new_dbvals[attr] = dbval values.append(dbval) else: + new_dbvals[attr] = val values.extend(attr.get_raw_values(val)) if update_columns: for attr in obj._pk_attrs_: diff --git a/pony/orm/tests/test_json.py b/pony/orm/tests/test_json.py index d125b629f..721b61bc5 100644 --- a/pony/orm/tests/test_json.py +++ b/pony/orm/tests/test_json.py @@ -644,3 +644,13 @@ def test_nonzero(self): with db_session: val = select(p.info['id'] for p in Product if not p.info['val']) self.assertEqual(tuple(sorted(val)), (2, 3, 5, 7, 9, 11)) + + @db_session + def test_optimistic_check(self): + p1 = self.Product.select().first() + p1.info['foo'] = 'bar' + flush() + p1.name = 'name2' + flush() + p1.name = 'name3' + flush() From 1008725ce4adc3a83625ff22db9bb66d804a13f2 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 26 Nov 2017 13:01:17 +0300 Subject: [PATCH 228/547] Rename some tests --- pony/orm/tests/test_db_session.py | 74 +++++++++++++++---------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/pony/orm/tests/test_db_session.py b/pony/orm/tests/test_db_session.py index 05e538208..dc6f8c197 100644 --- a/pony/orm/tests/test_db_session.py +++ b/pony/orm/tests/test_db_session.py @@ -90,19 +90,47 @@ def test(): else: self.fail() + def test_allowed_exceptions_1(self): + # allowed_exceptions may be callable, should commit if nonzero + @db_session(allowed_exceptions=lambda e: isinstance(e, ZeroDivisionError)) + def test(): + self.X(a=3, b=3) + 1/0 + try: + test() + except ZeroDivisionError: + with db_session: + self.assertEqual(count(x for x in self.X), 3) + else: + self.fail() + + def test_allowed_exceptions_2(self): + # allowed_exceptions may be callable, should rollback if not nonzero + @db_session(allowed_exceptions=lambda e: isinstance(e, TypeError)) + def test(): + self.X(a=3, b=3) + 1/0 + try: + test() + except ZeroDivisionError: + with db_session: + self.assertEqual(count(x for x in self.X), 2) + else: + self.fail() + @raises_exception(TypeError, "'retry' parameter of db_session must be of integer type. Got: %r" % str) - def test_db_session_decorator_5(self): + def test_retry_1(self): @db_session(retry='foobar') def test(): pass @raises_exception(TypeError, "'retry' parameter of db_session must not be negative. Got: -1") - def test_db_session_decorator_6(self): + def test_retry_2(self): @db_session(retry=-1) def test(): pass - def test_db_session_decorator_7(self): + def test_retry_3(self): # Should not to do retry until retry count is specified counter = count() @db_session(retry_exceptions=[ZeroDivisionError]) @@ -119,7 +147,7 @@ def test(): else: self.fail() - def test_db_session_decorator_8(self): + def test_retry_4(self): # Should rollback & retry 1 time if retry=1 counter = count() @db_session(retry=1, retry_exceptions=[ZeroDivisionError]) @@ -136,7 +164,7 @@ def test(): else: self.fail() - def test_db_session_decorator_9(self): + def test_retry_5(self): # Should rollback & retry N time if retry=N counter = count() @db_session(retry=5, retry_exceptions=[ZeroDivisionError]) @@ -153,7 +181,7 @@ def test(): else: self.fail() - def test_db_session_decorator_10(self): + def test_retry_6(self): # Should not retry if the exception not in the list of retry_exceptions counter = count() @db_session(retry=3, retry_exceptions=[TypeError]) @@ -170,7 +198,7 @@ def test(): else: self.fail() - def test_db_session_decorator_11(self): + def test_retry_7(self): # Should commit after successful retrying counter = count() @db_session(retry=5, retry_exceptions=[ZeroDivisionError]) @@ -189,41 +217,13 @@ def test(): @raises_exception(TypeError, "The same exception ZeroDivisionError cannot be specified " "in both allowed and retry exception lists simultaneously") - def test_db_session_decorator_12(self): + def test_retry_8(self): @db_session(retry=3, retry_exceptions=[ZeroDivisionError], allowed_exceptions=[ZeroDivisionError]) def test(): pass - def test_db_session_decorator_13(self): - # allowed_exceptions may be callable, should commit if nonzero - @db_session(allowed_exceptions=lambda e: isinstance(e, ZeroDivisionError)) - def test(): - self.X(a=3, b=3) - 1/0 - try: - test() - except ZeroDivisionError: - with db_session: - self.assertEqual(count(x for x in self.X), 3) - else: - self.fail() - - def test_db_session_decorator_14(self): - # allowed_exceptions may be callable, should rollback if not nonzero - @db_session(allowed_exceptions=lambda e: isinstance(e, TypeError)) - def test(): - self.X(a=3, b=3) - 1/0 - try: - test() - except ZeroDivisionError: - with db_session: - self.assertEqual(count(x for x in self.X), 2) - else: - self.fail() - - def test_db_session_decorator_15(self): + def test_retry_9(self): # retry_exceptions may be callable, should retry if nonzero counter = count() @db_session(retry=3, retry_exceptions=lambda e: isinstance(e, ZeroDivisionError)) From de6201c40189f901219b73dfe8521ee49b0e839d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 26 Nov 2017 13:25:02 +0300 Subject: [PATCH 229/547] Fixes #313: missed retry on exception raised during db_session.__exit__ --- pony/orm/core.py | 5 ++++- pony/orm/tests/test_db_session.py | 16 +++++++++++++++- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 8534b1fdb..e2e41263c 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -457,7 +457,10 @@ def new_func(func, *args, **kwargs): for i in xrange(db_session.retry+1): db_session._enter() exc_type = exc = tb = None - try: return func(*args, **kwargs) + try: + result = func(*args, **kwargs) + flush() + return result except: exc_type, exc, tb = sys.exc_info() retry_exceptions = db_session.retry_exceptions diff --git a/pony/orm/tests/test_db_session.py b/pony/orm/tests/test_db_session.py index dc6f8c197..d9a496c95 100644 --- a/pony/orm/tests/test_db_session.py +++ b/pony/orm/tests/test_db_session.py @@ -12,7 +12,7 @@ class TestDBSession(unittest.TestCase): def setUp(self): self.db = Database('sqlite', ':memory:') class X(self.db.Entity): - a = Required(int) + a = PrimaryKey(int) b = Optional(int) self.X = X self.db.generate_mapping(create_tables=True) @@ -240,6 +240,20 @@ def test(): else: self.fail() + def test_retry_10(self): + # Issue 313: retry on exception raised during db_session.__exit__ + retries = count() + @db_session(retry=3) + def test(): + next(retries) + self.X(a=1, b=1) + try: + test() + except TransactionIntegrityError: + self.assertEqual(next(retries), 4) + else: + self.fail() + def test_db_session_manager_1(self): with db_session: self.X(a=3, b=3) From 576562b564c46f0c16fc914fa148accd52b56da4 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sun, 26 Nov 2017 14:23:28 +0300 Subject: [PATCH 230/547] Fix: handling incorrect datetime values in mysql --- pony/orm/dbproviders/mysql.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index c42469970..968192317 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -329,4 +329,7 @@ def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True): def str2datetime(s): if 19 < len(s) < 26: s += '000000'[:26-len(s)] s = s.replace('-', ' ').replace(':', ' ').replace('.', ' ').replace('T', ' ') - return datetime(*imap(int, s.split())) + try: + return datetime(*imap(int, s.split())) + except ValueError: + return None # for incorrect values like 0000-00-00 00:00:00 From a79f8a0648c22dfa634dd0caa3703d4571d85ec2 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 26 Nov 2017 14:52:26 +0300 Subject: [PATCH 231/547] Fix retry handling: in PostgreSQL and Oracle an error can be raised during commit --- pony/orm/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index e2e41263c..8d96c4de6 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -459,7 +459,7 @@ def new_func(func, *args, **kwargs): exc_type = exc = tb = None try: result = func(*args, **kwargs) - flush() + commit() return result except: exc_type, exc, tb = sys.exc_info() From 792950f6e2b10d64381dcd76f2616216375bf59e Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sun, 26 Nov 2017 15:33:54 +0300 Subject: [PATCH 232/547] Fix #315: attribute lifting for JSON attributes --- pony/orm/core.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 8d96c4de6..b38709065 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -19,7 +19,7 @@ import pony from pony import options from pony.orm.decompiling import decompile -from pony.orm.ormtypes import LongStr, LongUnicode, numeric_types, RawSQL, get_normalized_type_of, Json +from pony.orm.ormtypes import LongStr, LongUnicode, numeric_types, RawSQL, get_normalized_type_of, Json, TrackedValue from pony.orm.asttranslation import ast2src, create_extractors, TranslationError from pony.orm.dbapiprovider import ( DBAPIProvider, DBException, Warning, Error, InterfaceError, DatabaseError, DataError, @@ -4167,6 +4167,8 @@ def _get_propagation_mixin_(entity): def fget(wrapper, attr=attr): attrnames = wrapper._attrnames_ + (attr.name,) items = [ x for x in (attr.__get__(item) for item in wrapper) if x is not None ] + if attr.py_type is Json: + return [ item.get_untracked() if isinstance(item, TrackedValue) else item for item in items ] return Multiset(wrapper._obj_, attrnames, items) elif not attr.is_collection: def fget(wrapper, attr=attr): From eeb41addc78bf1e8e1e9817da1bcdc06588a1389 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 14 Nov 2017 18:51:13 +0300 Subject: [PATCH 233/547] JSON optimization --- pony/orm/dbapiprovider.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 6eafad584..610eae237 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -781,6 +781,7 @@ def dbval2val(converter, dbval, obj=None): return val return TrackedValue.make(obj, converter.attr, val) def dbvals_equal(converter, x, y): + if x == y: return True # optimization if isinstance(x, basestring): x = json.loads(x) if isinstance(y, basestring): y = json.loads(y) return x == y From c6191c24301326d9bd404ca9fbf373c247c2e725 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 22 Nov 2017 18:50:28 +0300 Subject: [PATCH 234/547] Fix Entity._construct_optimistic_criteria_() --- pony/orm/core.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index b38709065..791959fb7 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4776,10 +4776,7 @@ def _construct_optimistic_criteria_(obj): optimistic_converters.extend(attr.converters) values = attr.get_raw_values(dbval) optimistic_values.extend(values) - if dbval is None: - optimistic_operations.append('IS_NULL') - else: - optimistic_operations.extend(converter.EQ for converter in converters) + optimistic_operations.extend('IS_NULL' if dbval is None else converter.EQ for converter in converters) return optimistic_operations, optimistic_columns, optimistic_converters, optimistic_values def _save_principal_objects_(obj, dependent_objects): if dependent_objects is None: dependent_objects = [] From 4dc41d29e24d546549a1dafd2bd692fa367b49b2 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 29 Nov 2017 13:45:39 +0300 Subject: [PATCH 235/547] Don't raise OptimisticCheckError if db_session is not optimistic --- pony/orm/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 791959fb7..e51168306 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4935,7 +4935,7 @@ def _save_updated_(obj): else: sql, adapter = cached_sql arguments = adapter(values) cursor = database._exec_sql(sql, arguments, start_transaction=True) - if cursor.rowcount == 0: + if cursor.rowcount == 0 and cache.db_session.optimistic: throw(OptimisticCheckError, obj.find_updated_attributes()) obj._status_ = 'updated' obj._rbits_ |= obj._wbits_ & obj._all_bits_except_volatile_ @@ -4966,7 +4966,7 @@ def _save_deleted_(obj): else: sql, adapter = cached_sql arguments = adapter(values) cursor = database._exec_sql(sql, arguments, start_transaction=True) - if cursor.rowcount == 0: + if cursor.rowcount == 0 and cache.db_session.optimistic: throw(OptimisticCheckError, obj.find_updated_attributes()) obj._status_ = 'deleted' cache.indexes[obj._pk_attrs_].pop(obj._pkval_) From 8402428138f44420d42d622f7a253b105de9fc07 Mon Sep 17 00:00:00 2001 From: Zubeyr Dereli Date: Fri, 31 Mar 2017 17:46:26 +0200 Subject: [PATCH 236/547] Fixes #249: Incorrect mixin used for Timedelta --- pony/orm/sqltranslation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 6679b9d45..eb257c5fd 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1708,7 +1708,7 @@ class StringParamMonad(StringMixin, ParamMonad): pass class NumericParamMonad(NumericMixin, ParamMonad): pass class DateParamMonad(DateMixin, ParamMonad): pass class TimeParamMonad(TimeMixin, ParamMonad): pass -class TimedeltaParamMonad(TimeMixin, ParamMonad): pass +class TimedeltaParamMonad(TimedeltaMixin, ParamMonad): pass class DatetimeParamMonad(DatetimeMixin, ParamMonad): pass class BufferParamMonad(BufferMixin, ParamMonad): pass class UuidParamMonad(UuidMixin, ParamMonad): pass From 1795b6c3c9bf7bcc20a1a0c04121611676abce02 Mon Sep 17 00:00:00 2001 From: Zubeyr Dereli Date: Fri, 31 Mar 2017 17:46:26 +0200 Subject: [PATCH 237/547] Fixes #249: Incorrect mixin used for Timedelta --- pony/orm/sqltranslation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 6679b9d45..eb257c5fd 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1708,7 +1708,7 @@ class StringParamMonad(StringMixin, ParamMonad): pass class NumericParamMonad(NumericMixin, ParamMonad): pass class DateParamMonad(DateMixin, ParamMonad): pass class TimeParamMonad(TimeMixin, ParamMonad): pass -class TimedeltaParamMonad(TimeMixin, ParamMonad): pass +class TimedeltaParamMonad(TimedeltaMixin, ParamMonad): pass class DatetimeParamMonad(DatetimeMixin, ParamMonad): pass class BufferParamMonad(BufferMixin, ParamMonad): pass class UuidParamMonad(UuidMixin, ParamMonad): pass From 53150fa8f57e59959abf6957ab9017a31a3ec726 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 15 Dec 2017 03:03:31 +0300 Subject: [PATCH 238/547] An attempt to fix issue 321: KeyError on delete --- pony/orm/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/core.py b/pony/orm/core.py index e51168306..fad20369c 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1800,6 +1800,7 @@ def db_update_simple_index(cache, obj, attr, old_dbval, new_dbval): # attribute which was created or updated lately clashes with one stored in database cache_index.pop(old_dbval, None) def update_composite_index(cache, obj, attrs, prev_vals, new_vals, undo): + if prev_vals == new_vals: return if None in prev_vals: prev_vals = None if None in new_vals: new_vals = None if prev_vals is None and new_vals is None: return @@ -1813,6 +1814,7 @@ def update_composite_index(cache, obj, attrs, prev_vals, new_vals, undo): if prev_vals is not None: del cache_index[prev_vals] undo.append((cache_index, prev_vals, new_vals)) def db_update_composite_index(cache, obj, attrs, prev_vals, new_vals): + if prev_vals == new_vals: return cache_index = cache.indexes[attrs] if None not in new_vals: obj2 = cache_index.setdefault(new_vals, obj) From 4d894f6ef9a92aafe5aded7ebe3faa5cb1c41ff8 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 Dec 2017 12:18:33 +0300 Subject: [PATCH 239/547] Refactoring --- pony/orm/core.py | 32 ++++++++++++++------------------ pony/orm/sqltranslation.py | 8 +++----- 2 files changed, 17 insertions(+), 23 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index fad20369c..2d682655c 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4553,15 +4553,13 @@ def _db_set_(obj, avdict, unpickling=False): cache.db_update_simple_index(obj, attr, old_val, new_dbval) for attrs in obj._composite_keys_: - for attr in attrs: - if attr in avdict: break - else: continue - vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! - prev_vals = tuple(vals) - for i, attr in enumerate(attrs): - if attr in avdict: vals[i] = avdict[attr] - new_vals = tuple(vals) - cache.db_update_composite_index(obj, attrs, prev_vals, new_vals) + if any(attr in avdict for attr in attrs): + vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! + prev_vals = tuple(vals) + for i, attr in enumerate(attrs): + if attr in avdict: vals[i] = avdict[attr] + new_vals = tuple(vals) + cache.db_update_composite_index(obj, attrs, prev_vals, new_vals) for attr, new_val in iteritems(avdict): if not attr.reverse: @@ -4727,15 +4725,13 @@ def undo_func(): old_val = get_val(attr) if old_val != new_val: cache.update_simple_index(obj, attr, old_val, new_val, undo) for attrs in obj._composite_keys_: - for attr in attrs: - if attr in avdict: break - else: continue - vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! - prev_vals = tuple(vals) - for i, attr in enumerate(attrs): - if attr in avdict: vals[i] = avdict[attr] - new_vals = tuple(vals) - cache.update_composite_index(obj, attrs, prev_vals, new_vals, undo) + if any(attr in avdict for attr in attrs): + vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! + prev_vals = tuple(vals) + for i, attr in enumerate(attrs): + if attr in avdict: vals[i] = avdict[attr] + new_vals = tuple(vals) + cache.update_composite_index(obj, attrs, prev_vals, new_vals, undo) for attr, new_val in iteritems(avdict): if not attr.reverse: continue old_val = get_val(attr) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index eb257c5fd..e0ff67075 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -326,11 +326,9 @@ def __init__(translator, tree, extractors, vartypes, parent_translator=None, lef for tr in tablerefs.values(): if not tr.can_affect_distinct: continue if tr.name_path in expr_set: continue - for attr in tr.entity._pk_attrs_: - if (tr.name_path, attr) not in expr_set: break - else: continue - translator.distinct = True - break + if any((tr.name_path, attr) not in expr_set for attr in tr.entity._pk_attrs_): + translator.distinct = True + break row_layout = [] offset = 0 provider = translator.database.provider From 5bc3ac15f2fe626cc073689fc177fe6bc49faacf Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 Dec 2017 12:22:53 +0300 Subject: [PATCH 240/547] Refactoring --- pony/orm/core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 2d682655c..c1dde3990 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4700,9 +4700,7 @@ def set(obj, **kwargs): objects_to_save.append(obj) cache.modified = True if not collection_avdict: - for attr in avdict: - if attr.reverse or attr.is_part_of_unique_index: break - else: + if not any(attr.reverse or attr.is_part_of_unique_index for attr in avdict): obj._vals_.update(avdict) return undo_funcs = [] From 2220047c3e4068aabddef62b6bf70054a8585361 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 Dec 2017 12:26:53 +0300 Subject: [PATCH 241/547] Refactoring --- pony/orm/dbschema.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index 334fd3566..241073c9c 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -350,10 +350,10 @@ def __init__(foreign_key, name, child_table, child_columns, parent_table, parent if index_name is not False: child_columns_len = len(child_columns) - for columns in child_table.indexes: - if columns[:child_columns_len] == child_columns: break - else: child_table.add_index(index_name, child_columns, is_pk=False, - is_unique=False, m2m=bool(child_table.m2m)) + if all(columns[:child_columns_len] != child_columns for columns in child_table.indexes): + child_table.add_index(index_name, child_columns, is_pk=False, + is_unique=False, m2m=bool(child_table.m2m)) + def exists(foreign_key, provider, connection, case_sensitive=True): return provider.fk_exists(connection, foreign_key.child_table.name, foreign_key.name, case_sensitive) def get_sql(foreign_key): From 23aec348d0dd459c956978134e6717aaed5d3d2f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 20 Dec 2017 13:01:14 +0300 Subject: [PATCH 242/547] Fixes #321: KeyError on delete --- pony/orm/core.py | 13 ++++++++++--- pony/orm/tests/test_indexes.py | 16 ++++++++++++++++ 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index c1dde3990..614810dc3 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1800,7 +1800,7 @@ def db_update_simple_index(cache, obj, attr, old_dbval, new_dbval): # attribute which was created or updated lately clashes with one stored in database cache_index.pop(old_dbval, None) def update_composite_index(cache, obj, attrs, prev_vals, new_vals, undo): - if prev_vals == new_vals: return + assert prev_vals != new_vals if None in prev_vals: prev_vals = None if None in new_vals: new_vals = None if prev_vals is None and new_vals is None: return @@ -1814,7 +1814,7 @@ def update_composite_index(cache, obj, attrs, prev_vals, new_vals, undo): if prev_vals is not None: del cache_index[prev_vals] undo.append((cache_index, prev_vals, new_vals)) def db_update_composite_index(cache, obj, attrs, prev_vals, new_vals): - if prev_vals == new_vals: return + assert prev_vals != new_vals cache_index = cache.indexes[attrs] if None not in new_vals: obj2 = cache_index.setdefault(new_vals, obj) @@ -4688,6 +4688,7 @@ def set(obj, **kwargs): for attr in avdict: if attr not in obj._vals_ and attr.reverse and not attr.reverse.is_collection: attr.load(obj) # loading of one-to-one relations + if wbits is not None: new_wbits = wbits for attr in avdict: new_wbits |= obj._bits_[attr] @@ -4699,10 +4700,16 @@ def set(obj, **kwargs): obj._save_pos_ = len(objects_to_save) objects_to_save.append(obj) cache.modified = True + if not collection_avdict: if not any(attr.reverse or attr.is_part_of_unique_index for attr in avdict): obj._vals_.update(avdict) return + + for attr, value in items_list(avdict): + if value == get_val(attr): + avdict.pop(attr) + undo_funcs = [] undo = [] def undo_func(): @@ -4721,7 +4728,7 @@ def undo_func(): if attr not in avdict: continue new_val = avdict[attr] old_val = get_val(attr) - if old_val != new_val: cache.update_simple_index(obj, attr, old_val, new_val, undo) + cache.update_simple_index(obj, attr, old_val, new_val, undo) for attrs in obj._composite_keys_: if any(attr in avdict for attr in attrs): vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! diff --git a/pony/orm/tests/test_indexes.py b/pony/orm/tests/test_indexes.py index eef32618f..f6399a147 100644 --- a/pony/orm/tests/test_indexes.py +++ b/pony/orm/tests/test_indexes.py @@ -76,5 +76,21 @@ class User(db.Entity): u = User[1] self.assertEqual(u.name, 'B') + def test_4(self): # issue 321 + db = Database('sqlite', ':memory:') + class Person(db.Entity): + name = Required(str) + age = Required(int) + composite_key(name, age) + + db.generate_mapping(create_tables=True) + with db_session: + p1 = Person(id=1, name='John', age=19) + + with db_session: + p1 = Person[1] + p1.set(name='John', age=19) + p1.delete() + if __name__ == '__main__': unittest.main() From 0b7e14afa5e21d1a358e56c3d46189723678c06f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 12 Jan 2018 23:34:55 +0300 Subject: [PATCH 243/547] Fix Python implementation of between(x, a, b) --- pony/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/utils/utils.py b/pony/utils/utils.py index f717c98ee..ed5977929 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -517,7 +517,7 @@ def concat(*args): return ''.join(tostring(arg) for arg in args) def between(a, x, y): - return a <= x <= y + return x <= a <= y def is_utf8(encoding): return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') From 3a16a2525bde9cf2f12210ecccd1ee0e85a31d16 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 12 Jan 2018 23:41:08 +0300 Subject: [PATCH 244/547] Rename between() arguments: between(a, x, y) -> between(x, a, b) --- pony/utils/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/utils/utils.py b/pony/utils/utils.py index ed5977929..c5bdfa312 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -516,8 +516,8 @@ def distinct(iter): def concat(*args): return ''.join(tostring(arg) for arg in args) -def between(a, x, y): - return x <= a <= y +def between(x, a, b): + return a <= x <= b def is_utf8(encoding): return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') From d3bcdb16f8bdc9b0490e29f7f813152006ec079f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 10 Feb 2018 00:51:38 +0300 Subject: [PATCH 245/547] Minor refactorings --- pony/orm/core.py | 1 - pony/orm/dbproviders/mysql.py | 1 - pony/orm/dbproviders/oracle.py | 6 ++---- pony/orm/dbproviders/postgres.py | 9 +-------- 4 files changed, 3 insertions(+), 14 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 614810dc3..64bbc1536 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1000,7 +1000,6 @@ def get_columns(table, column_names): @cut_traceback @db_session(ddl=True) def drop_table(database, table_name, if_exists=False, with_all_data=False): - table_name = database._get_table_name(table_name) database._drop_tables([ table_name ], if_exists, with_all_data, try_normalized=True) def _get_table_name(database, table_name): if isinstance(table_name, EntityMeta): diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 968192317..4e2109b63 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -288,7 +288,6 @@ def release(provider, connection, cache=None): raise DBAPIProvider.release(provider, connection, cache) - def table_exists(provider, connection, table_name, case_sensitive=True): db_name, table_name = provider.split_table_name(table_name) cursor = connection.cursor() diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 7ef0a148e..56909572a 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -528,15 +528,13 @@ def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True): return row[0] if row is not None else None def table_has_data(provider, connection, table_name): - table_name = provider.quote_name(table_name) cursor = connection.cursor() - cursor.execute('SELECT 1 FROM %s WHERE ROWNUM = 1' % table_name) + cursor.execute('SELECT 1 FROM %s WHERE ROWNUM = 1' % provider.quote_name(table_name)) return cursor.fetchone() is not None def drop_table(provider, connection, table_name): - table_name = provider.quote_name(table_name) cursor = connection.cursor() - sql = 'DROP TABLE %s CASCADE CONSTRAINTS' % table_name + sql = 'DROP TABLE %s CASCADE CONSTRAINTS' % provider.quote_name(table_name) cursor.execute(sql) provider_cls = OraProvider diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 0e0b24e75..7c203439f 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -259,16 +259,9 @@ def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True): row = cursor.fetchone() return row[0] if row is not None else None - def table_has_data(provider, connection, table_name): - table_name = provider.quote_name(table_name) - cursor = connection.cursor() - cursor.execute('SELECT 1 FROM %s LIMIT 1' % table_name) - return cursor.fetchone() is not None - def drop_table(provider, connection, table_name): - table_name = provider.quote_name(table_name) cursor = connection.cursor() - sql = 'DROP TABLE %s CASCADE' % table_name + sql = 'DROP TABLE %s CASCADE' % provider.quote_name(table_name) cursor.execute(sql) converter_classes = [ From 310995c04326a88c737e26bc8a15832496d8ae2c Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 10 Feb 2018 00:53:16 +0300 Subject: [PATCH 246/547] Fixes #251: correct dealing with qualified table names --- pony/orm/core.py | 42 ++++++++++++++++++++++------------ pony/orm/dbapiprovider.py | 11 +++++---- pony/orm/dbschema.py | 8 +++---- pony/orm/tests/test_diagram.py | 2 +- 4 files changed, 37 insertions(+), 26 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 64bbc1536..a4defd611 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -896,12 +896,15 @@ def get_columns(table, column_names): if not attr.table: seq_counter = itertools.count(2) while m2m_table is not None: - new_table_name = table_name + '_%d' % next(seq_counter) + if isinstance(table_name, basestring): + new_table_name = table_name + '_%d' % next(seq_counter) + else: + schema_name, base_name = provider.split_table_name(table_name) + new_table_name = schema_name, base_name + '_%d' % next(seq_counter) m2m_table = schema.tables.get(new_table_name) table_name = new_table_name - elif m2m_table.entities or m2m_table.m2m: - if isinstance(table_name, tuple): table_name = '.'.join(table_name) - throw(MappingError, "Table name '%s' is already in use" % table_name) + elif m2m_table.entities or m2m_table.m2m: throw(MappingError, + "Table name %s is already in use" % provider.format_table_name(table_name)) else: throw(NotImplementedError) attr.table = reverse.table = table_name m2m_table = schema.add_table(table_name) @@ -1013,9 +1016,13 @@ def _get_table_name(database, table_name): elif table_name is None: if database.schema is None: throw(MappingError, 'No mapping was generated for the database') else: throw(TypeError, 'Table name cannot be None') - elif not isinstance(table_name, basestring): - throw(TypeError, 'Invalid table name: %r' % table_name) - table_name = table_name[:] # table_name = templating.plainstr(table_name) + elif isinstance(table_name, tuple): + for component in table_name: + if not isinstance(component, basestring): + throw(TypeError, 'Invalid table name component: {}'.format(component)) + elif isinstance(table_name, basestring): + table_name = table_name[:] # table_name = templating.plainstr(table_name) + else: throw(TypeError, 'Invalid table name: {}'.format(table_name)) return table_name @cut_traceback @db_session(ddl=True) @@ -1032,19 +1039,24 @@ def _drop_tables(database, table_names, if_exists, with_all_data, try_normalized if provider.table_exists(connection, table_name): existed_tables.append(table_name) elif not if_exists: if try_normalized: - normalized_table_name = provider.normalize_name(table_name) - if normalized_table_name != table_name \ - and provider.table_exists(connection, normalized_table_name): - throw(TableDoesNotExist, 'Table %s does not exist (probably you meant table %s)' - % (table_name, normalized_table_name)) - throw(TableDoesNotExist, 'Table %s does not exist' % table_name) + if isinstance(table_name, basestring): + normalized_table_name = provider.normalize_name(table_name) + else: + schema_name, base_name = provider.split_table_name(table_name) + normalized_table_name = schema_name, provider.normalize_name(base_name) + if normalized_table_name != table_name and provider.table_exists(connection, normalized_table_name): + throw(TableDoesNotExist, 'Table %s does not exist (probably you meant table %s)' % ( + provider.format_table_name(table_name), + provider.format_table_name(normalized_table_name))) + throw(TableDoesNotExist, 'Table %s does not exist' % provider.format_table_name(table_name)) if not with_all_data: for table_name in existed_tables: if provider.table_has_data(connection, table_name): throw(TableIsNotEmpty, 'Cannot drop table %s because it is not empty. Specify option ' - 'with_all_data=True if you want to drop table with all data' % table_name) + 'with_all_data=True if you want to drop table with all data' + % provider.format_table_name(table_name)) for table_name in existed_tables: - if local.debug: log_orm('DROPPING TABLE %s' % table_name) + if local.debug: log_orm('DROPPING TABLE %s' % provider.format_table_name(table_name)) provider.drop_table(connection, table_name) @cut_traceback @db_session(ddl=True) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 610eae237..9f3540a39 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -153,7 +153,7 @@ def get_default_m2m_column_names(provider, entity): return [ normalize(prefix + column) for column in columns ] def get_default_index_name(provider, table_name, column_names, is_pk=False, is_unique=False, m2m=False): - if is_pk: index_name = 'pk_%s' % table_name + if is_pk: index_name = 'pk_%s' % provider.base_name(table_name) else: if is_unique: template = 'unq_%(tname)s__%(cnames)s' elif m2m: template = 'idx_%(tname)s' @@ -191,6 +191,9 @@ def quote_name(provider, name): return quote_char + name + quote_char return '.'.join(provider.quote_name(item) for item in name) + def format_table_name(provider, name): + return provider.quote_name(name) + def normalize_vars(provider, vars, vartypes): pass @@ -287,9 +290,8 @@ def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True): throw(NotImplementedError) def table_has_data(provider, connection, table_name): - table_name = provider.quote_name(table_name) cursor = connection.cursor() - cursor.execute('SELECT 1 FROM %s LIMIT 1' % table_name) + cursor.execute('SELECT 1 FROM %s LIMIT 1' % provider.quote_name(table_name)) return cursor.fetchone() is not None def disable_fk_checks(provider, connection): @@ -299,9 +301,8 @@ def enable_fk_checks(provider, connection, prev_state): pass def drop_table(provider, connection, table_name): - table_name = provider.quote_name(table_name) cursor = connection.cursor() - sql = 'DROP TABLE %s' % table_name + sql = 'DROP TABLE %s' % provider.quote_name(table_name) cursor.execute(sql) class Pool(localbase): diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index 241073c9c..5e733ed84 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -53,9 +53,10 @@ def create_tables(schema, provider, connection): created_tables = set() for table in schema.order_tables_to_create(): for db_object in table.get_objects_to_create(created_tables): + base_name = provider.base_name(db_object.name) name = db_object.exists(provider, connection, case_sensitive=False) if name is None: db_object.create(provider, connection) - elif name != db_object.name: + elif name != base_name: quote_name = schema.provider.quote_name n1, n2 = quote_name(db_object.name), quote_name(name) tn1, tn2 = db_object.typename, db_object.typename.lower() @@ -108,10 +109,7 @@ def __init__(table, name, schema, entity=None): table.options = entity._table_options_ table.m2m = set() def __repr__(table): - table_name = table.name - if isinstance(table_name, tuple): - table_name = '.'.join(table_name) - return '' % table_name + return '' % table.schema.provider.format_table_name(table.name) def add_entity(table, entity): for e in table.entities: if e._root_ is not entity._root_: diff --git a/pony/orm/tests/test_diagram.py b/pony/orm/tests/test_diagram.py index 73f537f0e..3b525aed5 100644 --- a/pony/orm/tests/test_diagram.py +++ b/pony/orm/tests/test_diagram.py @@ -76,7 +76,7 @@ class Entity2(db.Entity): attr2 = Set(Entity1, table='Table2') db.generate_mapping() - @raises_exception(MappingError, "Table name 'Table1' is already in use") + @raises_exception(MappingError, 'Table name "Table1" is already in use') def test_diagram7(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): From cde9696a18dd4a3255a9a9d35f7d9e7838405fb3 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 28 Feb 2018 18:54:53 +0300 Subject: [PATCH 247/547] Fixes #331: Overriding __len__ in entity fails --- pony/orm/core.py | 4 +++- pony/orm/tests/test_bug_331.py | 43 ++++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 pony/orm/tests/test_bug_331.py diff --git a/pony/orm/core.py b/pony/orm/core.py index a4defd611..09eb42336 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4118,7 +4118,9 @@ def _get_from_identity_map_(entity, pkval, status, for_update=False, undo_funcs= if obj is None: with cache.flush_disabled(): - obj = obj_to_init or object.__new__(entity) + obj = obj_to_init + if obj_to_init is None: + obj = object.__new__(entity) cache.objects.add(obj) obj._pkval_ = pkval obj._status_ = status diff --git a/pony/orm/tests/test_bug_331.py b/pony/orm/tests/test_bug_331.py new file mode 100644 index 000000000..00832e484 --- /dev/null +++ b/pony/orm/tests/test_bug_331.py @@ -0,0 +1,43 @@ +import unittest + +from pony import orm + +class Test(unittest.TestCase): + def test_1(self): + db = orm.Database('sqlite', ':memory:') + + class Person(db.Entity): + name = orm.Required(str) + group = orm.Optional(lambda: Group) + + class Group(db.Entity): + title = orm.PrimaryKey(str) + persons = orm.Set(Person) + + def __len__(self): + return len(self.persons) + + db.generate_mapping(create_tables=True) + + with orm.db_session: + p1 = Person(name="Alex") + p2 = Person(name="Brad") + p3 = Person(name="Chad") + p4 = Person(name="Dylan") + p5 = Person(name="Ethan") + + g1 = Group(title="Foxes") + g2 = Group(title="Gorillas") + + g1.persons.add(p1) + g1.persons.add(p2) + g1.persons.add(p3) + g2.persons.add(p4) + g2.persons.add(p5) + orm.commit() + + foxes = Group['Foxes'] + gorillas = Group['Gorillas'] + + self.assertEqual(len(foxes), 3) + self.assertEqual(len(gorillas), 2) From 29088380efa26bc6d9be38194189c46e1ada8a30 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 1 Mar 2018 23:27:45 +0300 Subject: [PATCH 248/547] Fixes #325: duplicating percentage sign in raw SQL queries without parameters --- pony/orm/core.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 09eb42336..4e408becd 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -221,6 +221,7 @@ def adapt_sql(sql, paramstyle): result = [] args = [] kwargs = {} + original_sql = sql if paramstyle in ('format', 'pyformat'): sql = sql.replace('%', '%%') while True: try: i = sql.index('$', pos) @@ -256,16 +257,14 @@ def adapt_sql(sql, paramstyle): kwargs[key] = expr result.append('%%(%s)s' % key) else: throw(NotImplementedError) - adapted_sql = ''.join(result) - if args: - source = '(%s,)' % ', '.join(args) - code = compile(source, '', 'eval') - elif kwargs: - source = '{%s}' % ','.join('%r:%s' % item for item in kwargs.items()) + if args or kwargs: + adapted_sql = ''.join(result) + if args: source = '(%s,)' % ', '.join(args) + else: source = '{%s}' % ','.join('%r:%s' % item for item in kwargs.items()) code = compile(source, '', 'eval') else: + adapted_sql = original_sql code = compile('None', '', 'eval') - if paramstyle in ('format', 'pyformat'): sql = sql.replace('%%', '%') result = adapted_sql, code adapted_sql_cache[(sql, paramstyle)] = result return result From 43cfde0ddaadd75c8985e04176f87e390bd33f27 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 2 Mar 2018 00:52:28 +0300 Subject: [PATCH 249/547] Support of explicit casting to float in queries --- pony/orm/dbproviders/mysql.py | 2 ++ pony/orm/dbproviders/oracle.py | 2 ++ pony/orm/dbproviders/postgres.py | 2 ++ pony/orm/sqlbuilding.py | 2 ++ pony/orm/sqltranslation.py | 7 +++++++ pony/orm/tests/test_json.py | 5 +++++ 6 files changed, 20 insertions(+) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 4e2109b63..11c2410e1 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -62,6 +62,8 @@ def RTRIM(builder, expr, chars=None): return 'trim(trailing ', builder(chars), ' from ' ,builder(expr), ')' def TO_INT(builder, expr): return 'CAST(', builder(expr), ' AS SIGNED)' + def TO_REAL(builder, expr): + return 'CAST(', builder(expr), ' AS DOUBLE)' def YEAR(builder, expr): return 'year(', builder(expr), ')' def MONTH(builder, expr): diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 56909572a..0c6dcf8f3 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -210,6 +210,8 @@ def ROWID(builder, *expr_list): return builder.ALL(*expr_list) def LIMIT(builder, limit, offset=None): assert False # pragma: no cover + def TO_REAL(builder, expr): + return 'CAST(', builder(expr), ' AS NUMBER)' def DATE(builder, expr): return 'TRUNC(', builder(expr), ')' def RANDOM(builder): diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 7c203439f..bff51bc73 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -61,6 +61,8 @@ def INSERT(builder, table_name, columns, values, returning=None): return result def TO_INT(builder, expr): return '(', builder(expr), ')::int' + def TO_REAL(builder, expr): + return '(', builder(expr), ')::double precision' def DATE(builder, expr): return '(', builder(expr), ')::date' def RANDOM(builder): diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index d449eed7e..f2b233b07 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -506,6 +506,8 @@ def REPLACE(builder, str, from_, to): return 'replace(', builder(str), ', ', builder(from_), ', ', builder(to), ')' def TO_INT(builder, expr): return 'CAST(', builder(expr), ' AS integer)' + def TO_REAL(builder, expr): + return 'CAST(', builder(expr), ' AS real)' def TODAY(builder): return 'CURRENT_DATE' def NOW(builder): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index e0ff67075..5531197ba 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1092,6 +1092,8 @@ def abs(monad): throw(TypeError) def cast_from_json(monad, type): assert False, monad def to_int(monad): return NumericExprMonad(monad.translator, int, [ 'TO_INT', monad.getsql()[0] ]) + def to_real(monad): + return NumericExprMonad(monad.translator, float, [ 'TO_REAL', monad.getsql()[0] ]) class RawSQLMonad(Monad): def __init__(monad, translator, rawtype, varkey): @@ -2022,6 +2024,11 @@ class FuncIntMonad(FuncMonad): def call(monad, x): return x.to_int() +class FuncFloatMonad(FuncMonad): + func = float + def call(monad, x): + return x.to_real() + class FuncDecimalMonad(FuncMonad): func = Decimal def call(monad, x): diff --git a/pony/orm/tests/test_json.py b/pony/orm/tests/test_json.py index 721b61bc5..2ffe1fa77 100644 --- a/pony/orm/tests/test_json.py +++ b/pony/orm/tests/test_json.py @@ -654,3 +654,8 @@ def test_optimistic_check(self): flush() p1.name = 'name3' flush() + + @db_session + def test_avg(self): + result = select(avg(float(p.info['display']['size'])) for p in self.Product)[:] + self.assertEqual(1, 0) From 5183c4463386728020d5f718efee6b05500bc573 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 2 Mar 2018 01:01:39 +0300 Subject: [PATCH 250/547] Fixes #301: Aggregation Operation in JSON Column --- pony/orm/sqltranslation.py | 5 +++-- pony/orm/tests/test_json.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 5531197ba..c7a2df292 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1053,8 +1053,9 @@ def aggregate(monad, func_name): # if isinstance(expr_type, SetType): expr_type = expr_type.item_type if func_name in ('SUM', 'AVG'): if expr_type not in numeric_types: - throw(TypeError, "Function '%s' expects argument of numeric type, got %r in {EXPR}" - % (func_name, type2str(expr_type))) + if expr_type is Json: monad = monad.to_real() + else: throw(TypeError, "Function '%s' expects argument of numeric type, got %r in {EXPR}" + % (func_name, type2str(expr_type))) elif func_name in ('MIN', 'MAX'): if expr_type not in comparable_types: throw(TypeError, "Function '%s' cannot be applied to type %r in {EXPR}" diff --git a/pony/orm/tests/test_json.py b/pony/orm/tests/test_json.py index 2ffe1fa77..d223cf60f 100644 --- a/pony/orm/tests/test_json.py +++ b/pony/orm/tests/test_json.py @@ -657,5 +657,5 @@ def test_optimistic_check(self): @db_session def test_avg(self): - result = select(avg(float(p.info['display']['size'])) for p in self.Product)[:] - self.assertEqual(1, 0) + result = select(avg(p.info['display']['size']) for p in self.Product).first() + self.assertAlmostEqual(result, 9.7) From 977f89a610c6dd8640a3e939cca8f1cb6ff6fe62 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 28 Feb 2018 14:03:56 +0300 Subject: [PATCH 251/547] Explicit casting to bool in queries --- pony/orm/sqltranslation.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index c7a2df292..cd7f67c26 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2020,6 +2020,11 @@ def call(monad, source, encoding=None, errors=None): else: value = buffer(source) return translator.ConstMonad.new(translator, value) +class FuncBoolMonad(FuncMonad): + func = bool + def call(monad, x): + return x.nonzero() + class FuncIntMonad(FuncMonad): func = int def call(monad, x): From d00b311a99b422780161367ab37877850f3bc725 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 29 Nov 2017 15:49:20 +0300 Subject: [PATCH 252/547] Apply @cut_traceback decorator only when pony.MODE is 'INTERACTIVE' --- pony/orm/core.py | 42 +++++++++++++++++----------------- pony/orm/dbproviders/sqlite.py | 5 ++-- pony/utils/utils.py | 9 +++++++- 3 files changed, 32 insertions(+), 24 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 4e408becd..21a0b0771 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -26,8 +26,8 @@ OperationalError, IntegrityError, InternalError, ProgrammingError, NotSupportedError ) from pony import utils -from pony.utils import localbase, decorator, cut_traceback, throw, reraise, truncate_repr, get_lambda_args, \ - pickle_ast, unpickle_ast, deprecated, import_module, parse_expr, is_ident, tostring, strjoin, \ +from pony.utils import localbase, decorator, cut_traceback, cut_traceback_depth, throw, reraise, truncate_repr, \ + get_lambda_args, pickle_ast, unpickle_ast, deprecated, import_module, parse_expr, is_ident, tostring, strjoin, \ between, concat, coalesce __all__ = [ @@ -744,7 +744,7 @@ def rollback(database): except: transact_reraise(RollbackException, [sys.exc_info()]) @cut_traceback def execute(database, sql, globals=None, locals=None): - return database._exec_raw_sql(sql, globals, locals, frame_depth=3, start_transaction=True) + return database._exec_raw_sql(sql, globals, locals, frame_depth=cut_traceback_depth+1, start_transaction=True) def _exec_raw_sql(database, sql, globals, locals, frame_depth, start_transaction=False): provider = database.provider if provider is None: throw(MappingError, 'Database object is not bound with a provider yet') @@ -760,7 +760,7 @@ def _exec_raw_sql(database, sql, globals, locals, frame_depth, start_transaction @cut_traceback def select(database, sql, globals=None, locals=None, frame_depth=0): if not select_re.match(sql): sql = 'select ' + sql - cursor = database._exec_raw_sql(sql, globals, locals, frame_depth + 3) + cursor = database._exec_raw_sql(sql, globals, locals, frame_depth+cut_traceback_depth+1) max_fetch_count = options.MAX_FETCH_COUNT if max_fetch_count is not None: result = cursor.fetchmany(max_fetch_count) @@ -776,7 +776,7 @@ def select(database, sql, globals=None, locals=None, frame_depth=0): return [ row_class(row) for row in result ] @cut_traceback def get(database, sql, globals=None, locals=None): - rows = database.select(sql, globals, locals, frame_depth=3) + rows = database.select(sql, globals, locals, frame_depth=cut_traceback_depth+1) if not rows: throw(RowNotFound) if len(rows) > 1: throw(MultipleRowsFound) row = rows[0] @@ -784,7 +784,7 @@ def get(database, sql, globals=None, locals=None): @cut_traceback def exists(database, sql, globals=None, locals=None): if not select_re.match(sql): sql = 'select ' + sql - cursor = database._exec_raw_sql(sql, globals, locals, frame_depth=3) + cursor = database._exec_raw_sql(sql, globals, locals, frame_depth=cut_traceback_depth+1) result = cursor.fetchone() return bool(result) @cut_traceback @@ -3315,7 +3315,7 @@ def select(wrapper, *args): s = 'lambda item: JOIN(obj in item.%s)' if reverse.is_collection else 'lambda item: item.%s == obj' query = query.filter(s % reverse.name, {'obj' : obj, 'JOIN': JOIN}) if args: - func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=3) + func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=cut_traceback_depth+1) query = query.filter(func, globals, locals) return query filter = select @@ -3709,34 +3709,34 @@ def __getitem__(entity, key): return entity._find_one_(kwargs) @cut_traceback def exists(entity, *args, **kwargs): - if args: return entity._query_from_args_(args, kwargs, frame_depth=3).exists() + if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1).exists() try: obj = entity._find_one_(kwargs) except ObjectNotFound: return False except MultipleObjectsFoundError: return True return True @cut_traceback def get(entity, *args, **kwargs): - if args: return entity._query_from_args_(args, kwargs, frame_depth=3).get() + if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1).get() try: return entity._find_one_(kwargs) # can throw MultipleObjectsFoundError except ObjectNotFound: return None @cut_traceback def get_for_update(entity, *args, **kwargs): nowait = kwargs.pop('nowait', False) - if args: return entity._query_from_args_(args, kwargs, frame_depth=3).for_update(nowait).get() + if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1).for_update(nowait).get() try: return entity._find_one_(kwargs, True, nowait) # can throw MultipleObjectsFoundError except ObjectNotFound: return None @cut_traceback def get_by_sql(entity, sql, globals=None, locals=None): - objects = entity._find_by_sql_(1, sql, globals, locals, frame_depth=3) # can throw MultipleObjectsFoundError + objects = entity._find_by_sql_(1, sql, globals, locals, frame_depth=cut_traceback_depth+1) # can throw MultipleObjectsFoundError if not objects: return None assert len(objects) == 1 return objects[0] @cut_traceback def select(entity, *args): - return entity._query_from_args_(args, kwargs=None, frame_depth=3) + return entity._query_from_args_(args, kwargs=None, frame_depth=cut_traceback_depth+1) @cut_traceback def select_by_sql(entity, sql, globals=None, locals=None): - return entity._find_by_sql_(None, sql, globals, locals, frame_depth=3) + return entity._find_by_sql_(None, sql, globals, locals, frame_depth=cut_traceback_depth+1) @cut_traceback def select_random(entity, limit): if entity._pk_is_composite_: return entity.select().random(limit) @@ -5180,23 +5180,23 @@ def make_query(args, frame_depth, left_join=False): @cut_traceback def select(*args): - return make_query(args, frame_depth=3) + return make_query(args, frame_depth=cut_traceback_depth+1) @cut_traceback def left_join(*args): - return make_query(args, frame_depth=3, left_join=True) + return make_query(args, frame_depth=cut_traceback_depth+1, left_join=True) @cut_traceback def get(*args): - return make_query(args, frame_depth=3).get() + return make_query(args, frame_depth=cut_traceback_depth+1).get() @cut_traceback def exists(*args): - return make_query(args, frame_depth=3).exists() + return make_query(args, frame_depth=cut_traceback_depth+1).exists() @cut_traceback def delete(*args): - return make_query(args, frame_depth=3).delete() + return make_query(args, frame_depth=cut_traceback_depth+1).delete() def make_aggrfunc(std_func): def aggrfunc(*args, **kwargs): @@ -5549,7 +5549,7 @@ def _order_by(query, method_name, *args): return query._clone(_key=new_key, _filters=new_filters, _translator=new_translator) if isinstance(args[0], (basestring, types.FunctionType)): - func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=4) + func, globals, locals = get_globals_and_locals(args, kwargs=None, frame_depth=cut_traceback_depth+2) return query._process_lambda(func, globals, locals, order_by=True) if isinstance(args[0], RawSQL): @@ -5647,7 +5647,7 @@ def filter(query, *args, **kwargs): if isinstance(args[0], RawSQL): raw = args[0] return query.filter(lambda: raw) - func, globals, locals = get_globals_and_locals(args, kwargs, frame_depth=3) + func, globals, locals = get_globals_and_locals(args, kwargs, frame_depth=cut_traceback_depth+1) return query._process_lambda(func, globals, locals, order_by=False) if not kwargs: return query @@ -5661,7 +5661,7 @@ def where(query, *args, **kwargs): if isinstance(args[0], RawSQL): raw = args[0] return query.where(lambda: raw) - func, globals, locals = get_globals_and_locals(args, kwargs, frame_depth=3) + func, globals, locals = get_globals_and_locals(args, kwargs, frame_depth=cut_traceback_depth+1) return query._process_lambda(func, globals, locals, order_by=False, original_names=True) if not kwargs: return query diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index d3e5ba343..89814c95c 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -17,7 +17,8 @@ from pony.orm.ormtypes import Json from pony.orm.sqlbuilding import SQLBuilder, join, make_unary_func from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions -from pony.utils import datetime2timestamp, timestamp2datetime, absolutize_path, localbase, throw, reraise +from pony.utils import datetime2timestamp, timestamp2datetime, absolutize_path, localbase, throw, reraise, \ + cut_traceback_depth class SqliteExtensionUnavailable(Exception): pass @@ -368,7 +369,7 @@ def get_pool(provider, filename, create_db=False, **kwargs): # 2 - pony.dbapiprovider.DBAPIProvider.__init__() # 1 - SQLiteProvider.__init__() # 0 - pony.dbproviders.sqlite.get_pool() - filename = absolutize_path(filename, frame_depth=7) + filename = absolutize_path(filename, frame_depth=cut_traceback_depth+5) return SQLitePool(filename, create_db, **kwargs) def table_exists(provider, connection, table_name, case_sensitive=True): diff --git a/pony/utils/utils.py b/pony/utils/utils.py index c5bdfa312..c858f4699 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -77,7 +77,7 @@ def parameterized_decorator(*args, **kwargs): @decorator def cut_traceback(func, *args, **kwargs): - if not (pony.MODE == 'INTERACTIVE' and options.CUT_TRACEBACK): + if not options.CUT_TRACEBACK: return func(*args, **kwargs) try: return func(*args, **kwargs) @@ -101,6 +101,13 @@ def cut_traceback(func, *args, **kwargs): finally: del exc, full_tb, tb, last_pony_tb +cut_traceback_depth = 2 + +if pony.MODE != 'INTERACTIVE': + cut_traceback_depth = 0 + def cut_traceback(func): + return func + if PY2: exec('''def reraise(exc_type, exc, tb): try: raise exc_type, exc, tb From 84cc51afb8e2874027de775668b79cd0330ca3d2 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 29 Nov 2017 16:52:04 +0300 Subject: [PATCH 253/547] mod_wsgi detection according to official doc --- pony/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/__init__.py b/pony/__init__.py index 2f08649ef..022ca9597 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -16,8 +16,8 @@ def detect_mode(): except ImportError: return 'GAE-SERVER' return 'GAE-LOCAL' - try: mod_wsgi = sys.modules['mod_wsgi'] - except KeyError: pass + try: from mod_wsgi import version + except: pass else: return 'MOD_WSGI' if 'flup.server.fcgi' in sys.modules: return 'FCGI-FLUP' From eb15ac95596163612c003c3e4d12c3b68b323c88 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 29 Nov 2017 17:21:06 +0300 Subject: [PATCH 254/547] Fix pony.MODE detection --- pony/__init__.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/pony/__init__.py b/pony/__init__.py index 022ca9597..2702f8965 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -20,24 +20,28 @@ def detect_mode(): except: pass else: return 'MOD_WSGI' - if 'flup.server.fcgi' in sys.modules: return 'FCGI-FLUP' + try: + sys.modules['__main__'].__file__ + except AttributeError: + return 'INTERACTIVE' + if 'flup.server.fcgi' in sys.modules: return 'FCGI-FLUP' if 'uwsgi' in sys.modules: return 'UWSGI' - - try: sys.modules['__main__'].__file__ - except AttributeError: return 'INTERACTIVE' - return 'CHERRYPY' + if 'flask' in sys.modules: return 'FLASK' + if 'cherrypy' in sys.modules: return 'CHERRYPY' + if 'bottle' in sys.modules: return 'BOTTLE' + return 'UNKNOWN' MODE = detect_mode() MAIN_FILE = None -if MODE in ('CHERRYPY', 'GAE-LOCAL', 'GAE-SERVER', 'FCGI-FLUP'): - MAIN_FILE = sys.modules['__main__'].__file__ -elif MODE == 'MOD_WSGI': +if MODE == 'MOD_WSGI': for module_name, module in sys.modules.items(): if module_name.startswith('_mod_wsgi_'): MAIN_FILE = module.__file__ break +elif MODE != 'INTERACTIVE': + MAIN_FILE = sys.modules['__main__'].__file__ if MAIN_FILE is not None: MAIN_DIR = dirname(MAIN_FILE) else: MAIN_DIR = None From 689dd7aded4c66c28f099e0851676327bf574c3a Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 18 Oct 2017 15:42:03 +0300 Subject: [PATCH 255/547] Init monad.node as None to satisfy PyCharm --- pony/orm/sqltranslation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index cd7f67c26..a3396976b 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -999,6 +999,7 @@ class Monad(with_metaclass(MonadMeta)): disable_distinct = False disable_ordering = False def __init__(monad, translator, type): + monad.node = None monad.translator = translator monad.type = type monad.mixin_init() From 40987be53b9964923b498c627c8830e7526b19c7 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 18 Oct 2017 15:27:04 +0300 Subject: [PATCH 256/547] Remove translator argument from MethodMonad --- pony/orm/sqltranslation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index a3396976b..a0e90e7ae 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -101,7 +101,7 @@ def dispatch_external(translator, node): if not isinstance(obj, EntityMeta): throw(NotImplementedError) entity_monad = translator.EntityMonad(translator, obj) if obj.__class__.__dict__.get(func.__name__) is not func: throw(NotImplementedError) - monad = translator.MethodMonad(translator, entity_monad, func.__name__) + monad = translator.MethodMonad(entity_monad, func.__name__) elif isinstance(node, ast.Name) and node.name in ('True', 'False'): value = True if node.name == 'True' else False monad = translator.ConstMonad.new(translator, value) @@ -1017,7 +1017,7 @@ def getattr(monad, attrname): if not hasattr(monad, 'call_' + attrname): throw(AttributeError, '%r object has no attribute %r' % (type2str(monad.type), attrname)) translator = monad.translator - return translator.MethodMonad(translator, monad, attrname) + return translator.MethodMonad(monad, attrname) return property_method() def len(monad): throw(TypeError) def count(monad): @@ -1169,8 +1169,8 @@ def raise_forgot_parentheses(monad): throw(TranslationError, 'You seems to forgot parentheses after %s' % ast2src(monad.node)) class MethodMonad(Monad): - def __init__(monad, translator, parent, attrname): - Monad.__init__(monad, translator, 'METHOD') + def __init__(monad, parent, attrname): + Monad.__init__(monad, parent.translator, 'METHOD') monad.parent = parent monad.attrname = attrname def getattr(monad, attrname): From b707f7143eddb3fce18cb1d7c3b2fc6021f560fa Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 18 Oct 2017 16:20:53 +0300 Subject: [PATCH 257/547] Renaming: translator.argnames -> translator.lambda_argnames --- pony/orm/sqltranslation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index a0e90e7ae..440248b8c 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -164,7 +164,7 @@ def __init__(translator, tree, extractors, vartypes, parent_translator=None, lef assert isinstance(tree, ast.GenExprInner), tree ASTTranslator.__init__(translator, tree) translator.database = None - translator.argnames = None + translator.lambda_argnames = None translator.filter_num = parent_translator.filter_num if parent_translator is not None else 0 translator.extractors = extractors translator.vartypes = vartypes @@ -603,7 +603,7 @@ def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_ translator.filter_num = filter_num translator.extractors.update(extractors) translator.vartypes.update(vartypes) - translator.argnames = list(argnames) + translator.lambda_argnames = list(argnames) translator.original_names = original_names translator.dispatch(func_ast) if isinstance(func_ast, ast.Tuple): nodes = func_ast.nodes @@ -676,7 +676,7 @@ def postName(translator, node): name = node.name t = translator while t is not None: - argnames = t.argnames + argnames = t.lambda_argnames if argnames is not None and not t.original_names and name in argnames: i = argnames.index(name) return t.expr_monads[i] From 412afb65274f9bf9d0b2e2d2722b2eb8545eb487 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 26 Nov 2017 16:55:21 +0300 Subject: [PATCH 258/547] Typo fixed --- pony/orm/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 21a0b0771..b6faa77b8 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5350,7 +5350,7 @@ def _construct_sql_and_arguments(query, range=None, aggr_func_name=None): elif arguments_type is dict: arguments_key = tuple(sorted(iteritems(arguments))) try: hash(arguments_key) except: query_key = None # arguments are unhashable - else: query_key = sql_key + (arguments_key) + else: query_key = sql_key + (arguments_key,) else: query_key = None return sql, arguments, attr_offsets, query_key def get_sql(query): From 47dbccb3a25dcbfde8af0131d6ec76fde185fef1 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 26 Feb 2018 17:31:41 +0300 Subject: [PATCH 259/547] Remove unused code --- pony/orm/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index b6faa77b8..b54e14198 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5585,7 +5585,6 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names cells = None elif type(func) is types.FunctionType: argnames = get_lambda_args(func) - subquery = prev_translator.subquery func_id = id(func.func_code if PY2 else func.__code__) func_ast, external_names, cells = decompile(func) elif not order_by: throw(TypeError, From 6793f02e3241c75e203948c095c77b7146cd3d6b Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 13 Dec 2017 16:35:10 +0300 Subject: [PATCH 260/547] Improved error message --- pony/orm/sqltranslation.py | 2 +- pony/orm/tests/test_declarative_exceptions.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 440248b8c..9a1793449 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1015,7 +1015,7 @@ def getattr(monad, attrname): try: property_method = getattr(monad, 'attr_' + attrname) except AttributeError: if not hasattr(monad, 'call_' + attrname): - throw(AttributeError, '%r object has no attribute %r' % (type2str(monad.type), attrname)) + throw(AttributeError, '%r object has no attribute %r: {EXPR}' % (type2str(monad.type), attrname)) translator = monad.translator return translator.MethodMonad(monad, attrname) return property_method() diff --git a/pony/orm/tests/test_declarative_exceptions.py b/pony/orm/tests/test_declarative_exceptions.py index 4f5935cef..01764c405 100644 --- a/pony/orm/tests/test_declarative_exceptions.py +++ b/pony/orm/tests/test_declarative_exceptions.py @@ -134,7 +134,7 @@ def test24(self): @raises_exception(TypeError, "'chars' argument must be of '%s' type in s.name.strip(1), got: 'int'" % unicode.__name__) def test25(self): select(s.name for s in Student if s.name.strip(1)) - @raises_exception(AttributeError, "'%s' object has no attribute 'unknown'" % unicode.__name__) + @raises_exception(AttributeError, "'%s' object has no attribute 'unknown': s.name.unknown" % unicode.__name__) def test26(self): result = set(select(s for s in Student if s.name.unknown() == "joe")) @raises_exception(AttributeError, "Entity Group does not have attribute foo: s.group.foo") From c882df5937f2cde3d2bcc62bb6ca53f1269de738 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 21 Dec 2017 18:09:46 +0300 Subject: [PATCH 261/547] Functype.__repr__ added --- pony/orm/ormtypes.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index a4bc11630..75af93aa9 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -45,6 +45,8 @@ def __ne__(self, other): return type(other) is not FuncType or self.func != other.func def __hash__(self): return hash(self.func) + 1 + def __repr__(self): + return 'FuncType(%s at %d)' % (self.func.__name__, id(self.func)) class MethodType(object): __slots__ = 'obj', 'func' From c19f58b47cae48d6fb4dd3e62219ba7f57f0293a Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 28 Oct 2017 16:05:38 +0300 Subject: [PATCH 262/547] Local variable renaming --- pony/orm/core.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index b54e14198..2ad6ad558 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5299,13 +5299,13 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False translator = database._translator_cache.get(query._key) if translator is None: pickled_tree = pickle_ast(tree) - tree = unpickle_ast(pickled_tree) # tree = deepcopy(tree) + tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) translator_cls = database.provider.translator_cls - translator = translator_cls(tree, extractors, vartypes, left_join=left_join) + translator = translator_cls(tree_copy, extractors, vartypes, left_join=left_join) name_path = translator.can_be_optimized() if name_path: - tree = unpickle_ast(pickled_tree) # tree = deepcopy(tree) - try: translator = translator_cls(tree, extractors, vartypes, left_join=True, optimize=name_path) + tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) + try: translator = translator_cls(tree_copy, extractors, vartypes, left_join=True, optimize=name_path) except OptimizationFailed: translator.optimization_failed = True translator.pickled_tree = pickled_tree database._translator_cache[query._key] = translator @@ -5624,11 +5624,11 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names if not prev_optimized: name_path = new_translator.can_be_optimized() if name_path: - tree = unpickle_ast(prev_translator.pickled_tree) # tree = deepcopy(tree) + tree_copy = unpickle_ast(prev_translator.pickled_tree) # tree = deepcopy(tree) prev_extractors = prev_translator.extractors prev_vartypes = prev_translator.vartypes translator_cls = prev_translator.__class__ - new_translator = translator_cls(tree, prev_extractors, prev_vartypes, + new_translator = translator_cls(tree_copy, prev_extractors, prev_vartypes, left_join=True, optimize=name_path) new_translator = query._reapply_filters(new_translator) new_translator = new_translator.apply_lambda(filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes) From b1c0d76895fa7cde6fb8e11d3066c7980cce0887 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 13 Nov 2017 13:36:50 +0300 Subject: [PATCH 263/547] HashableDict added to pony.utils --- pony/utils/utils.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/pony/utils/utils.py b/pony/utils/utils.py index c858f4699..e56ee63ee 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -13,7 +13,7 @@ from locale import getpreferredencoding from bisect import bisect from collections import defaultdict -from functools import update_wrapper +from functools import update_wrapper, wraps from xml.etree import cElementTree import pony @@ -553,3 +553,25 @@ def unpickle_ast(pickled): def copy_ast(tree): return unpickle_ast(pickle_ast(tree)) + +def _hashable_wrap(func): + @wraps(func, assigned=('__name__', '__doc__')) + def new_func(self, *args, **kwargs): + if getattr(self, '_hash', None) is not None: + assert False, 'Cannot mutate HashableDict instance after the hash value is calculated' + return func(self, *args, **kwargs) + return new_func + +class HashableDict(dict): + def __hash__(self): + result = getattr(self, '_hash', None) + if result is None: + result = self._hash = hash(tuple(sorted(self.items()))) + return result + __setitem__ = _hashable_wrap(dict.__setitem__) + __delitem__ = _hashable_wrap(dict.__delitem__) + clear = _hashable_wrap(dict.clear) + pop = _hashable_wrap(dict.pop) + popitem = _hashable_wrap(dict.popitem) + setdefault = _hashable_wrap(dict.setdefault) + update = _hashable_wrap(dict.update) From e51e58bf7ef37e9209b4a3404a18d182e8c03ed5 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 13 Nov 2017 13:37:47 +0300 Subject: [PATCH 264/547] Use HashableDict for getattr_attrnames in create_extractors() --- pony/orm/asttranslation.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 7018a079a..4696a2a37 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -1,11 +1,11 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import basestring +from pony.py23compat import basestring, iteritems from functools import update_wrapper from pony.thirdparty.compiler import ast -from pony.utils import throw, copy_ast +from pony.utils import HashableDict, throw, copy_ast class TranslationError(Exception): pass @@ -307,8 +307,9 @@ def create_extractors(code_key, tree, globals, locals, special_functions, const_ result = None getattr_extractors = getattr_cache.get(code_key) if getattr_extractors: - getattr_attrname_values = tuple(eval(code, globals, locals) for src, code in getattr_extractors) - extractors_key = (code_key, getattr_attrname_values) + getattr_attrnames = HashableDict({src: eval(code, globals, locals) + for src, code in iteritems(getattr_extractors)}) + extractors_key = (code_key, getattr_attrnames) try: result = extractors_cache.get(extractors_key) except TypeError: @@ -328,24 +329,23 @@ def create_extractors(code_key, tree, globals, locals, special_functions, const_ extractors[src] = code getattr_extractors = {} - getattr_attrname_dict = {} + getattr_attrnames = HashableDict() for node in pretranslator.getattr_nodes: if node in pretranslator.externals: src = node.src code = extractors[src] getattr_extractors[src] = code attrname_value = eval(code, globals, locals) - getattr_attrname_dict[src] = attrname_value + getattr_attrnames[src] = attrname_value elif isinstance(node, ast.Const): attrname_value = node.value else: throw(TypeError, '`%s` should be either external expression or constant.' % ast2src(node)) if not isinstance(attrname_value, basestring): throw(TypeError, '%s: attribute name must be string. Got: %r' % (ast2src(node.parent_node), attrname_value)) node._attrname_value = attrname_value - getattr_cache[code_key] = tuple(sorted(getattr_extractors.items())) + getattr_cache[code_key] = getattr_extractors varnames = list(sorted(extractors)) - getattr_attrname_values = tuple(val for key, val in sorted(getattr_attrname_dict.items())) - extractors_key = (code_key, getattr_attrname_values) + extractors_key = (code_key, getattr_attrnames) result = extractors_cache[extractors_key] = extractors, varnames, tree, extractors_key return result From 819754c1624ec62cd60a2c5ef4acc932ee9ea9a9 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 13 Nov 2017 16:44:17 +0300 Subject: [PATCH 265/547] HashableDict.__deepcopy__() method added --- pony/utils/utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pony/utils/utils.py b/pony/utils/utils.py index e56ee63ee..3180ae879 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -1,7 +1,7 @@ #coding: cp1251 from __future__ import absolute_import, print_function -from pony.py23compat import PY2, imap, basestring, unicode, pickle +from pony.py23compat import PY2, imap, basestring, unicode, pickle, iteritems import io, re, os, os.path, sys, datetime, inspect, types, linecache, warnings, json @@ -15,6 +15,7 @@ from collections import defaultdict from functools import update_wrapper, wraps from xml.etree import cElementTree +from copy import deepcopy import pony from pony import options @@ -568,6 +569,11 @@ def __hash__(self): if result is None: result = self._hash = hash(tuple(sorted(self.items()))) return result + def __deepcopy__(self, memo): + if getattr(self, '_hash', None) is not None: + return self + return HashableDict({deepcopy(key, memo): deepcopy(value, memo) + for key, value in iteritems(self)}) __setitem__ = _hashable_wrap(dict.__setitem__) __delitem__ = _hashable_wrap(dict.__delitem__) clear = _hashable_wrap(dict.clear) From e3cc7245062f87d8a2c963b74bdadd1d3371e5f1 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 13 Nov 2017 16:46:49 +0300 Subject: [PATCH 266/547] Use HashableDict of vartypes in query._key --- pony/orm/asttranslation.py | 3 +-- pony/orm/core.py | 16 ++++++++-------- pony/orm/sqltranslation.py | 2 +- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 4696a2a37..93f172e3f 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -345,7 +345,6 @@ def create_extractors(code_key, tree, globals, locals, special_functions, const_ node._attrname_value = attrname_value getattr_cache[code_key] = getattr_extractors - varnames = list(sorted(extractors)) extractors_key = (code_key, getattr_attrnames) - result = extractors_cache[extractors_key] = extractors, varnames, tree, extractors_key + result = extractors_cache[extractors_key] = extractors, tree, extractors_key return result diff --git a/pony/orm/core.py b/pony/orm/core.py index 2ad6ad558..e21595a47 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -28,7 +28,7 @@ from pony import utils from pony.utils import localbase, decorator, cut_traceback, cut_traceback_depth, throw, reraise, truncate_repr, \ get_lambda_args, pickle_ast, unpickle_ast, deprecated, import_module, parse_expr, is_ident, tostring, strjoin, \ - between, concat, coalesce + between, concat, coalesce, HashableDict __all__ = [ 'pony', @@ -5245,7 +5245,7 @@ def extract_vars(filter_num, extractors, globals, locals, cells=None): for name, cell in cells.items(): locals[name] = cell.cell_contents vars = {} - vartypes = {} + vartypes = HashableDict() for src, code in iteritems(extractors): key = filter_num, src if src == '.0': value = locals['.0'] @@ -5276,7 +5276,7 @@ def unpickle_query(query_result): class Query(object): def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False): assert isinstance(tree, ast.GenExprInner) - extractors, varnames, tree, pretranslator_key = create_extractors( + extractors, tree, pretranslator_key = create_extractors( code_key, tree, globals, locals, special_functions, const_functions) filter_num = 0 vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) @@ -5292,8 +5292,9 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False if database is None: throw(TranslationError, 'Entity %s is not mapped to a database' % origin.__name__) if database.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % origin.__name__) database.provider.normalize_vars(vars, vartypes) + query._vars = vars - query._key = pretranslator_key, tuple(vartypes[filter_num, name] for name in varnames), left_join + query._key = pretranslator_key, vartypes, left_join query._database = database translator = database._translator_cache.get(query._key) @@ -5604,7 +5605,7 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names 'Expected: %d, got: %d' % (expr_count, len(argnames))) filter_num = len(query._filters) + 1 - extractors, varnames, func_ast, pretranslator_key = create_extractors( + extractors, func_ast, pretranslator_key = create_extractors( func_id, func_ast, globals, locals, special_functions, const_functions, argnames or prev_translator.subquery) if extractors: @@ -5612,10 +5613,9 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names query._database.provider.normalize_vars(vars, vartypes) new_query_vars = query._vars.copy() new_query_vars.update(vars) - sorted_vartypes = tuple(vartypes[filter_num, name] for name in varnames) - else: new_query_vars, vartypes, sorted_vartypes = query._vars, {}, () + else: new_query_vars, vartypes = query._vars, HashableDict() - new_key = query._key + (('order_by' if order_by else 'where' if original_names else 'filter', pretranslator_key, sorted_vartypes),) + new_key = query._key + (('order_by' if order_by else 'where' if original_names else 'filter', pretranslator_key, vartypes),) new_filters = query._filters + (('apply_lambda', filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes),) new_translator = query._database._translator_cache.get(new_key) if new_translator is None: diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 9a1793449..ba466037b 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -167,7 +167,7 @@ def __init__(translator, tree, extractors, vartypes, parent_translator=None, lef translator.lambda_argnames = None translator.filter_num = parent_translator.filter_num if parent_translator is not None else 0 translator.extractors = extractors - translator.vartypes = vartypes + translator.vartypes = vartypes.copy() translator.parent = parent_translator translator.left_join = left_join translator.optimize = optimize From ee9beb2674088424da136e13edeb44303ef2e6fd Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 26 Nov 2017 16:43:25 +0300 Subject: [PATCH 267/547] Refactoring: rename pretranslator_key -> extractors_key --- pony/orm/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index e21595a47..100e16911 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5276,7 +5276,7 @@ def unpickle_query(query_result): class Query(object): def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False): assert isinstance(tree, ast.GenExprInner) - extractors, tree, pretranslator_key = create_extractors( + extractors, tree, extractors_key = create_extractors( code_key, tree, globals, locals, special_functions, const_functions) filter_num = 0 vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) @@ -5294,7 +5294,7 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False database.provider.normalize_vars(vars, vartypes) query._vars = vars - query._key = pretranslator_key, vartypes, left_join + query._key = extractors_key, vartypes, left_join query._database = database translator = database._translator_cache.get(query._key) @@ -5605,7 +5605,7 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names 'Expected: %d, got: %d' % (expr_count, len(argnames))) filter_num = len(query._filters) + 1 - extractors, func_ast, pretranslator_key = create_extractors( + extractors, func_ast, extractors_key = create_extractors( func_id, func_ast, globals, locals, special_functions, const_functions, argnames or prev_translator.subquery) if extractors: @@ -5615,7 +5615,7 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names new_query_vars.update(vars) else: new_query_vars, vartypes = query._vars, HashableDict() - new_key = query._key + (('order_by' if order_by else 'where' if original_names else 'filter', pretranslator_key, vartypes),) + new_key = query._key + (('order_by' if order_by else 'where' if original_names else 'filter', extractors_key, vartypes),) new_filters = query._filters + (('apply_lambda', filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes),) new_translator = query._database._translator_cache.get(new_key) if new_translator is None: From 68cc4e86ebad65ba591d3ccef888526c8dcf6531 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 26 Nov 2017 17:03:41 +0300 Subject: [PATCH 268/547] Use HashableDict for arguments_key --- pony/orm/core.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 100e16911..7778c6087 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5346,9 +5346,7 @@ def _construct_sql_and_arguments(query, range=None, aggr_func_name=None): else: sql, adapter, attr_offsets = cache_entry arguments = adapter(query._vars) if query._translator.query_result_is_cacheable: - arguments_type = type(arguments) - if arguments_type is tuple: arguments_key = arguments - elif arguments_type is dict: arguments_key = tuple(sorted(iteritems(arguments))) + arguments_key = HashableDict(arguments) if type(arguments) is dict else arguments try: hash(arguments_key) except: query_key = None # arguments are unhashable else: query_key = sql_key + (arguments_key,) From 9837ff95eb4b720214e7a2d48e21b8c146a3bd23 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 29 Nov 2017 14:42:22 +0300 Subject: [PATCH 269/547] Make extractors_key a HashableDict --- pony/orm/asttranslation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 93f172e3f..073cf8045 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -309,7 +309,7 @@ def create_extractors(code_key, tree, globals, locals, special_functions, const_ if getattr_extractors: getattr_attrnames = HashableDict({src: eval(code, globals, locals) for src, code in iteritems(getattr_extractors)}) - extractors_key = (code_key, getattr_attrnames) + extractors_key = HashableDict(code_key=code_key, getattr_attrnames=getattr_attrnames) try: result = extractors_cache.get(extractors_key) except TypeError: @@ -345,6 +345,6 @@ def create_extractors(code_key, tree, globals, locals, special_functions, const_ node._attrname_value = attrname_value getattr_cache[code_key] = getattr_extractors - extractors_key = (code_key, getattr_attrnames) + extractors_key = HashableDict(code_key=code_key, getattr_attrnames=getattr_attrnames) result = extractors_cache[extractors_key] = extractors, tree, extractors_key return result From 10131f40207f2b810c1eac38ef112c844cf186b0 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 29 Nov 2017 14:43:50 +0300 Subject: [PATCH 270/547] Make query._key a HashableDict --- pony/orm/core.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 7778c6087..b6ae0e7a2 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5294,7 +5294,7 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False database.provider.normalize_vars(vars, vartypes) query._vars = vars - query._key = extractors_key, vartypes, left_join + query._key = HashableDict(extractors_key, vartypes=vartypes, left_join=left_join, filters=()) query._database = database translator = database._translator_cache.get(query._key) @@ -5332,8 +5332,8 @@ def _construct_sql_and_arguments(query, range=None, aggr_func_name=None): attrs_to_prefetch = tuple(sorted(query._attrs_to_prefetch_dict.get(expr_type, ()))) else: attrs_to_prefetch = () - sql_key = query._key + (range, query._distinct, aggr_func_name, query._for_update, query._nowait, - options.INNER_JOIN_SYNTAX, attrs_to_prefetch) + sql_key = (query._key, range, query._distinct, aggr_func_name, query._for_update, query._nowait, + options.INNER_JOIN_SYNTAX, attrs_to_prefetch) database = query._database cache_entry = database._constructed_sql_cache.get(sql_key) if cache_entry is None: @@ -5508,7 +5508,7 @@ def delete(query, bulk=None): for obj in objects: obj._delete_() return len(objects) translator = query._translator - sql_key = query._key + ('DELETE',) + sql_key = HashableDict(query._key, sql_command='DELETE') database = query._database cache = database._get_cache() cache_entry = database._constructed_sql_cache.get(sql_key) @@ -5539,7 +5539,7 @@ def _order_by(query, method_name, *args): if args[0] is None: if len(args) > 1: throw(TypeError, 'When first argument of %s() method is None, it must be the only argument' % method_name) tup = (('without_order',),) - new_key = query._key + tup + new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + tup new_translator = query._database._translator_cache.get(new_key) if new_translator is None: @@ -5564,7 +5564,7 @@ def _order_by(query, method_name, *args): throw(TypeError, 'order_by() method receive invalid combination of arguments') tup = (('order_by_numbers' if numbers else 'order_by_attributes', args),) - new_key = query._key + tup + new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + tup new_translator = query._database._translator_cache.get(new_key) if new_translator is None: @@ -5612,8 +5612,8 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names new_query_vars = query._vars.copy() new_query_vars.update(vars) else: new_query_vars, vartypes = query._vars, HashableDict() - - new_key = query._key + (('order_by' if order_by else 'where' if original_names else 'filter', extractors_key, vartypes),) + tup = (('order_by' if order_by else 'where' if original_names else 'filter', extractors_key, vartypes),) + new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + (('apply_lambda', filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes),) new_translator = query._database._translator_cache.get(new_key) if new_translator is None: @@ -5693,7 +5693,7 @@ def _apply_kwargs(query, kwargs, original_names=False): filterattrs = tuple(filterattrs) tup = (('apply_kwfilters', filterattrs, original_names),) - new_key = query._key + tup + new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + tup new_translator = query._database._translator_cache.get(new_key) if new_translator is None: From 076cfa2baf46e8d528c00bde0e0d6e1e03ca1455 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 14 Mar 2018 20:45:22 +0300 Subject: [PATCH 271/547] Fix unexpected_args() helper function --- pony/orm/dbapiprovider.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 9f3540a39..a81956b7d 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -69,9 +69,9 @@ def wrap_dbapi_exceptions(func, provider, *args, **kwargs): except dbapi_module.Warning as e: raise Warning(e) def unexpected_args(attr, args): - throw(TypeError, - 'Unexpected positional argument%s for attribute %s: %r' - % ((args > 1 and 's' or ''), attr, ', '.join(repr(arg) for arg in args))) + throw(TypeError, 'Unexpected positional argument{} for attribute {}: {}'.format( + len(args) > 1 and 's' or '', attr, ', '.join(repr(arg) for arg in args)) + ) version_re = re.compile('[0-9\.]+') From c2891ba5cb5aae3ad1377b7431505da3617d4470 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 5 May 2018 01:30:51 +0300 Subject: [PATCH 272/547] Fixes #325: $$ should be replaced with $ in raw sql queries disregarding the precence of $parameters --- pony/orm/core.py | 2 +- pony/orm/tests/test_crud_raw_sql.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index b6ae0e7a2..686056087 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -263,7 +263,7 @@ def adapt_sql(sql, paramstyle): else: source = '{%s}' % ','.join('%r:%s' % item for item in kwargs.items()) code = compile(source, '', 'eval') else: - adapted_sql = original_sql + adapted_sql = original_sql.replace('$$', '$') code = compile('None', '', 'eval') result = adapted_sql, code adapted_sql_cache[(sql, paramstyle)] = result diff --git a/pony/orm/tests/test_crud_raw_sql.py b/pony/orm/tests/test_crud_raw_sql.py index 43af748a5..1ebab6651 100644 --- a/pony/orm/tests/test_crud_raw_sql.py +++ b/pony/orm/tests/test_crud_raw_sql.py @@ -60,5 +60,22 @@ def test3(self): def test4(self): students = Student.select(123) + def test5(self): + x = 1 + y = 30 + cursor = db.execute("select name from Student where id = $x and age = $y") + self.assertEqual(cursor.fetchone()[0], 'A') + + def test6(self): + x = 1 + y = 30 + cursor = db.execute("select name, 'abc$$def%' from Student where id = $x and age = $y") + self.assertEqual(cursor.fetchone(), ('A', 'abc$def%')) + + def test7(self): + cursor = db.execute("select name, 'abc$$def%' from Student where id = 1") + self.assertEqual(cursor.fetchone(), ('A', 'abc$def%')) + + if __name__ == '__main__': unittest.main() From d762dc7c8037929b2c0e4beaad539bcdae15770e Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Fri, 22 Jun 2018 16:29:48 +0300 Subject: [PATCH 273/547] Decompiler fix for Python 3.6 when operation has EXTENDED_ARGUMENT --- pony/orm/decompiling.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index 8a580a15a..1d284e5e1 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -79,15 +79,19 @@ def decompile(decompiler): co_code = code.co_code free = code.co_cellvars + code.co_freevars try: - extended_arg = 0 while decompiler.pos < decompiler.end: i = decompiler.pos if i in decompiler.targets: decompiler.process_target(i) op = ord(code.co_code[i]) if PY36: - if op >= HAVE_ARGUMENT: - oparg = ord(co_code[i + 1]) | extended_arg - extended_arg = (arg << 8) if op == EXTENDED_ARG else 0 + extended_arg = 0 + oparg = ord(code.co_code[i+1]) + while op == EXTENDED_ARG: + extended_arg = (extended_arg | oparg) << 8 + i += 2 + op = ord(code.co_code[i]) + oparg = ord(code.co_code[i+1]) + oparg = None if op < HAVE_ARGUMENT else oparg | extended_arg i += 2 else: i += 1 From f763694212c07ccaf73eafd64c8f6aa901f7ac81 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 26 Jun 2018 15:53:32 +0300 Subject: [PATCH 274/547] Fixes #357: reconnect after PostgreSQL server closed the connection unexpectedly --- pony/orm/dbproviders/postgres.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index bff51bc73..deea7c9bb 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -187,8 +187,7 @@ def inspect_connection(provider, connection): provider.table_if_not_exists_syntax = provider.server_version >= 90100 def should_reconnect(provider, exc): - return isinstance(exc, psycopg2.OperationalError) \ - and exc.pgcode is exc.pgerror is exc.cursor is None + return isinstance(exc, psycopg2.OperationalError) and exc.pgcode is None def get_pool(provider, *args, **kwargs): return PGPool(provider.dbapi_module, *args, **kwargs) From 03fba284f3e3b067fc49d5f04f9097eccf88a9d0 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 30 Jun 2018 12:36:13 +0300 Subject: [PATCH 275/547] Minor bug fixed --- pony/orm/sqltranslation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index ba466037b..1a4610d40 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1674,7 +1674,7 @@ def new(translator, type, paramkey): elif type is UUID: cls = translator.UuidParamMonad elif type is Json: cls = translator.JsonParamMonad elif isinstance(type, EntityMeta): cls = translator.ObjectParamMonad - else: throw(NotImplementedError, 'Parameter {EXPR} has unsupported type %r' % (type)) + else: throw(NotImplementedError, 'Parameter {EXPR} has unsupported type %r' % (type,)) result = cls(translator, type, paramkey) result.aggregated = False return result From 1efb96893dae7f28b5cb025a7a398d4c90e3af2d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 30 Jun 2018 12:37:41 +0300 Subject: [PATCH 276/547] Fix transformer.com_generator_expression for Python 3.7 --- pony/thirdparty/compiler/transformer.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pony/thirdparty/compiler/transformer.py b/pony/thirdparty/compiler/transformer.py index dea639574..993ea4a35 100644 --- a/pony/thirdparty/compiler/transformer.py +++ b/pony/thirdparty/compiler/transformer.py @@ -31,6 +31,7 @@ from .ast import * import parser import symbol +import sys import token # Python 2.6 compatibility fix @@ -1225,6 +1226,9 @@ def com_generator_expression(self, expr, node): # comp_for: 'for' exprlist 'in' test [comp_iter] # comp_if: 'if' test [comp_iter] + if sys.version_info >= (3, 7): + node = node[1] # remove async part + lineno = node[1][2] fors = [] while node: From c670422b1454940a30c56f2e80f7634ff9506702 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 30 Jun 2018 13:50:19 +0300 Subject: [PATCH 277/547] Add support of LOAD_METHOD and CALL_METHOD for Python 3.7 --- pony/orm/decompiling.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index 1d284e5e1..df11ac89d 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -238,6 +238,15 @@ def CALL_FUNCTION_EX(decompiler, argc): star = decompiler.stack.pop() return decompiler._call_function([], star, star2) + def CALL_METHOD(decompiler, argc): + pop = decompiler.stack.pop + args = [] + for i in range(argc): + args.append(pop()) + args.reverse() + method = pop() + return ast.CallFunc(method, args) + def COMPARE_OP(decompiler, op): oper2 = decompiler.stack.pop() oper1 = decompiler.stack.pop() @@ -339,6 +348,9 @@ def LOAD_GLOBAL(decompiler, varname): decompiler.names.add(varname) return ast.Name(varname) + def LOAD_METHOD(decompiler, methname): + return decompiler.LOAD_ATTR(methname) + def LOAD_NAME(decompiler, varname): decompiler.names.add(varname) return ast.Name(varname) From d74e77392426cdcc874d3a48dd943b0cb94c7bdb Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 30 Jun 2018 13:55:57 +0300 Subject: [PATCH 278/547] Update setup.py to support Python 3.7 --- setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 612f77fe6..a089969af 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,7 @@ 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Topic :: Software Development :: Libraries', 'Topic :: Database' ] @@ -89,8 +90,8 @@ if __name__ == "__main__": pv = sys.version_info[:2] - if pv not in ((2, 7), (3, 3), (3, 4), (3, 5), (3, 6)): - s = "Sorry, but %s %s requires Python of one of the following versions: 2.7, 3.3-3.6." \ + if pv not in ((2, 7), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7)): + s = "Sorry, but %s %s requires Python of one of the following versions: 2.7, 3.3-3.7." \ " You have version %s" print(s % (name, version, sys.version.split(' ', 1)[0])) sys.exit(1) From 47ef2b5740def82ddc2fd0ff1177867d1467c412 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 30 Jun 2018 15:40:35 +0300 Subject: [PATCH 279/547] Add PYPY and PYPY2 flags to py23compat.py --- pony/py23compat.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pony/py23compat.py b/pony/py23compat.py index 51e059da0..f23b99f37 100644 --- a/pony/py23compat.py +++ b/pony/py23compat.py @@ -1,6 +1,8 @@ -import sys +import sys, platform PY2 = sys.version_info[0] == 2 +PYPY = platform.python_implementation() == 'PyPy' +PYPY2 = PYPY and PY2 if PY2: from future_builtins import zip as izip, map as imap From 35d104ad915f7c164997b3f4d60d3ce929ede9ab Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 30 Jun 2018 15:06:40 +0300 Subject: [PATCH 280/547] Fix LOOKUP_METHOD bytecode handling in PyPy --- pony/orm/decompiling.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index df11ac89d..71fda2abb 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -351,6 +351,8 @@ def LOAD_GLOBAL(decompiler, varname): def LOAD_METHOD(decompiler, methname): return decompiler.LOAD_ATTR(methname) + LOOKUP_METHOD = LOAD_METHOD # For PyPy + def LOAD_NAME(decompiler, varname): decompiler.names.add(varname) return ast.Name(varname) From 37f48667d08e66130eceecd9685500087c79190c Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 30 Jun 2018 15:06:11 +0300 Subject: [PATCH 281/547] Fix random() function handling in PyPy --- pony/orm/sqltranslation.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 1a4610d40..3c042a715 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -98,10 +98,13 @@ def dispatch_external(translator, node): monad = func_monad_class(translator, func) elif tt is MethodType: obj, func = t.obj, t.func - if not isinstance(obj, EntityMeta): throw(NotImplementedError) - entity_monad = translator.EntityMonad(translator, obj) - if obj.__class__.__dict__.get(func.__name__) is not func: throw(NotImplementedError) - monad = translator.MethodMonad(entity_monad, func.__name__) + if isinstance(obj, EntityMeta): + entity_monad = translator.EntityMonad(translator, obj) + if obj.__class__.__dict__.get(func.__name__) is not func: throw(NotImplementedError) + monad = translator.MethodMonad(entity_monad, func.__name__) + elif node.src == 'random': # For PyPy + monad = translator.FuncRandomMonad(translator, t) + else: throw(NotImplementedError) elif isinstance(node, ast.Name) and node.name in ('True', 'False'): value = True if node.name == 'True' else False monad = translator.ConstMonad.new(translator, value) From a2121f87c56e3c6c51feef9f4386a87447e210ae Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 30 Jun 2018 15:38:01 +0300 Subject: [PATCH 282/547] Fix tests for PyPy --- pony/orm/tests/test_declarative_exceptions.py | 12 +++++++++--- pony/orm/tests/test_declarative_func_monad.py | 12 +++++++----- pony/orm/tests/test_query.py | 4 +++- pony/orm/tests/test_raw_sql.py | 5 ++++- 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/pony/orm/tests/test_declarative_exceptions.py b/pony/orm/tests/test_declarative_exceptions.py index 01764c405..cd04aa263 100644 --- a/pony/orm/tests/test_declarative_exceptions.py +++ b/pony/orm/tests/test_declarative_exceptions.py @@ -1,4 +1,5 @@ from __future__ import absolute_import, print_function, division +from pony.py23compat import PYPY, PYPY2 import sys, unittest from datetime import date @@ -71,7 +72,8 @@ def test4(self): def test5(self): select(s for s in Student if s.name.upper(**{'a':'b', 'c':'d'})) - @raises_exception(ExprEvalError, "1 in 2 raises TypeError: argument of type 'int' is not iterable") + @raises_exception(ExprEvalError, "1 in 2 raises TypeError: argument of type 'int' is not iterable" if not PYPY else + "1 in 2 raises TypeError: 'int' object is not iterable") def test6(self): select(s for s in Student if 1 in 2) @raises_exception(NotImplementedError, 'Group[s.group.number]') @@ -151,7 +153,8 @@ def test29(self): @raises_exception(NotImplementedError, "date(s.id, 1, 1)") def test30(self): select(s for s in Student if s.dob < date(s.id, 1, 1)) - @raises_exception(ExprEvalError, "max() raises TypeError: max expected 1 arguments, got 0") + @raises_exception(ExprEvalError, "max() raises TypeError: max expected 1 arguments, got 0" if not PYPY else + "max() raises TypeError: max() expects at least one argument") def test31(self): select(s for s in Student if s.id < max()) @raises_exception(TypeError, "Incomparable types 'Student' and 'Course' in expression: s in s.courses") @@ -178,7 +181,10 @@ def test38(self): @raises_exception(TypeError, "strip() takes at most 1 argument (3 given)") def test39(self): select(s for s in Student if s.name.strip(1, 2, 3)) - @raises_exception(ExprEvalError, "len(1, 2) == 3 raises TypeError: len() takes exactly one argument (2 given)") + @raises_exception(ExprEvalError, + "len(1, 2) == 3 raises TypeError: len() takes exactly 1 argument (2 given)" if PYPY2 else + "len(1, 2) == 3 raises TypeError: len() takes 1 positional argument but 2 were given" if PYPY else + "len(1, 2) == 3 raises TypeError: len() takes exactly one argument (2 given)") def test40(self): select(s for s in Student if len(1, 2) == 3) @raises_exception(TypeError, "Function sum() expects query or items of numeric type, got 'Student' in sum(s for s in Student if s.group == g)") diff --git a/pony/orm/tests/test_declarative_func_monad.py b/pony/orm/tests/test_declarative_func_monad.py index 834e08cc9..c7e87f30f 100644 --- a/pony/orm/tests/test_declarative_func_monad.py +++ b/pony/orm/tests/test_declarative_func_monad.py @@ -1,5 +1,5 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PY2 +from pony.py23compat import PY2, PYPY, PYPY2 import sys, unittest from datetime import date, datetime @@ -113,10 +113,12 @@ def test_datetime_func4(self): def test_datetime_now1(self): result = set(select(s for s in Student if s.dob < date.today())) self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) - @raises_exception(ExprEvalError, "1 < datetime.now() raises TypeError: " + - ("can't compare datetime.datetime to int" if PY2 else - "unorderable types: int() < datetime.datetime()" if sys.version_info < (3, 6) else - "'<' not supported between instances of 'int' and 'datetime.datetime'")) + @raises_exception(ExprEvalError, "1 < datetime.now() raises TypeError: " + ( + "can't compare 'datetime' to 'int'" if PYPY2 else + "unorderable types: int < datetime" if PYPY else + "can't compare datetime.datetime to int" if PY2 else + "unorderable types: int() < datetime.datetime()" if sys.version_info < (3, 6) else + "'<' not supported between instances of 'int' and 'datetime.datetime'")) def test_datetime_now2(self): select(s for s in Student if 1 < datetime.now()) def test_datetime_now3(self): diff --git a/pony/orm/tests/test_query.py b/pony/orm/tests/test_query.py index 0220e73e6..5a7f2fc3d 100644 --- a/pony/orm/tests/test_query.py +++ b/pony/orm/tests/test_query.py @@ -1,4 +1,5 @@ from __future__ import absolute_import, print_function, division +from pony.py23compat import PYPY2 import unittest from datetime import date @@ -46,7 +47,8 @@ def test2(self): def test3(self): g = Group[1] select(s for s in g.students) - @raises_exception(ExprEvalError, "a raises NameError: name 'a' is not defined") + @raises_exception(ExprEvalError, "a raises NameError: global name 'a' is not defined" if PYPY2 else + "a raises NameError: name 'a' is not defined") def test4(self): select(a for s in Student) @raises_exception(TypeError, "Incomparable types '%s' and 'list' in expression: s.name == x" % unicode.__name__) diff --git a/pony/orm/tests/test_raw_sql.py b/pony/orm/tests/test_raw_sql.py index 22f3a89df..f56247b01 100644 --- a/pony/orm/tests/test_raw_sql.py +++ b/pony/orm/tests/test_raw_sql.py @@ -1,4 +1,5 @@ from __future__ import absolute_import, print_function, division +from pony.py23compat import PYPY2 import unittest from datetime import date @@ -158,7 +159,9 @@ def test_19(self): select(p for p in Person if raw_sql(p.name))[:] @db_session - @raises_exception(ExprEvalError, "raw_sql('p.dob < $x') raises NameError: name 'x' is not defined") + @raises_exception(ExprEvalError, + "raw_sql('p.dob < $x') raises NameError: global name 'x' is not defined" if PYPY2 else + "raw_sql('p.dob < $x') raises NameError: name 'x' is not defined") def test_20(self): # testing for situation where parameter variable is missing select(p for p in Person if raw_sql('p.dob < $x'))[:] From 864d7a3d8b79aa328644e78b8a9284f337945fa6 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 30 Jun 2018 15:42:03 +0300 Subject: [PATCH 283/547] Add PyPy classifier to setup.py --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index a089969af..e1851a511 100644 --- a/setup.py +++ b/setup.py @@ -65,6 +65,7 @@ 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: Implementation :: PyPy', 'Topic :: Software Development :: Libraries', 'Topic :: Database' ] From bc46d781ffd677e5f9927e29fd05e4783d40f2d4 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 3 Jul 2018 16:25:01 +0300 Subject: [PATCH 284/547] Fix tests --- pony/orm/tests/test_diagram_attribute.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/pony/orm/tests/test_diagram_attribute.py b/pony/orm/tests/test_diagram_attribute.py index 6505bd255..ec956d85a 100644 --- a/pony/orm/tests/test_diagram_attribute.py +++ b/pony/orm/tests/test_diagram_attribute.py @@ -525,17 +525,21 @@ class Entity2(db.Entity): a = Set('Entity1', py_check=lambda val: True) db.generate_mapping(create_tables=True) - @raises_exception(ValueError, "Check for attribute Entity1.a failed. Value: " + ( - "u'12345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345..." if PY2 - else "'123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456..." - )) def test_py_check_truncate(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): a = Required(str, py_check=lambda val: False) db.generate_mapping(create_tables=True) with db_session: - obj = Entity1(a='1234567890' * 1000) + try: + obj = Entity1(a='1234567890' * 1000) + except ValueError as e: + error_message = "Check for attribute Entity1.a failed. Value: " + ( + "u'12345678901234567890123456789012345678901234567890123456789012345678901234567890123456789012345..." if PY2 + else "'123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456...") + self.assertEqual(str(e), error_message) + else: + self.assert_(False) @raises_exception(ValueError, 'Value for attribute Entity1.a is too long. Max length is 10, value length is 10000') def test_str_max_len(self): From 401abdc1508d2d8e3540411ee0c32d4f991b09a8 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Tue, 3 Jul 2018 16:31:24 +0300 Subject: [PATCH 285/547] auto=True fix --- pony/orm/core.py | 2 ++ pony/orm/tests/test_diagram_attribute.py | 6 ++++++ 2 files changed, 8 insertions(+) diff --git a/pony/orm/core.py b/pony/orm/core.py index 686056087..0a291d317 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1900,6 +1900,8 @@ def __init__(attr, py_type, *args, **kwargs): attr.entity = attr.name = None attr.args = args attr.auto = kwargs.pop('auto', False) + if attr.auto and (attr.py_type not in int_types): throw(TypeError, + '`auto=True` option can be specified for `int` attributes only, not for `%s`' % (attr.py_type.__name__)) attr.cascade_delete = kwargs.pop('cascade_delete', None) attr.reverse = kwargs.pop('reverse', None) diff --git a/pony/orm/tests/test_diagram_attribute.py b/pony/orm/tests/test_diagram_attribute.py index ec956d85a..9a8997d03 100644 --- a/pony/orm/tests/test_diagram_attribute.py +++ b/pony/orm/tests/test_diagram_attribute.py @@ -233,6 +233,12 @@ class Entity2(db.Entity): d = Optional('Entity1', reverse='a') db.generate_mapping() + @raises_exception(TypeError, '`auto=True` option can be specified for `int` attributes only, not for `str`') + def test_attribute24(self): + db = Database('sqlite', ':memory:') + class Entity1(db.Entity): + a = Required(str, auto=True) + @raises_exception(TypeError, "Parameters 'column' and 'columns' cannot be specified simultaneously") def test_columns1(self): db = Database('sqlite', ':memory:') From 6ddecb0f1c06e8b45899844fa8f5e7fdf05d4677 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 5 Jul 2018 10:17:55 +0300 Subject: [PATCH 286/547] micro refactoring --- pony/orm/sqltranslation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 3c042a715..0440b9f75 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1067,7 +1067,7 @@ def aggregate(monad, func_name): else: assert False # pragma: no cover expr = monad.getsql() if len(expr) == 1: expr = expr[0] - elif translator.row_value_syntax == True: expr = ['ROW'] + expr + elif translator.row_value_syntax: expr = ['ROW'] + expr else: throw(NotImplementedError, '%s database provider does not support entities ' 'with composite primary keys inside aggregate functions. Got: {EXPR} ' From 9a15b3134a4f5c5570e91b0f949342ad9c0c0623 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 4 Jul 2018 18:29:53 +0300 Subject: [PATCH 287/547] Pass `distinct` argument to all aggregate SQLBuilder functions --- pony/orm/core.py | 4 +- pony/orm/dbproviders/sqlite.py | 12 +-- pony/orm/sqlbuilding.py | 40 +++++----- pony/orm/sqltranslation.py | 77 +++++++++---------- .../tests/test_declarative_sqltranslator2.py | 4 +- 5 files changed, 65 insertions(+), 72 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 0a291d317..e248837d8 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -3138,7 +3138,7 @@ def count(wrapper): where_list.append([ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (i, None, None), converter ] ]) if not reverse.is_collection: table_name = reverse.entity._table_ else: table_name = attr.table - sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', 'ALL' ] ], + sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None ] ], [ 'FROM', [ None, 'TABLE', table_name ] ], where_list ] sql, adapter = database._ast2sql(sql_ast) attr.cached_count_sql = sql, adapter @@ -3752,7 +3752,7 @@ def select_random(entity, limit): if max_id is None: max_id_sql = entity._cached_max_id_sql_ if max_id_sql is None: - sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'MAX', [ 'COLUMN', None, pk.column ] ] ], + sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'MAX', None, [ 'COLUMN', None, pk.column ] ] ], [ 'FROM', [ None, 'TABLE', entity._table_ ] ] ] max_id_sql, adapter = database._ast2sql(sql_ast) entity._cached_max_id_sql_ = max_id_sql diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 89814c95c..f46e385d2 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -55,6 +55,8 @@ class SQLiteTranslator(sqltranslation.SQLTranslator): class SQLiteBuilder(SQLBuilder): dialect = 'SQLite' + least_func_name = 'min' + greatest_func_name = 'max' def __init__(builder, provider, ast): builder.json1_available = provider.json1_available SQLBuilder.__init__(builder, provider, ast) @@ -118,16 +120,6 @@ def DATETIME_SUB(builder, expr, delta): if isinstance(delta, timedelta): return builder.datetime_add('datetime', expr, -delta) return 'datetime(julianday(', builder(expr), ') - ', builder(delta), ')' - def MIN(builder, *args): - if len(args) == 0: assert False # pragma: no cover - elif len(args) == 1: fname = 'MIN' - else: fname = 'min' - return fname, '(', join(', ', imap(builder, args)), ')' - def MAX(builder, *args): - if len(args) == 0: assert False # pragma: no cover - elif len(args) == 1: fname = 'MAX' - else: fname = 'max' - return fname, '(', join(', ', imap(builder, args)), ')' def RANDOM(builder): return 'rand()' # return '(random() / 9223372036854775807.0 + 1.0) / 2.0' PY_UPPER = make_unary_func('py_upper') diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index f2b233b07..b476f58e6 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -162,6 +162,8 @@ class SQLBuilder(object): composite_param_class = CompositeParam value_class = Value indent_spaces = " " * 4 + least_func_name = 'least' + greatest_func_name = 'greatest' def __init__(builder, provider, ast): builder.provider = provider builder.quote_name = provider.quote_name @@ -441,23 +443,23 @@ def NOT_IN(builder, expr1, x): return builder(expr1), ' NOT IN ', builder(x) expr_list = [ builder(expr) for expr in x ] return builder(expr1), ' NOT IN (', join(', ', expr_list), ')' - def COUNT(builder, kind, *expr_list): - if kind == 'ALL': + def COUNT(builder, distinct, *expr_list): + if not distinct: if not expr_list: return ['COUNT(*)'] return 'COUNT(', join(', ', imap(builder, expr_list)), ')' - elif kind == 'DISTINCT': - if not expr_list: throw(AstError, 'COUNT(DISTINCT) without argument') - if len(expr_list) == 1: return 'COUNT(DISTINCT ', builder(expr_list[0]), ')' - if builder.dialect == 'PostgreSQL': - return 'COUNT(DISTINCT ', builder.ROW(*expr_list), ')' - elif builder.dialect == 'MySQL': - return 'COUNT(DISTINCT ', join(', ', imap(builder, expr_list)), ')' - # Oracle and SQLite queries translated to completely different subquery syntax - else: throw(NotImplementedError) # This line must not be executed - throw(AstError, 'Invalid COUNT kind (must be ALL or DISTINCT)') - def SUM(builder, expr, distinct=False): + if not expr_list: throw(AstError, 'COUNT(DISTINCT) without argument') + if len(expr_list) == 1: + return 'COUNT(DISTINCT ', builder(expr_list[0]), ')' + + if builder.dialect == 'PostgreSQL': + return 'COUNT(DISTINCT ', builder.ROW(*expr_list), ')' + elif builder.dialect == 'MySQL': + return 'COUNT(DISTINCT ', join(', ', imap(builder, expr_list)), ')' + # Oracle and SQLite queries translated to completely different subquery syntax + else: throw(NotImplementedError) # This line must not be executed + def SUM(builder, distinct, expr): return distinct and 'coalesce(SUM(DISTINCT ' or 'coalesce(SUM(', builder(expr), '), 0)' - def AVG(builder, expr, distinct=False): + def AVG(builder, distinct, expr): return distinct and 'AVG(DISTINCT ' or 'AVG(', builder(expr), ')' UPPER = make_unary_func('upper') LOWER = make_unary_func('lower') @@ -466,15 +468,17 @@ def AVG(builder, expr, distinct=False): def COALESCE(builder, *args): if len(args) < 2: assert False # pragma: no cover return 'coalesce(', join(', ', imap(builder, args)), ')' - def MIN(builder, *args): + def MIN(builder, distinct, *args): + assert not distinct, distinct if len(args) == 0: assert False # pragma: no cover elif len(args) == 1: fname = 'MIN' - else: fname = 'least' + else: fname = builder.least_func_name return fname, '(', join(', ', imap(builder, args)), ')' - def MAX(builder, *args): + def MAX(builder, distinct, *args): + assert not distinct, distinct if len(args) == 0: assert False # pragma: no cover elif len(args) == 1: fname = 'MAX' - else: fname = 'greatest' + else: fname = builder.greatest_func_name return fname, '(', join(', ', imap(builder, args)), ')' def SUBSTR(builder, expr, start, len=None): if len is None: return 'substr(', builder(expr), ', ', builder(start), ')' diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 0440b9f75..722fcecdd 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -426,24 +426,24 @@ def construct_sql_ast(translator, range=None, distinct=None, aggr_func_name=None and len(translator.expr_columns) > 1): outer_alias = 't' if aggr_func_name == 'COUNT': - outer_aggr_ast = [ 'COUNT', 'ALL' ] + outer_aggr_ast = [ 'COUNT', None ] else: assert len(translator.expr_columns) == 1 expr_ast = translator.expr_columns[0] if expr_ast[0] == 'COLUMN': outer_alias, column_name = expr_ast[1:] - outer_aggr_ast = [ aggr_func_name, [ 'COLUMN', outer_alias, column_name ] ] + outer_aggr_ast = [ aggr_func_name, None, [ 'COLUMN', outer_alias, column_name ] ] else: select_ast = [ 'DISTINCT' if distinct else 'ALL' ] + [ [ 'AS', expr_ast, 'expr' ] ] - outer_aggr_ast = [ aggr_func_name, [ 'COLUMN', 't', 'expr' ] ] + outer_aggr_ast = [ aggr_func_name, None, [ 'COLUMN', 't', 'expr' ] ] def ast_transformer(ast): return [ 'SELECT', [ 'AGGREGATES', outer_aggr_ast ], [ 'FROM', [ outer_alias, 'SELECT', ast[1:] ] ] ] else: if aggr_func_name == 'COUNT': - if isinstance(expr_type, (tuple, EntityMeta)) and not distinct: aggr_ast = [ 'COUNT', 'ALL' ] - else: aggr_ast = [ 'COUNT', 'DISTINCT', translator.expr_columns[0] ] - else: aggr_ast = [ aggr_func_name, translator.expr_columns[0] ] + if isinstance(expr_type, (tuple, EntityMeta)) and not distinct: aggr_ast = [ 'COUNT', None ] + else: aggr_ast = [ 'COUNT', True, translator.expr_columns[0] ] + else: aggr_ast = [ aggr_func_name, None, translator.expr_columns[0] ] if aggr_ast: select_ast = [ 'AGGREGATES', aggr_ast ] elif isinstance(translator.expr_type, EntityMeta) and not translator.parent \ and not translator.aggregated and not translator.optimize: @@ -1027,10 +1027,10 @@ def count(monad): translator = monad.translator if monad.aggregated: throw(TranslationError, 'Aggregated functions cannot be nested. Got: {EXPR}') expr = monad.getsql() - count_kind = 'DISTINCT' + distinct = True if monad.type is bool: expr = [ 'CASE', None, [ [ expr[0], [ 'VALUE', 1 ] ] ], [ 'VALUE', None ] ] - count_kind = 'ALL' + distinct = None elif len(expr) == 1: expr = expr[0] elif translator.dialect == 'PostgreSQL': row = [ 'ROW' ] + expr @@ -1047,7 +1047,7 @@ def count(monad): '%s database provider does not support entities ' 'with composite primary keys inside aggregate functions. Got: {EXPR}' % translator.dialect) - result = translator.ExprMonad.new(translator, int, [ 'COUNT', count_kind, expr ]) + result = translator.ExprMonad.new(translator, int, [ 'COUNT', distinct, expr ]) result.aggregated = True return result def aggregate(monad, func_name): @@ -1075,9 +1075,8 @@ def aggregate(monad, func_name): % translator.dialect) if func_name == 'AVG': result_type = float else: result_type = expr_type - aggr_ast = [ func_name, expr ] - if getattr(monad, 'forced_distinct', False) and func_name in ('SUM', 'AVG'): - aggr_ast.append(True) + distinct = getattr(monad, 'forced_distinct', False) and func_name in ('SUM', 'AVG') + aggr_ast = [ func_name, distinct, expr ] result = translator.ExprMonad.new(translator, result_type, aggr_ast) result.aggregated = True return result @@ -2140,7 +2139,7 @@ def call(monad, x=None): translator = monad.translator if isinstance(x, translator.StringConstMonad) and x.value == '*': x = None if x is not None: return x.count() - result = translator.ExprMonad.new(translator, int, [ 'COUNT', 'ALL' ]) + result = translator.ExprMonad.new(translator, int, [ 'COUNT', None ]) result.aggregated = True return result @@ -2218,7 +2217,7 @@ def minmax(monad, sqlop, *args): for i, arg in enumerate(args): if arg.type is bool: args[i] = NumericExprMonad(translator, int, [ 'TO_INT', arg.getsql() ]) - sql = [ sqlop ] + [ arg.getsql()[0] for arg in args ] + sql = [ sqlop, None ] + [ arg.getsql()[0] for arg in args ] return translator.ExprMonad.new(translator, t, sql) class FuncSelectMonad(FuncMonad): @@ -2367,35 +2366,35 @@ def count(monad): sql_ast = make_aggr = None extra_grouping = False if not distinct and monad.tableref.name_path != translator.optimize: - make_aggr = lambda expr_list: [ 'COUNT', 'ALL' ] + make_aggr = lambda expr_list: [ 'COUNT', None ] elif len(expr_list) == 1: - make_aggr = lambda expr_list: [ 'COUNT', 'DISTINCT' ] + expr_list + make_aggr = lambda expr_list: [ 'COUNT', True ] + expr_list elif translator.dialect == 'Oracle': if monad.tableref.name_path == translator.optimize: alias, pk_columns = monad.tableref.make_join(pk_only=True) - make_aggr = lambda expr_list: [ 'COUNT', 'DISTINCT' if distinct else 'ALL', [ 'COLUMN', alias, 'ROWID' ] ] + make_aggr = lambda expr_list: [ 'COUNT', distinct, [ 'COLUMN', alias, 'ROWID' ] ] else: extra_grouping = True - if translator.hint_join: make_aggr = lambda expr_list: [ 'COUNT', 'ALL' ] - else: make_aggr = lambda expr_list: [ 'COUNT', 'ALL', [ 'COUNT', 'ALL' ] ] + if translator.hint_join: make_aggr = lambda expr_list: [ 'COUNT', None ] + else: make_aggr = lambda expr_list: [ 'COUNT', None, [ 'COUNT', None ] ] elif translator.dialect == 'PostgreSQL': row = [ 'ROW' ] + expr_list expr = [ 'CASE', None, [ [ [ 'IS_NULL', row ], [ 'VALUE', None ] ] ], row ] - make_aggr = lambda expr_list: [ 'COUNT', 'DISTINCT', expr ] + make_aggr = lambda expr_list: [ 'COUNT', True, expr ] elif translator.row_value_syntax: - make_aggr = lambda expr_list: [ 'COUNT', 'DISTINCT' ] + expr_list + make_aggr = lambda expr_list: [ 'COUNT', True ] + expr_list elif translator.dialect == 'SQLite': if not distinct: alias, pk_columns = monad.tableref.make_join(pk_only=True) - make_aggr = lambda expr_list: [ 'COUNT', 'ALL', [ 'COLUMN', alias, 'ROWID' ] ] + make_aggr = lambda expr_list: [ 'COUNT', None, [ 'COLUMN', alias, 'ROWID' ] ] elif translator.hint_join: # Same join as in Oracle extra_grouping = True - make_aggr = lambda expr_list: [ 'COUNT', 'ALL' ] + make_aggr = lambda expr_list: [ 'COUNT', None ] elif translator.sqlite_version < (3, 6, 21): alias, pk_columns = monad.tableref.make_join(pk_only=False) - make_aggr = lambda expr_list: [ 'COUNT', 'DISTINCT', [ 'COLUMN', alias, 'ROWID' ] ] + make_aggr = lambda expr_list: [ 'COUNT', True, [ 'COLUMN', alias, 'ROWID' ] ] else: - sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', 'ALL' ] ], + sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None ] ], [ 'FROM', [ 't', 'SELECT', [ [ 'DISTINCT' ] + expr_list, from_ast, [ 'WHERE' ] + outer_conditions + inner_conditions ] ] ] ] @@ -2425,10 +2424,8 @@ def aggregate(monad, func_name): % (func_name.lower(), type2str(item_type))) else: assert False # pragma: no cover - if monad.forced_distinct and func_name in ('SUM', 'AVG'): - make_aggr = lambda expr_list: [ func_name ] + expr_list + [ True ] - else: - make_aggr = lambda expr_list: [ func_name ] + expr_list + distinct = monad.forced_distinct and func_name in ('SUM', 'AVG') + make_aggr = lambda expr_list: [ func_name, distinct ] + expr_list if translator.hint_join: sql_ast, optimized = monad._joined_subselect(make_aggr, coalesce_to_zero=(func_name=='SUM')) @@ -2604,8 +2601,8 @@ def aggregate(monad, func_name): if outer_cond[0] == 'AND': subquery.outer_conditions = outer_cond[1:] else: subquery.outer_conditions = [ outer_cond ] result_type = float if func_name == 'AVG' else monad.type.item_type - aggr_ast = [ func_name, expr ] - if monad.forced_distinct and func_name in ('SUM', 'AVG'): aggr_ast.append(True) + distinct = monad.forced_distinct and func_name in ('SUM', 'AVG') + aggr_ast = [ func_name, distinct, expr ] if translator.optimize != monad.tableref.name_path: sql_ast = [ 'SELECT', [ 'AGGREGATES', aggr_ast ], subquery.from_ast, @@ -2733,14 +2730,14 @@ def count(monad): expr_type = sub.expr_type if isinstance(expr_type, (tuple, EntityMeta)): if not sub.distinct: - select_ast = [ 'AGGREGATES', [ 'COUNT', 'ALL' ] ] + select_ast = [ 'AGGREGATES', [ 'COUNT', None ] ] elif len(sub.expr_columns) == 1: - select_ast = [ 'AGGREGATES', [ 'COUNT', 'DISTINCT' ] + sub.expr_columns ] + select_ast = [ 'AGGREGATES', [ 'COUNT', True ] + sub.expr_columns ] elif translator.dialect == 'Oracle': - sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', 'ALL', [ 'COUNT', 'ALL' ] ] ], + sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None, [ 'COUNT', None ] ] ], from_ast, where_ast, [ 'GROUP_BY' ] + sub.expr_columns ] elif translator.row_value_syntax: - select_ast = [ 'AGGREGATES', [ 'COUNT', 'DISTINCT' ] + sub.expr_columns ] + select_ast = [ 'AGGREGATES', [ 'COUNT', True ] + sub.expr_columns ] elif translator.dialect == 'SQLite': if translator.sqlite_version < (3, 6, 21): if sub.aggregated: throw(TranslationError) @@ -2748,16 +2745,16 @@ def count(monad): subquery_ast = sub.shallow_copy_of_subquery_ast() from_ast, where_ast = subquery_ast[2:4] sql_ast = [ 'SELECT', - [ 'AGGREGATES', [ 'COUNT', 'DISTINCT', [ 'COLUMN', alias, 'ROWID' ] ] ], + [ 'AGGREGATES', [ 'COUNT', True, [ 'COLUMN', alias, 'ROWID' ] ] ], from_ast, where_ast ] else: alias = translator.subquery.make_alias('t') - sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', 'ALL' ] ], + sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None ] ], [ 'FROM', [ alias, 'SELECT', [ [ 'DISTINCT' ] + sub.expr_columns, from_ast, where_ast ] ] ] ] else: assert False # pragma: no cover elif len(sub.expr_columns) == 1: - select_ast = [ 'AGGREGATES', [ 'COUNT', 'DISTINCT', sub.expr_columns[0] ] ] + select_ast = [ 'AGGREGATES', [ 'COUNT', True, sub.expr_columns[0] ] ] else: throw(NotImplementedError) # pragma: no cover if sql_ast is None: sql_ast = [ 'SELECT', select_ast, from_ast, where_ast ] @@ -2780,8 +2777,8 @@ def aggregate(monad, func_name): % (func_name.lower(), type2str(expr_type))) else: assert False # pragma: no cover assert len(sub.expr_columns) == 1 - aggr_ast = [ func_name, sub.expr_columns[0] ] - if monad.forced_distinct and func_name in ('SUM', 'AVG'): aggr_ast.append(True) + distinct = monad.forced_distinct and func_name in ('SUM', 'AVG') + aggr_ast = [ func_name, distinct, sub.expr_columns[0] ] select_ast = [ 'AGGREGATES', aggr_ast ] sql_ast = [ 'SELECT', select_ast, from_ast, where_ast ] result_type = float if func_name == 'AVG' else expr_type diff --git a/pony/orm/tests/test_declarative_sqltranslator2.py b/pony/orm/tests/test_declarative_sqltranslator2.py index 87ea5000a..ba684ee7e 100644 --- a/pony/orm/tests/test_declarative_sqltranslator2.py +++ b/pony/orm/tests/test_declarative_sqltranslator2.py @@ -100,12 +100,12 @@ def test_distinct1(self): self.assertEqual(q.count(), 7) def test_distinct3(self): q = select(d for d in Department if len(s for c in d.courses for s in c.students) > len(s for s in Student)) - self.assertEqual("DISTINCT" in flatten(q._translator.conditions), True) self.assertEqual(q[:], []) + self.assertTrue('DISTINCT' in db.last_sql) def test_distinct4(self): q = select(d for d in Department if len(d.groups.students) > 3) - self.assertEqual("DISTINCT" not in flatten(q._translator.conditions), True) self.assertEqual(q[:], [Department[2]]) + self.assertTrue("DISTINCT" not in db.last_sql) def test_distinct5(self): result = set(select(s for s in Student)) self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5], Student[6], Student[7]}) From 3ac867799434956d4dcb7991b6ebcaadaadfff6a Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 4 Jul 2018 19:14:07 +0300 Subject: [PATCH 288/547] distinct option added to query.sum(), query.avg() & query.count() --- pony/orm/core.py | 26 +++++----- pony/orm/sqltranslation.py | 23 +++++---- .../tests/test_declarative_query_set_monad.py | 49 ++++++++++++++++++- 3 files changed, 74 insertions(+), 24 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index e248837d8..8432e4d07 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5327,20 +5327,21 @@ def _clone(query, **kwargs): return new_query def __reduce__(query): return unpickle_query, (query._fetch(),) - def _construct_sql_and_arguments(query, range=None, aggr_func_name=None): + def _construct_sql_and_arguments(query, range=None, aggr_func_name=None, aggr_func_distinct=None): translator = query._translator expr_type = translator.expr_type if isinstance(expr_type, EntityMeta) and query._attrs_to_prefetch_dict: attrs_to_prefetch = tuple(sorted(query._attrs_to_prefetch_dict.get(expr_type, ()))) else: attrs_to_prefetch = () - sql_key = (query._key, range, query._distinct, aggr_func_name, query._for_update, query._nowait, - options.INNER_JOIN_SYNTAX, attrs_to_prefetch) + sql_key = (query._key, range, query._distinct, aggr_func_name, aggr_func_distinct, + query._for_update, query._nowait, options.INNER_JOIN_SYNTAX, attrs_to_prefetch) database = query._database cache_entry = database._constructed_sql_cache.get(sql_key) if cache_entry is None: sql_ast, attr_offsets = translator.construct_sql_ast( - range, query._distinct, aggr_func_name, query._for_update, query._nowait, attrs_to_prefetch) + range, query._distinct, aggr_func_name, aggr_func_distinct, + query._for_update, query._nowait, attrs_to_prefetch) cache = database._get_cache() sql, adapter = database.provider.ast2sql(sql_ast) cache_entry = sql, adapter, attr_offsets @@ -5730,9 +5731,10 @@ def page(query, pagenum, pagesize=10): start = (pagenum - 1) * pagesize stop = pagenum * pagesize return query[start:stop] - def _aggregate(query, aggr_func_name): + def _aggregate(query, aggr_func_name, distinct=None): translator = query._translator - sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments(aggr_func_name=aggr_func_name) + sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments( + aggr_func_name=aggr_func_name, aggr_func_distinct=distinct) cache = query._database._get_cache() try: result = cache.query_results[query_key] except KeyError: @@ -5751,11 +5753,11 @@ def _aggregate(query, aggr_func_name): if query_key is not None: cache.query_results[query_key] = result return result @cut_traceback - def sum(query): - return query._aggregate('SUM') + def sum(query, distinct=None): + return query._aggregate('SUM', distinct) @cut_traceback - def avg(query): - return query._aggregate('AVG') + def avg(query, distinct=None): + return query._aggregate('AVG', distinct) @cut_traceback def min(query): return query._aggregate('MIN') @@ -5763,8 +5765,8 @@ def min(query): def max(query): return query._aggregate('MAX') @cut_traceback - def count(query): - return query._aggregate('COUNT') + def count(query, distinct=None): + return query._aggregate('COUNT', distinct) @cut_traceback def for_update(query, nowait=False): provider = query._database.provider diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 722fcecdd..684b883ca 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -392,8 +392,8 @@ def can_be_optimized(translator): if translator.groupby_monads: return False if len(translator.aggregated_subquery_paths) != 1: return False return next(iter(translator.aggregated_subquery_paths)) - def construct_sql_ast(translator, range=None, distinct=None, aggr_func_name=None, for_update=False, nowait=False, - attrs_to_prefetch=(), is_not_null_checks=False): + def construct_sql_ast(translator, range=None, distinct=None, aggr_func_name=None, aggr_func_distinct=None, + for_update=False, nowait=False, attrs_to_prefetch=(), is_not_null_checks=False): attr_offsets = None if distinct is None: distinct = translator.distinct ast_transformer = lambda ast: ast @@ -422,28 +422,31 @@ def construct_sql_ast(translator, range=None, distinct=None, aggr_func_name=None assert len(translator.expr_columns) == 1 aggr_ast = None if groupby_monads or (aggr_func_name == 'COUNT' and distinct - and isinstance(translator.expr_type, EntityMeta) - and len(translator.expr_columns) > 1): + and isinstance(translator.expr_type, EntityMeta) + and len(translator.expr_columns) > 1): outer_alias = 't' - if aggr_func_name == 'COUNT': + if aggr_func_name == 'COUNT' and not aggr_func_distinct: outer_aggr_ast = [ 'COUNT', None ] else: assert len(translator.expr_columns) == 1 expr_ast = translator.expr_columns[0] if expr_ast[0] == 'COLUMN': outer_alias, column_name = expr_ast[1:] - outer_aggr_ast = [ aggr_func_name, None, [ 'COLUMN', outer_alias, column_name ] ] + outer_aggr_ast = [ aggr_func_name, aggr_func_distinct, [ 'COLUMN', outer_alias, column_name ] ] else: select_ast = [ 'DISTINCT' if distinct else 'ALL' ] + [ [ 'AS', expr_ast, 'expr' ] ] - outer_aggr_ast = [ aggr_func_name, None, [ 'COLUMN', 't', 'expr' ] ] + outer_aggr_ast = [ aggr_func_name, aggr_func_distinct, [ 'COLUMN', 't', 'expr' ] ] def ast_transformer(ast): return [ 'SELECT', [ 'AGGREGATES', outer_aggr_ast ], [ 'FROM', [ outer_alias, 'SELECT', ast[1:] ] ] ] else: if aggr_func_name == 'COUNT': - if isinstance(expr_type, (tuple, EntityMeta)) and not distinct: aggr_ast = [ 'COUNT', None ] - else: aggr_ast = [ 'COUNT', True, translator.expr_columns[0] ] - else: aggr_ast = [ aggr_func_name, None, translator.expr_columns[0] ] + if isinstance(expr_type, (tuple, EntityMeta)) and not distinct and not aggr_func_distinct: + aggr_ast = [ 'COUNT', aggr_func_distinct ] + else: + aggr_ast = [ 'COUNT', True if aggr_func_distinct is None else aggr_func_distinct, + translator.expr_columns[0] ] + else: aggr_ast = [ aggr_func_name, aggr_func_distinct, translator.expr_columns[0] ] if aggr_ast: select_ast = [ 'AGGREGATES', aggr_ast ] elif isinstance(translator.expr_type, EntityMeta) and not translator.parent \ and not translator.aggregated and not translator.optimize: diff --git a/pony/orm/tests/test_declarative_query_set_monad.py b/pony/orm/tests/test_declarative_query_set_monad.py index 183a64446..195b0c16c 100644 --- a/pony/orm/tests/test_declarative_query_set_monad.py +++ b/pony/orm/tests/test_declarative_query_set_monad.py @@ -73,20 +73,49 @@ def test_count_4(self): result = set(select(c for c in Course if count(s for s in c.students) > 1)) self.assertEqual(result, {Course['C1', 1], Course['C2', 1]}) + def test_count_5(self): + result = select(c.semester for c in Course).count(distinct=True) + self.assertEqual(result, 2) + + def test_count_6(self): + result = select(c for c in Course).count() + self.assertEqual(result, 3) + self.assertTrue('DISTINCT' not in db.last_sql) + + def test_count_7(self): + result = select(c for c in Course).count(distinct=True) + self.assertEqual(result, 3) + self.assertTrue('DISTINCT' in db.last_sql) + @raises_exception(TypeError) def test_sum_1(self): result = set(select(g for g in Group if sum(s for s in Student if s.group == g) > 1)) - self.assertEqual(result, set()) @raises_exception(TypeError) def test_sum_2(self): select(g for g in Group if sum(s.name for s in Student if s.group == g) > 1) def test_sum_3(self): + result = sum(s.scholarship for s in Student) + self.assertEqual(result, 600) + + def test_sum_4(self): + result = sum(s.scholarship for s in Student if s.name == 'Unnamed') + self.assertEqual(result, 0) + + def test_sum_5(self): + result = select(c.semester for c in Course).sum() + self.assertEqual(result, 4) + + def test_sum_6(self): + result = select(c.semester for c in Course).sum(distinct=True) + self.assertEqual(result, 3) + + def test_sum_7(self): result = set(select(g for g in Group if sum(s.scholarship for s in Student if s.group == g) > 500)) self.assertEqual(result, set()) - def test_sum_4(self): + def test_sum_8(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).sum() > 200)) self.assertEqual(result, {Group[2]}) @@ -102,6 +131,10 @@ def test_min_3(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).min() == 0)) self.assertEqual(result, {Group[1]}) + def test_min_4(self): + result = select(s.scholarship for s in Student).min() + self.assertEqual(0, result) + def test_max_1(self): result = set(select(g for g in Group if max(s.scholarship for s in Student if s.group == g) > 100)) self.assertEqual(result, {Group[2]}) @@ -114,6 +147,10 @@ def test_max_3(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).max() == 100)) self.assertEqual(result, {Group[1]}) + def test_max_4(self): + result = select(s.scholarship for s in Student).max() + self.assertEqual(result, 500) + def test_avg_1(self): result = select(g for g in Group if avg(s.scholarship for s in Student if s.group == g) == 50)[:] self.assertEqual(result, [Group[1]]) @@ -122,6 +159,14 @@ def test_avg_2(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).avg() == 50)) self.assertEqual(result, {Group[1]}) + def test_avg_3(self): + result = select(c.semester for c in Course).avg() + self.assertAlmostEqual(1.33, result, places=2) + + def test_avg_4(self): + result = select(c.semester for c in Course).avg(distinct=True) + self.assertAlmostEqual(1.5, result) + def test_exists(self): result = set(select(g for g in Group if exists(s for s in g.students if s.name == 'S1'))) self.assertEqual(result, {Group[1]}) From 2ad6eeb661832f26451ccbf48086ed5b12c90a56 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Thu, 5 Jul 2018 10:09:54 +0300 Subject: [PATCH 289/547] distinct option added to QuerySetMonad methods sum(), avg() & count() --- pony/orm/sqltranslation.py | 38 +++++++++++-------- .../tests/test_declarative_query_set_monad.py | 21 ++++++++++ 2 files changed, 43 insertions(+), 16 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 684b883ca..b40f10905 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2722,9 +2722,14 @@ def negate(monad): assert sql[0] == 'EXISTS' translator = monad.translator return translator.BoolExprMonad(translator, [ 'NOT_EXISTS' ] + sql[1:]) - def count(monad): + def count(monad, distinct=None): translator = monad.translator sub = monad.subtranslator + if distinct is not None: + if isinstance(distinct, NumericConstMonad) and isinstance(distinct.value, bool): + distinct = distinct.value + else: + throw(TypeError, '`distinct` value should be True or False, got: %s' % ast2src(distinct.node)) if sub.aggregated: throw(TranslationError, 'Too complex aggregation in {EXPR}') subquery_ast = sub.shallow_copy_of_subquery_ast() from_ast, where_ast = subquery_ast[2:4] @@ -2732,15 +2737,15 @@ def count(monad): expr_type = sub.expr_type if isinstance(expr_type, (tuple, EntityMeta)): - if not sub.distinct: + if not sub.distinct and not distinct: select_ast = [ 'AGGREGATES', [ 'COUNT', None ] ] elif len(sub.expr_columns) == 1: - select_ast = [ 'AGGREGATES', [ 'COUNT', True ] + sub.expr_columns ] + select_ast = [ 'AGGREGATES', [ 'COUNT', True if distinct is None else distinct ] + sub.expr_columns ] elif translator.dialect == 'Oracle': sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None, [ 'COUNT', None ] ] ], from_ast, where_ast, [ 'GROUP_BY' ] + sub.expr_columns ] elif translator.row_value_syntax: - select_ast = [ 'AGGREGATES', [ 'COUNT', True ] + sub.expr_columns ] + select_ast = [ 'AGGREGATES', [ 'COUNT', True if distinct is None else distinct ] + sub.expr_columns ] elif translator.dialect == 'SQLite': if translator.sqlite_version < (3, 6, 21): if sub.aggregated: throw(TranslationError) @@ -2748,22 +2753,22 @@ def count(monad): subquery_ast = sub.shallow_copy_of_subquery_ast() from_ast, where_ast = subquery_ast[2:4] sql_ast = [ 'SELECT', - [ 'AGGREGATES', [ 'COUNT', True, [ 'COLUMN', alias, 'ROWID' ] ] ], + [ 'AGGREGATES', [ 'COUNT', True if distinct is None else distinct, [ 'COLUMN', alias, 'ROWID' ] ] ], from_ast, where_ast ] else: alias = translator.subquery.make_alias('t') sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None ] ], - [ 'FROM', [ alias, 'SELECT', [ - [ 'DISTINCT' ] + sub.expr_columns, from_ast, where_ast ] ] ] ] + [ 'FROM', [ alias, 'SELECT', [ [ 'DISTINCT' if distinct is not False else 'ALL' ] + + sub.expr_columns, from_ast, where_ast ] ] ] ] else: assert False # pragma: no cover elif len(sub.expr_columns) == 1: - select_ast = [ 'AGGREGATES', [ 'COUNT', True, sub.expr_columns[0] ] ] + select_ast = [ 'AGGREGATES', [ 'COUNT', True if distinct is None else distinct, sub.expr_columns[0] ] ] else: throw(NotImplementedError) # pragma: no cover if sql_ast is None: sql_ast = [ 'SELECT', select_ast, from_ast, where_ast ] return translator.ExprMonad.new(translator, int, sql_ast) len = count - def aggregate(monad, func_name): + def aggregate(monad, func_name, distinct=None): translator = monad.translator sub = monad.subtranslator if sub.aggregated: throw(TranslationError, 'Too complex aggregation in {EXPR}') @@ -2780,22 +2785,23 @@ def aggregate(monad, func_name): % (func_name.lower(), type2str(expr_type))) else: assert False # pragma: no cover assert len(sub.expr_columns) == 1 - distinct = monad.forced_distinct and func_name in ('SUM', 'AVG') + if distinct is None: + distinct = monad.forced_distinct and func_name in ('SUM', 'AVG') aggr_ast = [ func_name, distinct, sub.expr_columns[0] ] select_ast = [ 'AGGREGATES', aggr_ast ] sql_ast = [ 'SELECT', select_ast, from_ast, where_ast ] result_type = float if func_name == 'AVG' else expr_type return translator.ExprMonad.new(translator, result_type, sql_ast) - def call_count(monad): - return monad.count() - def call_sum(monad): - return monad.aggregate('SUM') + def call_count(monad, distinct=None): + return monad.count(distinct=distinct) + def call_sum(monad, distinct=None): + return monad.aggregate('SUM', distinct) def call_min(monad): return monad.aggregate('MIN') def call_max(monad): return monad.aggregate('MAX') - def call_avg(monad): - return monad.aggregate('AVG') + def call_avg(monad, distinct=None): + return monad.aggregate('AVG', distinct) def find_or_create_having_ast(subquery_ast): groupby_offset = None diff --git a/pony/orm/tests/test_declarative_query_set_monad.py b/pony/orm/tests/test_declarative_query_set_monad.py index 195b0c16c..545136014 100644 --- a/pony/orm/tests/test_declarative_query_set_monad.py +++ b/pony/orm/tests/test_declarative_query_set_monad.py @@ -69,6 +69,16 @@ def test_count_3(self): result = set(select(s for s in Student if count(c for c in s.courses) > 1)) self.assertEqual(result, {Student[2], Student[3]}) + def test_count_3a(self): + result = set(select(s for s in Student if select(c for c in s.courses).count() > 1)) + self.assertEqual(result, {Student[2], Student[3]}) + self.assertTrue('DISTINCT' in db.last_sql) + + def test_count_3b(self): + result = set(select(s for s in Student if select(c for c in s.courses).count(distinct=False) > 1)) + self.assertEqual(result, {Student[2], Student[3]}) + self.assertTrue('DISTINCT' not in db.last_sql) + def test_count_4(self): result = set(select(c for c in Course if count(s for s in c.students) > 1)) self.assertEqual(result, {Course['C1', 1], Course['C2', 1]}) @@ -118,6 +128,12 @@ def test_sum_7(self): def test_sum_8(self): result = set(select(g for g in Group if select(s.scholarship for s in g.students).sum() > 200)) self.assertEqual(result, {Group[2]}) + self.assertTrue('DISTINCT' not in db.last_sql) + + def test_sum_9(self): + result = set(select(g for g in Group if select(s.scholarship for s in g.students).sum(distinct=True) > 200)) + self.assertEqual(result, {Group[2]}) + self.assertTrue('DISTINCT' in db.last_sql) def test_min_1(self): result = set(select(g for g in Group if min(s.name for s in Student if s.group == g) == 'S1')) @@ -167,6 +183,11 @@ def test_avg_4(self): result = select(c.semester for c in Course).avg(distinct=True) self.assertAlmostEqual(1.5, result) + def test_avg_5(self): + result = set(select(g for g in Group if select(s.scholarship for s in g.students).avg(distinct=True) == 50)) + self.assertEqual(result, {Group[1]}) + self.assertTrue('AVG(DISTINCT' in db.last_sql) + def test_exists(self): result = set(select(g for g in Group if exists(s for s in g.students if s.name == 'S1'))) self.assertEqual(result, {Group[1]}) From 2d0a1e25433316b5b2192871e4c91a4e0ba54c09 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Thu, 5 Jul 2018 11:41:54 +0300 Subject: [PATCH 290/547] Add `distinct` option to Monad.count() and Monad.aggregate() methods --- pony/orm/sqlbuilding.py | 3 ++ pony/orm/sqltranslation.py | 54 ++++++++++--------- .../tests/test_declarative_query_set_monad.py | 20 +++++++ 3 files changed, 53 insertions(+), 24 deletions(-) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index b476f58e6..966bdf547 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -444,6 +444,7 @@ def NOT_IN(builder, expr1, x): expr_list = [ builder(expr) for expr in x ] return builder(expr1), ' NOT IN (', join(', ', expr_list), ')' def COUNT(builder, distinct, *expr_list): + assert distinct in (None, True, False) if not distinct: if not expr_list: return ['COUNT(*)'] return 'COUNT(', join(', ', imap(builder, expr_list)), ')' @@ -458,8 +459,10 @@ def COUNT(builder, distinct, *expr_list): # Oracle and SQLite queries translated to completely different subquery syntax else: throw(NotImplementedError) # This line must not be executed def SUM(builder, distinct, expr): + assert distinct in (None, True, False) return distinct and 'coalesce(SUM(DISTINCT ' or 'coalesce(SUM(', builder(expr), '), 0)' def AVG(builder, distinct, expr): + assert distinct in (None, True, False) return distinct and 'AVG(DISTINCT ' or 'AVG(', builder(expr), ')' UPPER = make_unary_func('upper') LOWER = make_unary_func('lower') diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index b40f10905..d7b028291 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1026,11 +1026,12 @@ def getattr(monad, attrname): return translator.MethodMonad(monad, attrname) return property_method() def len(monad): throw(TypeError) - def count(monad): + def count(monad, distinct=None): + distinct = distinct_from_monad(distinct, default=True) translator = monad.translator if monad.aggregated: throw(TranslationError, 'Aggregated functions cannot be nested. Got: {EXPR}') expr = monad.getsql() - distinct = True + if monad.type is bool: expr = [ 'CASE', None, [ [ expr[0], [ 'VALUE', 1 ] ] ], [ 'VALUE', None ] ] distinct = None @@ -1053,7 +1054,8 @@ def count(monad): result = translator.ExprMonad.new(translator, int, [ 'COUNT', distinct, expr ]) result.aggregated = True return result - def aggregate(monad, func_name): + def aggregate(monad, func_name, distinct=None): + distinct = distinct_from_monad(distinct) translator = monad.translator if monad.aggregated: throw(TranslationError, 'Aggregated functions cannot be nested. Got: {EXPR}') expr_type = monad.type @@ -1078,7 +1080,8 @@ def aggregate(monad, func_name): % translator.dialect) if func_name == 'AVG': result_type = float else: result_type = expr_type - distinct = getattr(monad, 'forced_distinct', False) and func_name in ('SUM', 'AVG') + if distinct is None: + distinct = getattr(monad, 'forced_distinct', False) and func_name in ('SUM', 'AVG') aggr_ast = [ func_name, distinct, expr ] result = translator.ExprMonad.new(translator, result_type, aggr_ast) result.aggregated = True @@ -1102,6 +1105,13 @@ def to_int(monad): def to_real(monad): return NumericExprMonad(monad.translator, float, [ 'TO_REAL', monad.getsql()[0] ]) +def distinct_from_monad(distinct, default=None): + if distinct is None: + return default + if isinstance(distinct, NumericConstMonad) and isinstance(distinct.value, bool): + return distinct.value + throw(TypeError, '`distinct` value should be True or False. Got: %s' % ast2src(distinct.node)) + class RawSQLMonad(Monad): def __init__(monad, translator, rawtype, varkey): if rawtype.result_type is None: type = rawtype @@ -1188,7 +1198,7 @@ def __call__(monad, *args, **kwargs): def contains(monad, item, not_in=False): raise_forgot_parentheses(monad) def nonzero(monad): raise_forgot_parentheses(monad) def negate(monad): raise_forgot_parentheses(monad) - def aggregate(monad, func_name): raise_forgot_parentheses(monad) + def aggregate(monad, func_name, distinct=None): raise_forgot_parentheses(monad) def __getitem__(monad, key): raise_forgot_parentheses(monad) def __add__(monad, monad2): raise_forgot_parentheses(monad) @@ -2138,10 +2148,10 @@ def call(monad, obj_monad, name_monad): class FuncCountMonad(FuncMonad): func = itertools.count, utils.count, core.count - def call(monad, x=None): + def call(monad, x=None, distinct=None): translator = monad.translator if isinstance(x, translator.StringConstMonad) and x.value == '*': x = None - if x is not None: return x.count() + if x is not None: return x.count(distinct) result = translator.ExprMonad.new(translator, int, [ 'COUNT', None ]) result.aggregated = True return result @@ -2153,13 +2163,13 @@ def call(monad, x): class FuncSumMonad(FuncMonad): func = sum, core.sum - def call(monad, x): - return x.aggregate('SUM') + def call(monad, x, distinct=None): + return x.aggregate('SUM', distinct) class FuncAvgMonad(FuncMonad): func = utils.avg, core.avg - def call(monad, x): - return x.aggregate('AVG') + def call(monad, x, distinct=None): + return x.aggregate('AVG', distinct) class FuncCoalesceMonad(FuncMonad): func = coalesce @@ -2356,8 +2366,9 @@ def requires_distinct(monad, joined=False, for_count=False): if not for_count and not translator.hint_join: return True if isinstance(monad.parent, monad.translator.AttrSetMonad): return True return False - def count(monad): + def count(monad, distinct=None): translator = monad.translator + distinct = distinct_from_monad(distinct, monad.requires_distinct(joined=translator.hint_join, for_count=True)) subquery = monad._subselect() expr_list = subquery.expr_list @@ -2365,7 +2376,6 @@ def count(monad): inner_conditions = subquery.conditions outer_conditions = subquery.outer_conditions - distinct = monad.requires_distinct(joined=translator.hint_join, for_count=True) sql_ast = make_aggr = None extra_grouping = False if not distinct and monad.tableref.name_path != translator.optimize: @@ -2413,7 +2423,8 @@ def count(monad): else: result.nogroup = True return result len = count - def aggregate(monad, func_name): + def aggregate(monad, func_name, distinct=None): + distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) translator = monad.translator item_type = monad.type.item_type @@ -2427,7 +2438,6 @@ def aggregate(monad, func_name): % (func_name.lower(), type2str(item_type))) else: assert False # pragma: no cover - distinct = monad.forced_distinct and func_name in ('SUM', 'AVG') make_aggr = lambda expr_list: [ func_name, distinct ] + expr_list if translator.hint_join: @@ -2595,7 +2605,8 @@ def __init__(monad, op, sqlop, left, right): monad.sqlop = sqlop monad.left = left monad.right = right - def aggregate(monad, func_name): + def aggregate(monad, func_name, distinct=None): + distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) translator = monad.translator subquery = Subquery(translator.subquery) expr = monad.getsql(subquery)[0] @@ -2604,7 +2615,6 @@ def aggregate(monad, func_name): if outer_cond[0] == 'AND': subquery.outer_conditions = outer_cond[1:] else: subquery.outer_conditions = [ outer_cond ] result_type = float if func_name == 'AVG' else monad.type.item_type - distinct = monad.forced_distinct and func_name in ('SUM', 'AVG') aggr_ast = [ func_name, distinct, expr ] if translator.optimize != monad.tableref.name_path: sql_ast = [ 'SELECT', [ 'AGGREGATES', aggr_ast ], @@ -2723,13 +2733,10 @@ def negate(monad): translator = monad.translator return translator.BoolExprMonad(translator, [ 'NOT_EXISTS' ] + sql[1:]) def count(monad, distinct=None): + distinct = distinct_from_monad(distinct) translator = monad.translator sub = monad.subtranslator - if distinct is not None: - if isinstance(distinct, NumericConstMonad) and isinstance(distinct.value, bool): - distinct = distinct.value - else: - throw(TypeError, '`distinct` value should be True or False, got: %s' % ast2src(distinct.node)) + if sub.aggregated: throw(TranslationError, 'Too complex aggregation in {EXPR}') subquery_ast = sub.shallow_copy_of_subquery_ast() from_ast, where_ast = subquery_ast[2:4] @@ -2769,6 +2776,7 @@ def count(monad, distinct=None): return translator.ExprMonad.new(translator, int, sql_ast) len = count def aggregate(monad, func_name, distinct=None): + distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) translator = monad.translator sub = monad.subtranslator if sub.aggregated: throw(TranslationError, 'Too complex aggregation in {EXPR}') @@ -2785,8 +2793,6 @@ def aggregate(monad, func_name, distinct=None): % (func_name.lower(), type2str(expr_type))) else: assert False # pragma: no cover assert len(sub.expr_columns) == 1 - if distinct is None: - distinct = monad.forced_distinct and func_name in ('SUM', 'AVG') aggr_ast = [ func_name, distinct, sub.expr_columns[0] ] select_ast = [ 'AGGREGATES', aggr_ast ] sql_ast = [ 'SELECT', select_ast, from_ast, where_ast ] diff --git a/pony/orm/tests/test_declarative_query_set_monad.py b/pony/orm/tests/test_declarative_query_set_monad.py index 545136014..1197234b4 100644 --- a/pony/orm/tests/test_declarative_query_set_monad.py +++ b/pony/orm/tests/test_declarative_query_set_monad.py @@ -97,6 +97,18 @@ def test_count_7(self): self.assertEqual(result, 3) self.assertTrue('DISTINCT' in db.last_sql) + def test_count_8(self): + select(count(c.semester, distinct=False) for c in Course)[:] + self.assertTrue('DISTINCT' not in db.last_sql) + + @raises_exception(TypeError, "`distinct` value should be True or False. Got: s.name.startswith('P')") + def test_count_9(self): + select(count(s, distinct=s.name.startswith('P')) for s in Student) + + def test_count_10(self): + select(count('*', distinct=True) for s in Student)[:] + self.assertTrue('DISTINCT' not in db.last_sql) + @raises_exception(TypeError) def test_sum_1(self): result = set(select(g for g in Group if sum(s for s in Student if s.group == g) > 1)) @@ -135,6 +147,10 @@ def test_sum_9(self): self.assertEqual(result, {Group[2]}) self.assertTrue('DISTINCT' in db.last_sql) + def test_sum_10(self): + select(sum(s.scholarship, distinct=True) for s in Student)[:] + self.assertTrue('SUM(DISTINCT' in db.last_sql) + def test_min_1(self): result = set(select(g for g in Group if min(s.name for s in Student if s.group == g) == 'S1')) self.assertEqual(result, {Group[1]}) @@ -188,6 +204,10 @@ def test_avg_5(self): self.assertEqual(result, {Group[1]}) self.assertTrue('AVG(DISTINCT' in db.last_sql) + def test_avg_6(self): + select(avg(s.scholarship, distinct=True) for s in Student)[:] + self.assertTrue('AVG(DISTINCT' in db.last_sql) + def test_exists(self): result = set(select(g for g in Group if exists(s for s in g.students if s.name == 'S1'))) self.assertEqual(result, {Group[1]}) From c3136f2f1f9925509f6d23f749890dca1fbd8621 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Thu, 5 Jul 2018 14:50:24 +0300 Subject: [PATCH 291/547] SQLBuilder.TO_STR() --- pony/orm/dbproviders/mysql.py | 2 ++ pony/orm/dbproviders/oracle.py | 2 ++ pony/orm/dbproviders/postgres.py | 2 ++ pony/orm/sqlbuilding.py | 2 ++ 4 files changed, 8 insertions(+) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 11c2410e1..14029a343 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -64,6 +64,8 @@ def TO_INT(builder, expr): return 'CAST(', builder(expr), ' AS SIGNED)' def TO_REAL(builder, expr): return 'CAST(', builder(expr), ' AS DOUBLE)' + def TO_STR(builder, expr): + return 'CAST(', builder(expr), ' AS CHAR)' def YEAR(builder, expr): return 'year(', builder(expr), ')' def MONTH(builder, expr): diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 0c6dcf8f3..07af1818d 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -212,6 +212,8 @@ def LIMIT(builder, limit, offset=None): assert False # pragma: no cover def TO_REAL(builder, expr): return 'CAST(', builder(expr), ' AS NUMBER)' + def TO_STR(builder, expr): + return 'TO_CHAR(', builder(expr), ')' def DATE(builder, expr): return 'TRUNC(', builder(expr), ')' def RANDOM(builder): diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index deea7c9bb..d77967c19 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -61,6 +61,8 @@ def INSERT(builder, table_name, columns, values, returning=None): return result def TO_INT(builder, expr): return '(', builder(expr), ')::int' + def TO_STR(builder, expr): + return '(', builder(expr), ')::text' def TO_REAL(builder, expr): return '(', builder(expr), ')::double precision' def DATE(builder, expr): diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 966bdf547..313c4508d 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -513,6 +513,8 @@ def REPLACE(builder, str, from_, to): return 'replace(', builder(str), ', ', builder(from_), ', ', builder(to), ')' def TO_INT(builder, expr): return 'CAST(', builder(expr), ' AS integer)' + def TO_STR(builder, expr): + return 'CAST(', builder(expr), ' AS text)' def TO_REAL(builder, expr): return 'CAST(', builder(expr), ' AS real)' def TODAY(builder): From faf015439370ecc3d44c39d4ef027761d23ac19b Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Thu, 5 Jul 2018 12:36:15 +0300 Subject: [PATCH 292/547] group_concat() method for Query --- pony/orm/core.py | 26 +++++++++++++++++++------- pony/orm/sqlbuilding.py | 6 ++++++ pony/orm/sqltranslation.py | 14 +++++++++++--- pony/utils/utils.py | 5 +++++ 4 files changed, 41 insertions(+), 10 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 8432e4d07..759341f08 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -57,7 +57,7 @@ 'select', 'left_join', 'get', 'exists', 'delete', - 'count', 'sum', 'min', 'max', 'avg', 'distinct', + 'count', 'sum', 'min', 'max', 'avg', 'group_concat', 'distinct', 'JOIN', 'desc', 'between', 'concat', 'coalesce', 'raw_sql', @@ -5219,6 +5219,7 @@ def aggrfunc(*args, **kwargs): min = make_aggrfunc(builtins.min) max = make_aggrfunc(builtins.max) avg = make_aggrfunc(utils.avg) +group_concat = make_aggrfunc(utils.group_concat) distinct = make_aggrfunc(utils.distinct) @@ -5327,20 +5328,20 @@ def _clone(query, **kwargs): return new_query def __reduce__(query): return unpickle_query, (query._fetch(),) - def _construct_sql_and_arguments(query, range=None, aggr_func_name=None, aggr_func_distinct=None): + def _construct_sql_and_arguments(query, range=None, aggr_func_name=None, aggr_func_distinct=None, sep=None): translator = query._translator expr_type = translator.expr_type if isinstance(expr_type, EntityMeta) and query._attrs_to_prefetch_dict: attrs_to_prefetch = tuple(sorted(query._attrs_to_prefetch_dict.get(expr_type, ()))) else: attrs_to_prefetch = () - sql_key = (query._key, range, query._distinct, aggr_func_name, aggr_func_distinct, + sql_key = (query._key, range, query._distinct, (aggr_func_name, aggr_func_distinct, sep), query._for_update, query._nowait, options.INNER_JOIN_SYNTAX, attrs_to_prefetch) database = query._database cache_entry = database._constructed_sql_cache.get(sql_key) if cache_entry is None: sql_ast, attr_offsets = translator.construct_sql_ast( - range, query._distinct, aggr_func_name, aggr_func_distinct, + range, query._distinct, aggr_func_name, aggr_func_distinct, sep, query._for_update, query._nowait, attrs_to_prefetch) cache = database._get_cache() sql, adapter = database.provider.ast2sql(sql_ast) @@ -5731,10 +5732,10 @@ def page(query, pagenum, pagesize=10): start = (pagenum - 1) * pagesize stop = pagenum * pagesize return query[start:stop] - def _aggregate(query, aggr_func_name, distinct=None): + def _aggregate(query, aggr_func_name, distinct=None, sep=None): translator = query._translator sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments( - aggr_func_name=aggr_func_name, aggr_func_distinct=distinct) + aggr_func_name=aggr_func_name, aggr_func_distinct=distinct, sep=sep) cache = query._database._get_cache() try: result = cache.query_results[query_key] except KeyError: @@ -5746,7 +5747,12 @@ def _aggregate(query, aggr_func_name, distinct=None): if result is None: pass elif aggr_func_name == 'COUNT': pass else: - expr_type = float if aggr_func_name == 'AVG' else translator.expr_type + if aggr_func_name == 'AVG': + expr_type = float + elif aggr_func_name == 'GROUP_CONCAT': + expr_type = basestring + else: + expr_type = translator.expr_type provider = query._database.provider converter = provider.get_converter_by_py_type(expr_type) result = converter.sql2py(result) @@ -5759,6 +5765,12 @@ def sum(query, distinct=None): def avg(query, distinct=None): return query._aggregate('AVG', distinct) @cut_traceback + def group_concat(query, sep=None, distinct=None): + if sep is not None: + if not isinstance(sep, basestring): + throw(TypeError, '`sep` option for `group_concat` should be of type str. Got: %s' % type(sep).__name__) + return query._aggregate('GROUP_CONCAT', distinct, sep) + @cut_traceback def min(query): return query._aggregate('MIN') @cut_traceback diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 313c4508d..bde07d2c5 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -464,6 +464,12 @@ def SUM(builder, distinct, expr): def AVG(builder, distinct, expr): assert distinct in (None, True, False) return distinct and 'AVG(DISTINCT ' or 'AVG(', builder(expr), ')' + def GROUP_CONCAT(builder, distinct, expr, sep=None): + assert distinct in (None, True, False) + result = distinct and 'GROUP_CONCAT(DISTINCT ' or 'GROUP_CONCAT(', builder(expr) + if sep is not None: + result = result, ', ', builder(sep) + return result, ')' UPPER = make_unary_func('upper') LOWER = make_unary_func('lower') LENGTH = make_unary_func('length') diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index d7b028291..632e0eaea 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -392,7 +392,7 @@ def can_be_optimized(translator): if translator.groupby_monads: return False if len(translator.aggregated_subquery_paths) != 1: return False return next(iter(translator.aggregated_subquery_paths)) - def construct_sql_ast(translator, range=None, distinct=None, aggr_func_name=None, aggr_func_distinct=None, + def construct_sql_ast(translator, range=None, distinct=None, aggr_func_name=None, aggr_func_distinct=None, sep=None, for_update=False, nowait=False, attrs_to_prefetch=(), is_not_null_checks=False): attr_offsets = None if distinct is None: distinct = translator.distinct @@ -432,10 +432,14 @@ def construct_sql_ast(translator, range=None, distinct=None, aggr_func_name=None expr_ast = translator.expr_columns[0] if expr_ast[0] == 'COLUMN': outer_alias, column_name = expr_ast[1:] - outer_aggr_ast = [ aggr_func_name, aggr_func_distinct, [ 'COLUMN', outer_alias, column_name ] ] + outer_aggr_ast = [aggr_func_name, aggr_func_distinct, ['COLUMN', outer_alias, column_name]] + if aggr_func_name == 'GROUP_CONCAT' and sep is not None: + outer_aggr_ast.append(['VALUE', sep]) else: select_ast = [ 'DISTINCT' if distinct else 'ALL' ] + [ [ 'AS', expr_ast, 'expr' ] ] outer_aggr_ast = [ aggr_func_name, aggr_func_distinct, [ 'COLUMN', 't', 'expr' ] ] + if aggr_func_name == 'GROUP_CONCAT' and sep is not None: + outer_aggr_ast.append(['VALUE', sep]) def ast_transformer(ast): return [ 'SELECT', [ 'AGGREGATES', outer_aggr_ast ], [ 'FROM', [ outer_alias, 'SELECT', ast[1:] ] ] ] @@ -446,7 +450,11 @@ def ast_transformer(ast): else: aggr_ast = [ 'COUNT', True if aggr_func_distinct is None else aggr_func_distinct, translator.expr_columns[0] ] - else: aggr_ast = [ aggr_func_name, aggr_func_distinct, translator.expr_columns[0] ] + else: + aggr_ast = [ aggr_func_name, aggr_func_distinct, translator.expr_columns[0] ] + if aggr_func_name == 'GROUP_CONCAT' and sep is not None: + aggr_ast.append(['VALUE', sep]) + if aggr_ast: select_ast = [ 'AGGREGATES', aggr_ast ] elif isinstance(translator.expr_type, EntityMeta) and not translator.parent \ and not translator.aggregated and not translator.optimize: diff --git a/pony/utils/utils.py b/pony/utils/utils.py index 3180ae879..16f80a030 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -509,6 +509,11 @@ def avg(iter): if not count: return None return sum / count +def group_concat(items, sep=','): + if items is None: + return None + return str(sep).join(str(item) for item in items) + def coalesce(*args): for arg in args: if arg is not None: From befdb72267e8f4c5f494bd40eb9d815ff90a99ad Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Thu, 5 Jul 2018 13:04:02 +0300 Subject: [PATCH 293/547] group_concat for QuerySetMonad --- pony/orm/sqltranslation.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 632e0eaea..a1c9db736 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2179,6 +2179,15 @@ class FuncAvgMonad(FuncMonad): def call(monad, x, distinct=None): return x.aggregate('AVG', distinct) +class FuncGroupConcatMonad(FuncMonad): + func = utils.group_concat, core.group_concat + def call(monad, x, sep=None, distinct=None): + if sep is not None: + if not(isinstance(sep, StringConstMonad) and isinstance(sep.value, basestring)): + throw(TypeError, '`sep` option of `group_concat` should be type of str. Got: %s' % ast2src(sep.node)) + sep = sep.value + return x.aggregate('GROUP_CONCAT', distinct=distinct, sep=sep) + class FuncCoalesceMonad(FuncMonad): func = coalesce def call(monad, *args): @@ -2783,7 +2792,7 @@ def count(monad, distinct=None): if sql_ast is None: sql_ast = [ 'SELECT', select_ast, from_ast, where_ast ] return translator.ExprMonad.new(translator, int, sql_ast) len = count - def aggregate(monad, func_name, distinct=None): + def aggregate(monad, func_name, distinct=None, sep=None): distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) translator = monad.translator sub = monad.subtranslator @@ -2799,12 +2808,22 @@ def aggregate(monad, func_name, distinct=None): if expr_type not in comparable_types: throw(TypeError, "Function %s() cannot be applied to type %r in {EXPR}" % (func_name.lower(), type2str(expr_type))) + elif func_name == 'GROUP_CONCAT': + pass else: assert False # pragma: no cover assert len(sub.expr_columns) == 1 aggr_ast = [ func_name, distinct, sub.expr_columns[0] ] + if func_name == 'GROUP_CONCAT': + if sep is not None: + aggr_ast.append(['VALUE', sep]) select_ast = [ 'AGGREGATES', aggr_ast ] sql_ast = [ 'SELECT', select_ast, from_ast, where_ast ] - result_type = float if func_name == 'AVG' else expr_type + if func_name == 'AVG': + result_type = float + elif func_name == 'GROUP_CONCAT': + result_type = basestring + else: + result_type = expr_type return translator.ExprMonad.new(translator, result_type, sql_ast) def call_count(monad, distinct=None): return monad.count(distinct=distinct) @@ -2816,6 +2835,11 @@ def call_max(monad): return monad.aggregate('MAX') def call_avg(monad, distinct=None): return monad.aggregate('AVG', distinct) + def call_group_concat(monad, sep=None, distinct=None): + if sep is not None: + if not isinstance(sep, basestring): + throw(TypeError, '`sep` option of `group_concat` should be type of str. Got: %s' % type(sep).__name__) + return monad.aggregate('GROUP_CONCAT', distinct, sep=sep) def find_or_create_having_ast(subquery_ast): groupby_offset = None From f1b749e03d00cfa6db4108178a2f380ee0ab9cb1 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Thu, 5 Jul 2018 13:53:52 +0300 Subject: [PATCH 294/547] group_concat for other monads --- pony/orm/sqltranslation.py | 63 ++++++++++++++++++++++++++++++-------- 1 file changed, 50 insertions(+), 13 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index a1c9db736..8957d531f 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -411,10 +411,13 @@ def construct_sql_ast(translator, range=None, distinct=None, aggr_func_name=None if aggr_func_name: expr_type = translator.expr_type if isinstance(expr_type, EntityMeta): - if aggr_func_name is not 'COUNT': throw(TypeError, + if aggr_func_name == 'GROUP_CONCAT': + if expr_type._pk_is_composite_: + throw(TypeError, "`group_concat` cannot be used with entity with composite primary key") + elif aggr_func_name != 'COUNT': throw(TypeError, 'Attribute should be specified for %r aggregate function' % aggr_func_name.lower()) elif isinstance(expr_type, tuple): - if aggr_func_name is not 'COUNT': throw(TypeError, + if aggr_func_name != 'COUNT': throw(TypeError, 'Single attribute should be specified for %r aggregate function' % aggr_func_name.lower()) else: if aggr_func_name in ('SUM', 'AVG') and expr_type not in numeric_types: @@ -1062,7 +1065,7 @@ def count(monad, distinct=None): result = translator.ExprMonad.new(translator, int, [ 'COUNT', distinct, expr ]) result.aggregated = True return result - def aggregate(monad, func_name, distinct=None): + def aggregate(monad, func_name, distinct=None, sep=None): distinct = distinct_from_monad(distinct) translator = monad.translator if monad.aggregated: throw(TranslationError, 'Aggregated functions cannot be nested. Got: {EXPR}') @@ -1077,6 +1080,9 @@ def aggregate(monad, func_name, distinct=None): if expr_type not in comparable_types: throw(TypeError, "Function '%s' cannot be applied to type %r in {EXPR}" % (func_name, type2str(expr_type))) + elif func_name == 'GROUP_CONCAT': + if isinstance(expr_type, EntityMeta) and expr_type._pk_is_composite_: + throw(TypeError, "`group_concat` cannot be used with entity with composite primary key") else: assert False # pragma: no cover expr = monad.getsql() if len(expr) == 1: expr = expr[0] @@ -1086,11 +1092,18 @@ def aggregate(monad, func_name, distinct=None): 'with composite primary keys inside aggregate functions. Got: {EXPR} ' '(you can suggest us how to write SQL for this query)' % translator.dialect) - if func_name == 'AVG': result_type = float - else: result_type = expr_type + if func_name == 'AVG': + result_type = float + elif func_name == 'GROUP_CONCAT': + result_type = unicode + else: + result_type = expr_type if distinct is None: distinct = getattr(monad, 'forced_distinct', False) and func_name in ('SUM', 'AVG') aggr_ast = [ func_name, distinct, expr ] + if func_name == 'GROUP_CONCAT': + if sep is not None: + aggr_ast.append(['VALUE', sep]) result = translator.ExprMonad.new(translator, result_type, aggr_ast) result.aggregated = True return result @@ -1206,7 +1219,7 @@ def __call__(monad, *args, **kwargs): def contains(monad, item, not_in=False): raise_forgot_parentheses(monad) def nonzero(monad): raise_forgot_parentheses(monad) def negate(monad): raise_forgot_parentheses(monad) - def aggregate(monad, func_name, distinct=None): raise_forgot_parentheses(monad) + def aggregate(monad, func_name, distinct=None, sep=None): raise_forgot_parentheses(monad) def __getitem__(monad, key): raise_forgot_parentheses(monad) def __add__(monad, monad2): raise_forgot_parentheses(monad) @@ -2440,7 +2453,7 @@ def count(monad, distinct=None): else: result.nogroup = True return result len = count - def aggregate(monad, func_name, distinct=None): + def aggregate(monad, func_name, distinct=None, sep=None): distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) translator = monad.translator item_type = monad.type.item_type @@ -2453,16 +2466,31 @@ def aggregate(monad, func_name, distinct=None): if item_type not in comparable_types: throw(TypeError, "Function %s() expects query or items of comparable type, got %r in {EXPR}" % (func_name.lower(), type2str(item_type))) + elif func_name == 'GROUP_CONCAT': + if isinstance(item_type, EntityMeta) and item_type._pk_is_composite_: + throw(TypeError, "`group_concat` cannot be used with entity with composite primary key") else: assert False # pragma: no cover - make_aggr = lambda expr_list: [ func_name, distinct ] + expr_list + def make_aggr(expr_list): + result = [ func_name, distinct ] + expr_list + if sep is not None: + assert func_name == 'GROUP_CONCAT' + result.append(['VALUE', sep]) + return result + + # make_aggr = lambda expr_list: [ func_name, distinct ] + expr_list if translator.hint_join: sql_ast, optimized = monad._joined_subselect(make_aggr, coalesce_to_zero=(func_name=='SUM')) else: sql_ast, optimized = monad._aggregated_scalar_subselect(make_aggr) - result_type = float if func_name == 'AVG' else item_type + if func_name == 'AVG': + result_type = float + elif func_name == 'GROUP_CONCAT': + result_type = unicode + else: + result_type = item_type translator.aggregated_subquery_paths.add(monad.tableref.name_path) result = translator.ExprMonad.new(monad.translator, result_type, sql_ast) if optimized: result.aggregated = True @@ -2622,7 +2650,7 @@ def __init__(monad, op, sqlop, left, right): monad.sqlop = sqlop monad.left = left monad.right = right - def aggregate(monad, func_name, distinct=None): + def aggregate(monad, func_name, distinct=None, sep=None): distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) translator = monad.translator subquery = Subquery(translator.subquery) @@ -2631,8 +2659,16 @@ def aggregate(monad, func_name, distinct=None): outer_cond = subquery.from_ast[1].pop() if outer_cond[0] == 'AND': subquery.outer_conditions = outer_cond[1:] else: subquery.outer_conditions = [ outer_cond ] - result_type = float if func_name == 'AVG' else monad.type.item_type + if func_name == 'AVG': + result_type = float + elif func_name == 'GROUP_CONCAT': + result_type = unicode + else: + result_type = monad.type.item_type aggr_ast = [ func_name, distinct, expr ] + if func_name == 'GROUP_CONCAT': + if sep is not None: + aggr_ast.append(['VALUE', sep]) if translator.optimize != monad.tableref.name_path: sql_ast = [ 'SELECT', [ 'AGGREGATES', aggr_ast ], subquery.from_ast, @@ -2809,7 +2845,8 @@ def aggregate(monad, func_name, distinct=None, sep=None): "Function %s() cannot be applied to type %r in {EXPR}" % (func_name.lower(), type2str(expr_type))) elif func_name == 'GROUP_CONCAT': - pass + if isinstance(expr_type, EntityMeta) and expr_type._pk_is_composite_: + throw(TypeError, "`group_concat` cannot be used with entity with composite primary key") else: assert False # pragma: no cover assert len(sub.expr_columns) == 1 aggr_ast = [ func_name, distinct, sub.expr_columns[0] ] @@ -2821,7 +2858,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): if func_name == 'AVG': result_type = float elif func_name == 'GROUP_CONCAT': - result_type = basestring + result_type = unicode else: result_type = expr_type return translator.ExprMonad.new(translator, result_type, sql_ast) From 8af23f786607483a5be8338cd165f9fa7e14b5d8 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Thu, 5 Jul 2018 14:05:21 +0300 Subject: [PATCH 295/547] PostreSQL support of group_concat --- pony/orm/dbproviders/postgres.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index d77967c19..eb332de94 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -113,6 +113,15 @@ def JSON_CONTAINS(builder, expr, path, key): return (builder.JSON_QUERY(expr, path) if path else builder(expr)), ' ? ', builder(key) def JSON_ARRAY_LENGTH(builder, value): return 'jsonb_array_length(', builder(value), ')' + def GROUP_CONCAT(builder, distinct, expr, sep=None): + assert distinct in (None, True, False) + result = distinct and 'string_agg(distinct ' or 'string_agg(', builder(expr), '::text' + if sep is not None: + result = result, ', ', builder(sep) + else: + result = result, ", ','" + return result, ')' + class PGStrConverter(dbapiprovider.StrConverter): if PY2: From fa4a18f917ca4218831b6eaf3a2996a8666d51ba Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Thu, 5 Jul 2018 15:21:49 +0300 Subject: [PATCH 296/547] group_concat Oracle support --- pony/orm/dbproviders/oracle.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 07af1818d..867ca2aa1 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -277,6 +277,14 @@ def JSON_CONTAINS(builder, expr, path, key): return result def JSON_ARRAY_LENGTH(builder, value): throw(TranslationError, 'Oracle does not provide `length` function for JSON arrays') + def GROUP_CONCAT(builder, distinct, expr, sep=None): + assert distinct in (None, True, False) + result = 'LISTAGG(', builder(expr) + if sep is not None: + result = result, ', ', builder(sep) + else: + result = result, ", ','" + return result, ') WITHIN GROUP(ORDER BY 1)' json_item_re = re.compile('[\w\s]*') From bbf55f71df61bbe77039f6a2cb7eff391dac51f3 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Thu, 5 Jul 2018 15:35:26 +0300 Subject: [PATCH 297/547] tests for group_concat --- pony/orm/tests/queries.txt | 30 ++++++++++++++++ .../tests/test_declarative_query_set_monad.py | 36 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/pony/orm/tests/queries.txt b/pony/orm/tests/queries.txt index 2fb773fc4..6d16da848 100644 --- a/pony/orm/tests/queries.txt +++ b/pony/orm/tests/queries.txt @@ -914,3 +914,33 @@ Oracle: SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" FROM "STUDENT" "s" WHERE MOD("s"."ID", 2) = 0 + +# Test group_concat: + +>>> select((g, group_concat(s.name, '+')) for g in Group for s in g.students) + +SELECT "g"."number", GROUP_CONCAT("s"."name", '+') +FROM "Group" "g", "Student" "s" +WHERE "g"."number" = "s"."group" +GROUP BY "g"."number" + +PostgreSQL: + +SELECT "g"."number", string_agg("s"."name"::text, '+') +FROM "group" "g", "student" "s" +WHERE "g"."number" = "s"."group" +GROUP BY "g"."number" + +MySQL: + +SELECT `g`.`number`, GROUP_CONCAT(`s`.`name`, '+') +FROM `group` `g`, `student` `s` +WHERE `g`.`number` = `s`.`group` +GROUP BY `g`.`number` + +Oracle: + +SELECT "g"."NUMBER", LISTAGG("s"."NAME", '+') WITHIN GROUP(ORDER BY 1) +FROM "GROUP" "g", "STUDENT" "s" +WHERE "g"."NUMBER" = "s"."GROUP" +GROUP BY "g"."NUMBER" diff --git a/pony/orm/tests/test_declarative_query_set_monad.py b/pony/orm/tests/test_declarative_query_set_monad.py index 1197234b4..10d20313f 100644 --- a/pony/orm/tests/test_declarative_query_set_monad.py +++ b/pony/orm/tests/test_declarative_query_set_monad.py @@ -241,5 +241,41 @@ def test_hint_join_4(self): result = set(select(g for g in Group if JOIN(g in select(s.group for s in g.students)))) self.assertEqual(result, {Group[1], Group[2]}) + def test_group_concat_1(self): + result = select(s.name for s in Student).group_concat() + self.assertEqual(result, 'S1,S2,S3') + + def test_group_concat_2(self): + result = select(s.name for s in Student).group_concat('-') + self.assertEqual(result, 'S1-S2-S3') + + def test_group_concat_3(self): + result = select(s for s in Student if s.name in group_concat(s.name for s in Student))[:] + self.assertEqual(set(result), {Student[1], Student[2], Student[3]}) + + def test_group_concat_4(self): + result = Student.select().group_concat() + self.assertEqual(result, '1,2,3') + + def test_group_concat_5(self): + result = Student.select().group_concat('.') + self.assertEqual(result, '1.2.3') + + @raises_exception(TypeError, '`group_concat` cannot be used with entity with composite primary key') + def test_group_concat_6(self): + select(group_concat(s.courses, '-') for s in Student) + + def test_group_concat_7(self): + result = select(group_concat(c.semester) for c in Course)[:] + self.assertEqual(result[0], '1,1,2') + + def test_group_concat_8(self): + result = select(group_concat(c.semester, '-') for c in Course)[:] + self.assertEqual(result[0], '1-1-2') + + def test_group_concat_9(self): + result = select(group_concat(c.semester, distinct=True) for c in Course)[:] + self.assertEqual(result[0], '1,2') + if __name__ == "__main__": unittest.main() From 2500293c00c79c224639a34c5c32d0904b48c675 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 11 Jul 2018 11:17:36 +0300 Subject: [PATCH 298/547] make_aggrfunc fix and group_concat tests --- pony/orm/core.py | 10 +++++----- pony/orm/tests/test_declarative_query_set_monad.py | 9 +++++++++ 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 759341f08..f22577914 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5202,15 +5202,15 @@ def delete(*args): def make_aggrfunc(std_func): def aggrfunc(*args, **kwargs): - if kwargs: return std_func(*args, **kwargs) - if len(args) != 1: return std_func(*args) + if not args: + return std_func(**kwargs) arg = args[0] if type(arg) is types.GeneratorType: try: iterator = arg.gi_frame.f_locals['.0'] - except: return std_func(*args) + except: return std_func(*args, **kwargs) if isinstance(iterator, EntityIter): - return getattr(select(arg), std_func.__name__)() - return std_func(*args) + return getattr(select(arg), std_func.__name__)(*args[1:], **kwargs) + return std_func(*args, **kwargs) aggrfunc.__name__ = std_func.__name__ return aggrfunc diff --git a/pony/orm/tests/test_declarative_query_set_monad.py b/pony/orm/tests/test_declarative_query_set_monad.py index 10d20313f..c91517a44 100644 --- a/pony/orm/tests/test_declarative_query_set_monad.py +++ b/pony/orm/tests/test_declarative_query_set_monad.py @@ -277,5 +277,14 @@ def test_group_concat_9(self): result = select(group_concat(c.semester, distinct=True) for c in Course)[:] self.assertEqual(result[0], '1,2') + def test_group_concat_10(self): + result = group_concat((s.name for s in Student if int(s.name[1]) > 1), sep='-') + self.assertEqual(result, 'S2-S3') + + def test_group_concat_11(self): + result = group_concat((c.semester for c in Course), distinct=True) + self.assertEqual(result, '1,2') + + if __name__ == "__main__": unittest.main() From c0ff175c3b3d4c0a90ff535f3c8f50260c5f7f2f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 13 Jul 2018 10:54:08 +0300 Subject: [PATCH 299/547] Fix transformer.com_generator_expression in Python 3.7 for generators with multiple for-loops --- pony/thirdparty/compiler/transformer.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pony/thirdparty/compiler/transformer.py b/pony/thirdparty/compiler/transformer.py index 993ea4a35..d9cb9b5a9 100644 --- a/pony/thirdparty/compiler/transformer.py +++ b/pony/thirdparty/compiler/transformer.py @@ -1226,12 +1226,17 @@ def com_generator_expression(self, expr, node): # comp_for: 'for' exprlist 'in' test [comp_iter] # comp_if: 'if' test [comp_iter] - if sys.version_info >= (3, 7): - node = node[1] # remove async part + PY37 = sys.version_info >= (3, 7) - lineno = node[1][2] fors = [] while node: + if PY37 and node[0] == symbol.comp_for: + node = node[1] + assert node[0] == symbol.sync_comp_for + + lineno = node[1][2] + assert lineno is None or isinstance(lineno, int) + t = node[1][1] if t == 'for': assignNode = self.com_assign(node[2], OP_ASSIGN) @@ -1254,7 +1259,7 @@ def com_generator_expression(self, expr, node): else: raise SyntaxError("unexpected generator expression element: %s %d" % (node, lineno)) fors[0].is_outmost = True - return GenExpr(GenExprInner(expr, fors), lineno=lineno) + return GenExpr(GenExprInner(expr, fors), lineno=expr.lineno) def com_dictorsetmaker(self, nodelist): # dictorsetmaker: ( (test ':' test (comp_for | (',' test ':' test)* [','])) | From 913ed872bab17da22be74f5c0207a5ec409c9dab Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 16 Jul 2018 17:59:13 +0300 Subject: [PATCH 300/547] Refactoring of get_normalize_type_of for Flask support --- pony/orm/core.py | 6 ++--- pony/orm/dbapiprovider.py | 14 +++++----- pony/orm/ormtypes.py | 53 +++++++++++++++++++++++++++----------- pony/orm/sqltranslation.py | 6 ++--- 4 files changed, 51 insertions(+), 28 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index f22577914..991fd3e48 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -19,7 +19,7 @@ import pony from pony import options from pony.orm.decompiling import decompile -from pony.orm.ormtypes import LongStr, LongUnicode, numeric_types, RawSQL, get_normalized_type_of, Json, TrackedValue +from pony.orm.ormtypes import LongStr, LongUnicode, numeric_types, RawSQL, normalize, Json, TrackedValue from pony.orm.asttranslation import ast2src, create_extractors, TranslationError from pony.orm.dbapiprovider import ( DBAPIProvider, DBException, Warning, Error, InterfaceError, DatabaseError, DataError, @@ -5258,7 +5258,7 @@ def extract_vars(filter_num, extractors, globals, locals, cells=None): if src == 'None' and value is not None: throw(TranslationError) if src == 'True' and value is not True: throw(TranslationError) if src == 'False' and value is not False: throw(TranslationError) - try: vartypes[key] = get_normalized_type_of(value) + try: vartypes[key], value = normalize(value) except TypeError: if not isinstance(value, dict): unsupported = False @@ -5269,7 +5269,7 @@ def extract_vars(filter_num, extractors, globals, locals, cells=None): typename = type(value).__name__ if src == '.0': throw(TypeError, 'Cannot iterate over non-entity object') throw(TypeError, 'Expression `%s` has unsupported type %r' % (src, typename)) - vartypes[key] = get_normalized_type_of(value) + vartypes[key], value = normalize(value) vars[key] = value return vars, vartypes diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index a81956b7d..5f1b7129e 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -134,23 +134,23 @@ def get_default_m2m_table_name(provider, attr, reverse): return provider.normalize_name(name) def get_default_column_names(provider, attr, reverse_pk_columns=None): - normalize = provider.normalize_name + normalize_name = provider.normalize_name if reverse_pk_columns is None: - return [ normalize(attr.name) ] + return [ normalize_name(attr.name) ] elif len(reverse_pk_columns) == 1: - return [ normalize(attr.name) ] + return [ normalize_name(attr.name) ] else: prefix = attr.name + '_' - return [ normalize(prefix + column) for column in reverse_pk_columns ] + return [ normalize_name(prefix + column) for column in reverse_pk_columns ] def get_default_m2m_column_names(provider, entity): - normalize = provider.normalize_name + normalize_name = provider.normalize_name columns = entity._get_pk_columns_() if len(columns) == 1: - return [ normalize(entity.__name__.lower()) ] + return [ normalize_name(entity.__name__.lower()) ] else: prefix = entity.__name__.lower() + '_' - return [ normalize(prefix + column) for column in columns ] + return [ normalize_name(prefix + column) for column in columns ] def get_default_index_name(provider, table_name, column_names, is_pk=False, is_unique=False, m2m=False): if is_pk: index_name = 'pk_%s' % provider.base_name(table_name) diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 75af93aa9..0f9f012cd 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -103,8 +103,7 @@ def __deepcopy__(self, memo): def __init__(self, sql, globals=None, locals=None, result_type=None): self.sql = sql self.items, self.codes = parse_raw_sql(sql) - self.values = tuple(eval(code, globals, locals) for code in self.codes) - self.types = tuple(get_normalized_type_of(value) for value in self.values) + self.types, self.values = normalize(tuple(eval(code, globals, locals) for code in self.codes)) self.result_type = result_type def _get_type_(self): return RawSQLType(self.sql, self.items, self.types, self.result_type) @@ -130,22 +129,46 @@ def __ne__(self, other): function_types = {type, types.FunctionType, types.BuiltinFunctionType} type_normalization_dict = { long : int } if PY2 else {} -def get_normalized_type_of(value): +def normalize(value): t = type(value) - if t is tuple: return tuple(get_normalized_type_of(item) for item in value) - if t.__name__ == 'EntityMeta': return SetType(value) - if t.__name__ == 'EntityIter': return SetType(value.entity) + if t.__name__ == 'LocalProxy' and '_get_current_object' in t.__dict__: + value = value._get_current_object() + t = type(value) + + if t is tuple: + item_types, item_values = [], [] + for item in value: + item_type, item_value = normalize(item) + item_values.append(item_value) + item_types.append(item_type) + return tuple(item_types), tuple(item_values) + + if t.__name__ == 'EntityMeta': + return SetType(value), value + + if t.__name__ == 'EntityIter': + return SetType(value.entity), value + if PY2 and isinstance(value, str): - try: value.decode('ascii') - except UnicodeDecodeError: throw(TypeError, - 'The bytestring %r contains non-ascii symbols. Try to pass unicode string instead' % value) - else: return unicode - elif isinstance(value, unicode): return unicode - if t in function_types: return FuncType(value) - if t is types.MethodType: return MethodType(value) + try: + value.decode('ascii') + except UnicodeDecodeError: + throw(TypeError, 'The bytestring %r contains non-ascii symbols. Try to pass unicode string instead' % value) + else: + return unicode, value + elif isinstance(value, unicode): + return unicode, value + + if t in function_types: + return FuncType(value), value + + if t is types.MethodType: + return MethodType(value), value + if hasattr(value, '_get_type_'): - return value._get_type_() - return normalize_type(t) + return value._get_type_(), value + + return normalize_type(t), value def normalize_type(t): tt = type(t) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 8957d531f..62795285d 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -16,7 +16,7 @@ from pony.orm.asttranslation import ASTTranslator, ast2src, TranslationError from pony.orm.ormtypes import \ numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ - get_normalized_type_of, normalize_type, coerce_types, are_comparable_types, \ + normalize, normalize_type, coerce_types, are_comparable_types, \ Json from pony.orm import core from pony.orm.core import EntityMeta, Set, JOIN, OptimizationFailed, Attribute, DescWrapper @@ -1832,7 +1832,7 @@ def getsql(monad): class ConstMonad(Monad): @staticmethod def new(translator, value): - value_type = get_normalized_type_of(value) + value_type, value = normalize(value) if value_type in numeric_types: cls = translator.NumericConstMonad elif value_type is unicode: cls = translator.StringConstMonad elif value_type is date: cls = translator.DateConstMonad @@ -1851,7 +1851,7 @@ def __new__(cls, *args): if cls is ConstMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) def __init__(monad, translator, value): - value_type = get_normalized_type_of(value) + value_type, value = normalize(value) Monad.__init__(monad, translator, value_type) monad.value = value def getsql(monad, subquery=None): From 30c77bccfbfddfa7be46fc5e0e748fe77cfc96f2 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Fri, 20 Jul 2018 14:22:38 +0300 Subject: [PATCH 301/547] Support is_empty for set attributes in query --- pony/orm/sqltranslation.py | 1 + pony/orm/tests/test_declarative_attr_set_monad.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 62795285d..59815cdcd 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2508,6 +2508,7 @@ def negate(monad): [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] translator = monad.translator return translator.BoolExprMonad(translator, sql_ast) + call_is_empty = negate def make_tableref(monad, subquery): parent = monad.parent attr = monad.attr diff --git a/pony/orm/tests/test_declarative_attr_set_monad.py b/pony/orm/tests/test_declarative_attr_set_monad.py index f6f9a6418..938f882d2 100644 --- a/pony/orm/tests/test_declarative_attr_set_monad.py +++ b/pony/orm/tests/test_declarative_attr_set_monad.py @@ -165,6 +165,9 @@ def test26(self): @raises_exception(AttributeError, 'g.students.name.foo') def test27(self): select(g for g in Group if g.students.name.foo == 1) + def test28(self): + groups = set(select(g for g in Group if not g.students.is_empty())) + self.assertEqual(groups, {Group[41], Group[42], Group[44]}) if __name__ == "__main__": unittest.main() From 6e56b73eace07ba3ef37c81fc746cdafb6e31305 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Fri, 20 Jul 2018 18:37:00 +0300 Subject: [PATCH 302/547] SQLite3 like is now case sensitive --- pony/orm/dbproviders/sqlite.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index f46e385d2..e18a2065a 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -537,6 +537,8 @@ def create_function(name, num_params, func): if sqlite.sqlite_version_info >= (3, 6, 19): con.execute('PRAGMA foreign_keys = true') + + con.execute('PRAGMA case_sensitive_like = true') def disconnect(pool): if pool.filename != ':memory:': Pool.disconnect(pool) From 1d122057919e49412799ef25b3a1e1d4d83bb88e Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 24 May 2018 10:48:49 +0300 Subject: [PATCH 303/547] Move code around --- pony/orm/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 991fd3e48..5bacaae04 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -269,8 +269,6 @@ def adapt_sql(sql, paramstyle): adapted_sql_cache[(sql, paramstyle)] = result return result -num_counter = itertools.count() - class Local(localbase): def __init__(local): local.debug = False @@ -1594,6 +1592,8 @@ def avg_time(stat): if not stat.db_count: return None return stat.sum_time / stat.db_count +num_counter = itertools.count() + class SessionCache(object): def __init__(cache, database): cache.is_alive = True From 5015d634a962bf003f481a641180d94e8adc0aed Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 28 Feb 2018 13:00:13 +0300 Subject: [PATCH 304/547] Always create new translator using translator_cls --- pony/orm/sqltranslation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 59815cdcd..0e575b74a 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -648,7 +648,8 @@ def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_ return translator def preGenExpr(translator, node): inner_tree = node.code - subtranslator = translator.__class__(inner_tree, translator.extractors, translator.vartypes, translator) + translator_cls = translator.__class__ + subtranslator = translator_cls(inner_tree, translator.extractors, translator.vartypes, translator) return translator.QuerySetMonad(translator, subtranslator) def postGenExprIf(translator, node): monad = node.test.monad @@ -771,7 +772,8 @@ def preCallFunc(translator, node): name_ast.monad = entity_monad for_expr = ast.GenExprFor(ast.AssName(iter_name, 'OP_ASSIGN'), name_ast, [ if_expr ]) inner_expr = ast.GenExprInner(ast.Name(iter_name), [ for_expr ]) - subtranslator = translator.__class__(inner_expr, translator.extractors, translator.vartypes, translator) + translator_cls = translator.__class__ + subtranslator = translator_cls(inner_expr, translator.extractors, translator.vartypes, translator) return translator.QuerySetMonad(translator, subtranslator) def postCallFunc(translator, node): args = [] From 88cb468cd7d6726dfe5db9dc4755b22cfc605859 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 30 May 2018 15:14:46 +0300 Subject: [PATCH 305/547] Every translator now gets unique filter_num from core.filter_num_counter --- pony/orm/core.py | 18 ++++++++++-------- pony/orm/sqltranslation.py | 8 ++++---- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 5bacaae04..f4bc7b7ad 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5276,12 +5276,14 @@ def extract_vars(filter_num, extractors, globals, locals, cells=None): def unpickle_query(query_result): return query_result +filter_num_counter = itertools.count() + class Query(object): def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False): assert isinstance(tree, ast.GenExprInner) extractors, tree, extractors_key = create_extractors( code_key, tree, globals, locals, special_functions, const_functions) - filter_num = 0 + filter_num = next(filter_num_counter) vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) node = tree.quals[0].iter @@ -5305,11 +5307,12 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False pickled_tree = pickle_ast(tree) tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) translator_cls = database.provider.translator_cls - translator = translator_cls(tree_copy, extractors, vartypes, left_join=left_join) + translator = translator_cls(tree_copy, filter_num, extractors, vartypes, left_join=left_join) name_path = translator.can_be_optimized() if name_path: tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) - try: translator = translator_cls(tree_copy, extractors, vartypes, left_join=True, optimize=name_path) + try: translator = translator_cls(tree_copy, filter_num, extractors, vartypes, + left_join=True, optimize=name_path) except OptimizationFailed: translator.optimization_failed = True translator.pickled_tree = pickled_tree database._translator_cache[query._key] = translator @@ -5606,7 +5609,7 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names throw(TypeError, 'Incorrect number of lambda arguments. ' 'Expected: %d, got: %d' % (expr_count, len(argnames))) - filter_num = len(query._filters) + 1 + filter_num = next(filter_num_counter) extractors, func_ast, extractors_key = create_extractors( func_id, func_ast, globals, locals, special_functions, const_functions, argnames or prev_translator.subquery) @@ -5627,11 +5630,10 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names name_path = new_translator.can_be_optimized() if name_path: tree_copy = unpickle_ast(prev_translator.pickled_tree) # tree = deepcopy(tree) - prev_extractors = prev_translator.extractors - prev_vartypes = prev_translator.vartypes translator_cls = prev_translator.__class__ - new_translator = translator_cls(tree_copy, prev_extractors, prev_vartypes, - left_join=True, optimize=name_path) + new_translator = translator_cls( + tree_copy, prev_translator.original_filter_num, prev_translator.extractors, prev_translator.vartypes, + left_join=True, optimize=name_path) new_translator = query._reapply_filters(new_translator) new_translator = new_translator.apply_lambda(filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes) query._database._translator_cache[new_key] = new_translator diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 0e575b74a..7979be797 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -163,12 +163,12 @@ def call(translator, method, node): else: throw(TranslationError, 'Too complex aggregation, expressions cannot be combined: %s' % ast2src(node)) return monad - def __init__(translator, tree, extractors, vartypes, parent_translator=None, left_join=False, optimize=None): + def __init__(translator, tree, filter_num, extractors, vartypes, parent_translator=None, left_join=False, optimize=None): assert isinstance(tree, ast.GenExprInner), tree ASTTranslator.__init__(translator, tree) translator.database = None translator.lambda_argnames = None - translator.filter_num = parent_translator.filter_num if parent_translator is not None else 0 + translator.filter_num = translator.original_filter_num = filter_num translator.extractors = extractors translator.vartypes = vartypes.copy() translator.parent = parent_translator @@ -649,7 +649,7 @@ def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_ def preGenExpr(translator, node): inner_tree = node.code translator_cls = translator.__class__ - subtranslator = translator_cls(inner_tree, translator.extractors, translator.vartypes, translator) + subtranslator = translator_cls(inner_tree, translator.filter_num, translator.extractors, translator.vartypes, translator) return translator.QuerySetMonad(translator, subtranslator) def postGenExprIf(translator, node): monad = node.test.monad @@ -773,7 +773,7 @@ def preCallFunc(translator, node): for_expr = ast.GenExprFor(ast.AssName(iter_name, 'OP_ASSIGN'), name_ast, [ if_expr ]) inner_expr = ast.GenExprInner(ast.Name(iter_name), [ for_expr ]) translator_cls = translator.__class__ - subtranslator = translator_cls(inner_expr, translator.extractors, translator.vartypes, translator) + subtranslator = translator_cls(inner_expr, translator.filter_num, translator.extractors, translator.vartypes, translator) return translator.QuerySetMonad(translator, subtranslator) def postCallFunc(translator, node): args = [] From 642899bcf6ae8203b5b328bce133d7de806e449f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 21 Mar 2018 14:16:52 +0300 Subject: [PATCH 306/547] Closure bug fixed: "free variable referenced before assignment in enclosing scope" --- pony/orm/core.py | 5 ++++- pony/orm/tests/test_query.py | 8 ++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index f4bc7b7ad..cb68aa5ac 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5246,7 +5246,10 @@ def extract_vars(filter_num, extractors, globals, locals, cells=None): if cells: locals = locals.copy() for name, cell in cells.items(): - locals[name] = cell.cell_contents + try: + locals[name] = cell.cell_contents + except ValueError: + throw(NameError, 'Free variable `%s` referenced before assignment in enclosing scope' % name) vars = {} vartypes = HashableDict() for src, code in iteritems(extractors): diff --git a/pony/orm/tests/test_query.py b/pony/orm/tests/test_query.py index 5a7f2fc3d..be7977454 100644 --- a/pony/orm/tests/test_query.py +++ b/pony/orm/tests/test_query.py @@ -146,6 +146,14 @@ def find_by_gpa(gpa): q = select(s for s in Student) q = q.filter(fn) self.assertEqual(list(q), [ Student[2], Student[3] ]) + @raises_exception(NameError, 'Free variable `gpa` referenced before assignment in enclosing scope') + def test_closures_3(self): + def find_by_gpa(): + if False: + gpa = Decimal('3.1') + return lambda s: s.gpa > gpa + fn = find_by_gpa() + students = list(Student.select(fn)) if __name__ == '__main__': unittest.main() From db16dc0c8cd1fdd60e6450b9f026b3ddf433e9c9 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 11 Apr 2018 13:40:32 +0300 Subject: [PATCH 307/547] Fix HashableDict.__hash__() in Python3 --- pony/utils/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pony/utils/utils.py b/pony/utils/utils.py index 16f80a030..aa71bcd0d 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -572,7 +572,11 @@ class HashableDict(dict): def __hash__(self): result = getattr(self, '_hash', None) if result is None: - result = self._hash = hash(tuple(sorted(self.items()))) + result = 0 + for key, value in self.items(): + result ^= hash(key) + result ^= hash(value) + self._hash = result return result def __deepcopy__(self, memo): if getattr(self, '_hash', None) is not None: From edcb9d2c1e34b70a76dd5e25625f7812ebace018 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 30 Oct 2017 21:48:41 +0300 Subject: [PATCH 308/547] Renaming: additional_internal_names -> outer_names --- pony/orm/asttranslation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 073cf8045..beea40629 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -209,7 +209,7 @@ def postKeyword(translator, node): class PreTranslator(ASTTranslator): def __init__(translator, tree, globals, locals, - special_functions, const_functions, additional_internal_names=()): + special_functions, const_functions, outer_names=()): ASTTranslator.__init__(translator, tree) translator.getattr_nodes = set() translator.globals = globals @@ -217,8 +217,8 @@ def __init__(translator, tree, globals, locals, translator.special_functions = special_functions translator.const_functions = const_functions translator.contexts = [] - if additional_internal_names: - translator.contexts.append(additional_internal_names) + if outer_names: + translator.contexts.append(outer_names) translator.externals = externals = set() translator.dispatch(tree) for node in externals.copy(): @@ -303,7 +303,7 @@ def postCallFunc(translator, node): getattr_cache = {} extractors_cache = {} -def create_extractors(code_key, tree, globals, locals, special_functions, const_functions, additional_internal_names=()): +def create_extractors(code_key, tree, globals, locals, special_functions, const_functions, outer_names=()): result = None getattr_extractors = getattr_cache.get(code_key) if getattr_extractors: @@ -319,7 +319,7 @@ def create_extractors(code_key, tree, globals, locals, special_functions, const_ if not result: pretranslator = PreTranslator( - tree, globals, locals, special_functions, const_functions, additional_internal_names) + tree, globals, locals, special_functions, const_functions, outer_names) extractors = {} for node in pretranslator.externals: From ee25054c315c0313cf26acac564b305744735bd7 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 2 Dec 2017 20:21:17 +0300 Subject: [PATCH 309/547] Refactoring: use extractor functions instead of codeobjects: eval(code, globals, locals) -> extractor(globals, locals) --- pony/orm/asttranslation.py | 19 +++++++++++-------- pony/orm/core.py | 14 ++++++-------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index beea40629..038b69c41 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -307,8 +307,8 @@ def create_extractors(code_key, tree, globals, locals, special_functions, const_ result = None getattr_extractors = getattr_cache.get(code_key) if getattr_extractors: - getattr_attrnames = HashableDict({src: eval(code, globals, locals) - for src, code in iteritems(getattr_extractors)}) + getattr_attrnames = HashableDict({src: extractor(globals, locals) + for src, extractor in iteritems(getattr_extractors)}) extractors_key = HashableDict(code_key=code_key, getattr_attrnames=getattr_attrnames) try: result = extractors_cache.get(extractors_key) @@ -324,18 +324,21 @@ def create_extractors(code_key, tree, globals, locals, special_functions, const_ extractors = {} for node in pretranslator.externals: src = node.src = ast2src(node) - if src == '.0': code = None - else: code = compile(src, src, 'eval') - extractors[src] = code + if src == '.0': + extractor = lambda globals, locals: locals['.0'] + else: + code = compile(src, src, 'eval') + extractor = lambda globals, locals, code=code: eval(code, globals, locals) + extractors[src] = extractor getattr_extractors = {} getattr_attrnames = HashableDict() for node in pretranslator.getattr_nodes: if node in pretranslator.externals: src = node.src - code = extractors[src] - getattr_extractors[src] = code - attrname_value = eval(code, globals, locals) + extractor = extractors[src] + getattr_extractors[src] = extractor + attrname_value = extractor(globals, locals) getattr_attrnames[src] = attrname_value elif isinstance(node, ast.Const): attrname_value = node.value diff --git a/pony/orm/core.py b/pony/orm/core.py index cb68aa5ac..d08bbbcd0 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5252,15 +5252,13 @@ def extract_vars(filter_num, extractors, globals, locals, cells=None): throw(NameError, 'Free variable `%s` referenced before assignment in enclosing scope' % name) vars = {} vartypes = HashableDict() - for src, code in iteritems(extractors): + for src, extractor in iteritems(extractors): key = filter_num, src - if src == '.0': value = locals['.0'] - else: - try: value = eval(code, globals, locals) - except Exception as cause: raise ExprEvalError(src, cause) - if src == 'None' and value is not None: throw(TranslationError) - if src == 'True' and value is not True: throw(TranslationError) - if src == 'False' and value is not False: throw(TranslationError) + try: value = extractor(globals, locals) + except Exception as cause: raise ExprEvalError(src, cause) + if src == 'None' and value is not None: throw(TranslationError) + if src == 'True' and value is not True: throw(TranslationError) + if src == 'False' and value is not False: throw(TranslationError) try: vartypes[key], value = normalize(value) except TypeError: if not isinstance(value, dict): From 6a473494601d226e2a397e84bbdd993604e56e98 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 28 Feb 2018 11:54:18 +0300 Subject: [PATCH 310/547] Reformat code --- pony/orm/asttranslation.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 038b69c41..9f8f74848 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -318,17 +318,17 @@ def create_extractors(code_key, tree, globals, locals, special_functions, const_ tree = copy_ast(tree) if not result: - pretranslator = PreTranslator( - tree, globals, locals, special_functions, const_functions, outer_names) - + pretranslator = PreTranslator(tree, globals, locals, special_functions, const_functions, outer_names) extractors = {} for node in pretranslator.externals: src = node.src = ast2src(node) if src == '.0': - extractor = lambda globals, locals: locals['.0'] + def extractor(globals, locals): + return locals['.0'] else: code = compile(src, src, 'eval') - extractor = lambda globals, locals, code=code: eval(code, globals, locals) + def extractor(globals, locals, code=code): + return eval(code, globals, locals) extractors[src] = extractor getattr_extractors = {} From d0f745a01918a89c6173dcf7bf92d693c3d69f30 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 28 Feb 2018 13:02:17 +0300 Subject: [PATCH 311/547] Minor refactoring --- pony/orm/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index d08bbbcd0..b3f47f8ff 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5702,14 +5702,14 @@ def _apply_kwargs(query, kwargs, original_names=False): tup = (('apply_kwfilters', filterattrs, original_names),) new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + tup + new_vars = query._vars.copy() + new_vars.update(value_dict) new_translator = query._database._translator_cache.get(new_key) if new_translator is None: new_translator = translator.apply_kwfilters(filterattrs, original_names) query._database._translator_cache[new_key] = new_translator - new_query = query._clone(_key=new_key, _filters=new_filters, _translator=new_translator, - _next_kwarg_id=next_id, _vars=query._vars.copy()) - new_query._vars.update(value_dict) - return new_query + return query._clone(_key=new_key, _filters=new_filters, _translator=new_translator, + _next_kwarg_id=next_id, _vars=new_vars) @cut_traceback def __getitem__(query, key): if isinstance(key, slice): From d2fc1a3b9961e57e57acd2371be1c4904ee7d828 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 28 Feb 2018 13:06:29 +0300 Subject: [PATCH 312/547] Local variable renaming --- pony/orm/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index b3f47f8ff..394d650a2 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5617,9 +5617,9 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names if extractors: vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) query._database.provider.normalize_vars(vars, vartypes) - new_query_vars = query._vars.copy() - new_query_vars.update(vars) - else: new_query_vars, vartypes = query._vars, HashableDict() + new_vars = query._vars.copy() + new_vars.update(vars) + else: new_vars, vartypes = query._vars, HashableDict() tup = (('order_by' if order_by else 'where' if original_names else 'filter', extractors_key, vartypes),) new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + (('apply_lambda', filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes),) @@ -5638,7 +5638,7 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names new_translator = query._reapply_filters(new_translator) new_translator = new_translator.apply_lambda(filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes) query._database._translator_cache[new_key] = new_translator - return query._clone(_vars=new_query_vars, _key=new_key, _filters=new_filters, _translator=new_translator) + return query._clone(_vars=new_vars, _key=new_key, _filters=new_filters, _translator=new_translator) def _reapply_filters(query, translator): for tup in query._filters: method_name, args = tup[0], tup[1:] From 864f55f78701fb9643b1fa4d7546ec675197b5d5 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 26 Feb 2018 17:15:40 +0300 Subject: [PATCH 313/547] Rename local variable --- pony/orm/dbproviders/oracle.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 867ca2aa1..d41448206 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -442,10 +442,10 @@ def normalize_name(provider, name): return name[:provider.max_name_len].upper() def normalize_vars(provider, vars, vartypes): - for name, value in iteritems(vars): + for key, value in iteritems(vars): if value == '': - vars[name] = None - vartypes[name] = NoneType + vars[key] = None + vartypes[key] = NoneType @wrap_dbapi_exceptions def set_transaction_mode(provider, connection, cache): From 329c8e7fc38eb0003a2de8dfbd93a80e07a31c71 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 28 Feb 2018 17:24:15 +0300 Subject: [PATCH 314/547] Bug in utils.get_lambda_args() fixed --- pony/utils/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/utils/utils.py b/pony/utils/utils.py index aa71bcd0d..9bfce9d8d 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -158,7 +158,7 @@ def get_lambda_args(func): if type(func) is types.FunctionType: if hasattr(inspect, 'signature'): - names, argsname, kwname, defaults = [], None, None, None + names, argsname, kwname, defaults = [], None, None, [] for p in inspect.signature(func).parameters.values(): if p.default is not p.empty: defaults.append(p.default) From 841c1f0df3f8b57697c029eb2a3be741cb9d63bc Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 28 Feb 2018 17:27:45 +0300 Subject: [PATCH 315/547] Add postLambda() method to PythonTranslator --- pony/orm/asttranslation.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 9f8f74848..198cb5ca3 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -79,6 +79,18 @@ def postGenExprIf(translator, node): return 'if %s' % node.test.src def postIfExp(translator, node): return '%s if %s else %s' % (node.then.src, node.test.src, node.else_.src) + def postLambda(translator, node): + argnames = list(node.argnames) + kwargs_name = argnames.pop() if node.kwargs else None + varargs_name = argnames.pop() if node.varargs else None + def_argnames = argnames[-len(node.defaults):] if node.defaults else [] + nodef_argnames = argnames[:-len(node.defaults)] if node.defaults else argnames + args = ', '.join(nodef_argnames) + d_args = ', '.join('%s=%s' % (argname, default.src) for argname, default in zip(def_argnames, node.defaults)) + v_arg = '*%s' % varargs_name if varargs_name else None + kw_arg = '**%s' % kwargs_name if kwargs_name else None + args = ', '.join(x for x in [args, d_args, v_arg, kw_arg] if x) + return 'lambda %s: %s' % (args, node.code.src) @priority(14) def postOr(translator, node): return ' or '.join(expr.src for expr in node.nodes) From 7fdf8e4a0b7a36e5a77d251df175e4ff3610aca4 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 28 Feb 2018 17:29:01 +0300 Subject: [PATCH 316/547] Fix Decompiler.MAKE_FUNCTION --- pony/orm/decompiling.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index 71fda2abb..6a8b3a6f1 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, izip, xrange -import sys, types +import sys, types, inspect from opcode import opname as opnames, HAVE_ARGUMENT, EXTENDED_ARG, cmp_op from opcode import hasconst, hasname, hasjrel, haslocal, hascompare, hasfree @@ -363,24 +363,39 @@ def MAKE_CLOSURE(decompiler, argc): return decompiler.MAKE_FUNCTION(argc) def MAKE_FUNCTION(decompiler, argc): + defaults = [] + flags = 0 if sys.version_info >= (3, 6): - if argc: - if argc != 0x08: throw(NotImplementedError, argc) qualname = decompiler.stack.pop() tos = decompiler.stack.pop() - if (argc & 0x08): func_closure = decompiler.stack.pop() + if argc & 0x08: + func_closure = decompiler.stack.pop() + if argc & 0x04: + annotations = decompiler.stack.pop() + if argc & 0x02: + kwonly_defaults = decompiler.stack.pop() + if argc & 0x01: + defaults = decompiler.stack.pop() + throw(NotImplementedError) else: - if argc: throw(NotImplementedError) + if not PY2: + qualname = decompiler.stack.pop() tos = decompiler.stack.pop() - if not PY2: tos = decompiler.stack.pop() + if argc: + defaults = [ decompiler.stack.pop() for i in range(argc) ] + defaults.reverse() codeobject = tos.value func_decompiler = Decompiler(codeobject) # decompiler.names.update(decompiler.names) ??? if codeobject.co_varnames[:1] == ('.0',): return func_decompiler.ast # generator - argnames = codeobject.co_varnames[:codeobject.co_argcount] - defaults = [] # todo - flags = 0 # todo + argnames, varargs, keywords = inspect.getargs(codeobject) + if varargs: + argnames.append(varargs) + flags |= inspect.CO_VARARGS + if keywords: + argnames.append(keywords) + flags |= inspect.CO_VARKEYWORDS return ast.Lambda(argnames, defaults, flags, func_decompiler.ast) POP_JUMP_IF_FALSE = JUMP_IF_FALSE From 2cf9e703d66bacd1d18e175e783056364baad35d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 11 Apr 2018 15:25:09 +0300 Subject: [PATCH 317/547] Call base class __init__ from BoolMonad and BoolExprMonad --- pony/orm/sqltranslation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 7979be797..29f08bd0d 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1882,16 +1882,14 @@ class DatetimeConstMonad(DatetimeMixin, ConstMonad): pass class BoolMonad(Monad): def __init__(monad, translator): - monad.translator = translator - monad.type = bool + Monad.__init__(monad, translator, bool) sql_negation = { 'IN' : 'NOT_IN', 'EXISTS' : 'NOT_EXISTS', 'LIKE' : 'NOT_LIKE', 'BETWEEN' : 'NOT_BETWEEN', 'IS_NULL' : 'IS_NOT_NULL' } sql_negation.update((value, key) for key, value in items_list(sql_negation)) class BoolExprMonad(BoolMonad): def __init__(monad, translator, sql): - monad.translator = translator - monad.type = bool + BoolMonad.__init__(monad, translator) monad.sql = sql def getsql(monad, subquery=None): return [ monad.sql ] From 876c04a67101c06f8e786837289bab4df06c90ad Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Mon, 16 Apr 2018 18:30:21 +0300 Subject: [PATCH 318/547] Minor refactoring & typo fixed --- pony/orm/core.py | 22 +++++++++---------- .../tests/test_declarative_orderby_limit.py | 2 +- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 394d650a2..b5ca8d9fd 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5712,17 +5712,17 @@ def _apply_kwargs(query, kwargs, original_names=False): _next_kwarg_id=next_id, _vars=new_vars) @cut_traceback def __getitem__(query, key): - if isinstance(key, slice): - step = key.step - if step is not None and step != 1: throw(TypeError, "Parameter 'step' of slice object is not allowed here") - start = key.start - if start is None: start = 0 - elif start < 0: throw(TypeError, "Parameter 'start' of slice object cannot be negative") - stop = key.stop - if stop is None: - if not start: return query._fetch() - else: throw(TypeError, "Parameter 'stop' of slice object should be specified") - else: throw(TypeError, 'If you want apply index to query, convert it to list first') + if not isinstance(key, slice): + throw(TypeError, 'If you want apply index to a query, convert it to list first') + step = key.step + if step is not None and step != 1: throw(TypeError, "Parameter 'step' of slice object is not allowed here") + start = key.start + if start is None: start = 0 + elif start < 0: throw(TypeError, "Parameter 'start' of slice object cannot be negative") + stop = key.stop + if stop is None: + if not start: return query._fetch() + else: throw(TypeError, "Parameter 'stop' of slice object should be specified") if start >= stop: return [] return query._fetch(range=(start, stop)) @cut_traceback diff --git a/pony/orm/tests/test_declarative_orderby_limit.py b/pony/orm/tests/test_declarative_orderby_limit.py index fe36e1f1a..39aeca96d 100644 --- a/pony/orm/tests/test_declarative_orderby_limit.py +++ b/pony/orm/tests/test_declarative_orderby_limit.py @@ -82,7 +82,7 @@ def test11(self): def test12(self): students = select(s for s in Student).order_by(Student.id)[-3:2] - @raises_exception(TypeError, 'If you want apply index to query, convert it to list first') + @raises_exception(TypeError, 'If you want apply index to a query, convert it to list first') def test13(self): students = select(s for s in Student).order_by(Student.id)[3] self.assertEqual(students, Student[4]) From f8c110f7ba96bf5b7c09221dfab5c7a82e107ee4 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 30 May 2018 17:40:16 +0300 Subject: [PATCH 319/547] Remove unused code --- pony/orm/dbproviders/oracle.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index d41448206..b35d28c12 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -124,11 +124,6 @@ class OraTranslator(sqltranslation.SQLTranslator): NoneMonad = OraNoneMonad ConstMonad = OraConstMonad - @classmethod - def get_normalized_type_of(translator, value): - if value == '': return NoneType - return sqltranslation.SQLTranslator.get_normalized_type_of(value) - class OraBuilder(SQLBuilder): dialect = 'Oracle' def INSERT(builder, table_name, columns, values, returning=None): From b8b51e74e20afc2f76e670c28fdf415d5fac48db Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 28 Feb 2018 18:21:19 +0300 Subject: [PATCH 320/547] Throw exception on select(select(...) for x in ...), as we cannot translate it correctly to SQL --- pony/orm/sqltranslation.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 29f08bd0d..64d48e14b 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -288,8 +288,6 @@ def __init__(translator, tree, filter_num, extractors, vartypes, parent_translat if isinstance(expr_type, EntityMeta): monad.orderby_columns = list(xrange(1, len(expr_type._pk_columns_)+1)) if monad.aggregated: throw(TranslationError) - if translator.aggregated: translator.groupby_monads = [ monad ] - else: translator.distinct |= monad.requires_distinct() if isinstance(monad, translator.ObjectMixin): entity = monad.type tableref = monad.tableref @@ -297,6 +295,10 @@ def __init__(translator, tree, filter_num, extractors, vartypes, parent_translat entity = monad.type.item_type tableref = monad.make_tableref(translator.subquery) else: assert False # pragma: no cover + if translator.aggregated: + translator.groupby_monads = [ monad ] + else: + translator.distinct |= monad.requires_distinct() translator.tableref = tableref pk_only = parent_translator is not None or translator.aggregated alias, pk_columns = tableref.make_join(pk_only=pk_only) @@ -2715,6 +2717,8 @@ def __init__(monad, translator, subtranslator): monad.item_type = item_type monad_type = SetType(item_type) Monad.__init__(monad, translator, monad_type) + def requires_distinct(monad, joined=False): + assert False def contains(monad, item, not_in=False): translator = monad.translator check_comparable(item, monad, 'in') From 81b92d8e83e7f2a6a402eada02b4e652cae47ad1 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 26 Feb 2018 17:36:04 +0300 Subject: [PATCH 321/547] Remove monad references from SQLTranslator --- pony/orm/dbproviders/sqlite.py | 7 +- pony/orm/sqltranslation.py | 370 ++++++++++++++++----------------- 2 files changed, 185 insertions(+), 192 deletions(-) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index e18a2065a..d8d20e82f 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -12,9 +12,10 @@ from binascii import hexlify from functools import wraps -from pony.orm import core, dbschema, sqltranslation, dbapiprovider +from pony.orm import core, dbschema, dbapiprovider from pony.orm.core import log_orm from pony.orm.ormtypes import Json +from pony.orm.sqltranslation import SQLTranslator, StringExprMonad from pony.orm.sqlbuilding import SQLBuilder, join, make_unary_func from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions from pony.utils import datetime2timestamp, timestamp2datetime, absolutize_path, localbase, throw, reraise, \ @@ -39,12 +40,12 @@ def func(translator, monad): sql = monad.getsql() assert len(sql) == 1 translator = monad.translator - return translator.StringExprMonad(translator, monad.type, [ sqlop, sql[0] ]) + return StringExprMonad(translator, monad.type, [ sqlop, sql[0] ]) func.__name__ = sqlop return func -class SQLiteTranslator(sqltranslation.SQLTranslator): +class SQLiteTranslator(SQLTranslator): dialect = 'SQLite' sqlite_version = sqlite.sqlite_version_info row_value_syntax = False diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 64d48e14b..5bf8a0deb 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -87,39 +87,39 @@ def dispatch_external(translator, node): t = translator.vartypes[varkey] tt = type(t) if t is NoneType: - monad = translator.ConstMonad.new(translator, None) + monad = ConstMonad.new(translator, None) elif tt is SetType: if isinstance(t.item_type, EntityMeta): - monad = translator.EntityMonad(translator, t.item_type) + monad = EntityMonad(translator, t.item_type) else: throw(NotImplementedError) # pragma: no cover elif tt is FuncType: func = t.func - func_monad_class = translator.registered_functions.get(func, translator.ErrorSpecialFuncMonad) + func_monad_class = translator.registered_functions.get(func, ErrorSpecialFuncMonad) monad = func_monad_class(translator, func) elif tt is MethodType: obj, func = t.obj, t.func if isinstance(obj, EntityMeta): - entity_monad = translator.EntityMonad(translator, obj) + entity_monad = EntityMonad(translator, obj) if obj.__class__.__dict__.get(func.__name__) is not func: throw(NotImplementedError) - monad = translator.MethodMonad(entity_monad, func.__name__) + monad = MethodMonad(entity_monad, func.__name__) elif node.src == 'random': # For PyPy - monad = translator.FuncRandomMonad(translator, t) + monad = FuncRandomMonad(translator, t) else: throw(NotImplementedError) elif isinstance(node, ast.Name) and node.name in ('True', 'False'): value = True if node.name == 'True' else False - monad = translator.ConstMonad.new(translator, value) + monad = ConstMonad.new(translator, value) elif tt is tuple: params = [] for i, item_type in enumerate(t): if item_type is NoneType: throw(TypeError, 'Expression `%s` should not contain None values' % node.src) - param = translator.ParamMonad.new(translator, item_type, (varkey, i, None)) + param = ParamMonad.new(translator, item_type, (varkey, i, None)) params.append(param) - monad = translator.ListMonad(translator, params) + monad = ListMonad(translator, params) elif isinstance(t, RawSQLType): - monad = translator.RawSQLMonad(translator, t, varkey) + monad = RawSQLMonad(translator, t, varkey) else: - monad = translator.ParamMonad.new(translator, t, (varkey, None, None)) + monad = ParamMonad.new(translator, t, (varkey, None, None)) node.monad = monad monad.node = node monad.aggregated = monad.nogroup = False @@ -221,7 +221,7 @@ def __init__(translator, tree, filter_num, extractors, vartypes, parent_translat tableref = TableRef(subquery, name, entity) tablerefs[name] = tableref tableref.make_join() - node.monad = translator.ObjectIterMonad(translator, tableref, entity) + node.monad = ObjectIterMonad(translator, tableref, entity) else: attr_names = [] while isinstance(node, ast.Getattr): @@ -270,7 +270,7 @@ def __init__(translator, tree, filter_num, extractors, vartypes, parent_translat for if_ in qual.ifs: assert isinstance(if_, ast.GenExprIf) translator.dispatch(if_) - if isinstance(if_.monad, translator.AndMonad): cond_monads = if_.monad.operands + if isinstance(if_.monad, AndMonad): cond_monads = if_.monad.operands else: cond_monads = [ if_.monad ] for m in cond_monads: if not m.aggregated: translator.conditions.extend(m.getsql()) @@ -279,19 +279,19 @@ def __init__(translator, tree, filter_num, extractors, vartypes, parent_translat translator.dispatch(tree.expr) assert not translator.hint_join monad = tree.expr.monad - if isinstance(monad, translator.ParamMonad): throw(TranslationError, + if isinstance(monad, ParamMonad): throw(TranslationError, "External parameter '%s' cannot be used as query result" % ast2src(tree.expr)) - translator.expr_monads = monad.items if isinstance(monad, translator.ListMonad) else [ monad ] + translator.expr_monads = monad.items if isinstance(monad, ListMonad) else [ monad ] translator.groupby_monads = None expr_type = monad.type if isinstance(expr_type, SetType): expr_type = expr_type.item_type if isinstance(expr_type, EntityMeta): monad.orderby_columns = list(xrange(1, len(expr_type._pk_columns_)+1)) if monad.aggregated: throw(TranslationError) - if isinstance(monad, translator.ObjectMixin): + if isinstance(monad, ObjectMixin): entity = monad.type tableref = monad.tableref - elif isinstance(monad, translator.AttrSetMonad): + elif isinstance(monad, AttrSetMonad): entity = monad.type.item_type tableref = monad.make_tableref(translator.subquery) else: assert False # pragma: no cover @@ -472,8 +472,8 @@ def ast_transformer(ast): having_conditions = translator.having_conditions[:] if is_not_null_checks: for monad in translator.expr_monads: - if isinstance(monad, translator.ObjectIterMonad): pass - elif isinstance(monad, translator.AttrMonad) and not monad.attr.nullable: pass + if isinstance(monad, ObjectIterMonad): pass + elif isinstance(monad, AttrMonad) and not monad.attr.nullable: pass else: notnull_conditions = [ [ 'IS_NOT_NULL', column_ast ] for column_ast in monad.getsql() ] if monad.aggregated: having_conditions.extend(notnull_conditions) @@ -607,12 +607,12 @@ def apply_kwfilters(translator, filterattrs, original_names=False): throw(TypeError, 'Keyword arguments are not allowed when query result is not entity objects') monads = [] - none_monad = translator.NoneMonad(translator) + none_monad = NoneMonad(translator) for attr, id, is_none in filterattrs: attr_monad = object_monad.getattr(attr.name) if is_none: monads.append(CmpMonad('is', attr_monad, none_monad)) else: - param_monad = translator.ParamMonad.new(translator, attr.py_type, (id, None, None)) + param_monad = ParamMonad.new(translator, attr.py_type, (id, None, None)) monads.append(CmpMonad('==', attr_monad, param_monad)) for m in monads: translator.conditions.extend(m.getsql()) return translator @@ -631,7 +631,7 @@ def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_ translator.inside_order_by = True new_order = [] for node in nodes: - if isinstance(node.monad, translator.SetMixin): + if isinstance(node.monad, SetMixin): t = node.monad.type.item_type if isinstance(type(t), type): t = t.__name__ throw(TranslationError, 'Set of %s (%s) cannot be used for ordering' @@ -642,7 +642,7 @@ def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_ else: for node in nodes: monad = node.monad - if isinstance(monad, translator.AndMonad): cond_monads = monad.operands + if isinstance(monad, AndMonad): cond_monads = monad.operands else: cond_monads = [ monad ] for m in cond_monads: if not m.aggregated: translator.conditions.extend(m.getsql()) @@ -652,7 +652,7 @@ def preGenExpr(translator, node): inner_tree = node.code translator_cls = translator.__class__ subtranslator = translator_cls(inner_tree, translator.filter_num, translator.extractors, translator.vartypes, translator) - return translator.QuerySetMonad(translator, subtranslator) + return QuerySetMonad(translator, subtranslator) def postGenExprIf(translator, node): monad = node.test.monad if monad.type is not bool: monad = monad.nonzero() @@ -677,21 +677,21 @@ def preCompare(translator, node): monads.append(monad) left = right if len(monads) == 1: return monads[0] - return translator.AndMonad(monads) + return AndMonad(monads) def postConst(translator, node): value = node.value if type(value) is frozenset: value = tuple(sorted(value)) if type(value) is not tuple: - return translator.ConstMonad.new(translator, value) + return ConstMonad.new(translator, value) else: - return translator.ListMonad(translator, [ translator.ConstMonad.new(translator, item) for item in value ]) + return ListMonad(translator, [ ConstMonad.new(translator, item) for item in value ]) def postEllipsis(translator, node): - return translator.ConstMonad.new(translator, Ellipsis) + return ConstMonad.new(translator, Ellipsis) def postList(translator, node): - return translator.ListMonad(translator, [ item.monad for item in node.nodes ]) + return ListMonad(translator, [ item.monad for item in node.nodes ]) def postTuple(translator, node): - return translator.ListMonad(translator, [ item.monad for item in node.nodes ]) + return ListMonad(translator, [ item.monad for item in node.nodes ]) def postName(translator, node): name = node.name t = translator @@ -703,7 +703,7 @@ def postName(translator, node): t = t.parent tableref = translator.subquery.get_tableref(name) if tableref is not None: - return translator.ObjectIterMonad(translator, tableref, tableref.entity) + return ObjectIterMonad(translator, tableref, tableref.entity) else: assert False, name # pragma: no cover def postAdd(translator, node): return node.left.monad + node.right.monad @@ -724,9 +724,9 @@ def postUnarySub(translator, node): def postGetattr(translator, node): return node.expr.monad.getattr(node.attrname) def postAnd(translator, node): - return translator.AndMonad([ subnode.monad for subnode in node.nodes ]) + return AndMonad([ subnode.monad for subnode in node.nodes ]) def postOr(translator, node): - return translator.OrMonad([ subnode.monad for subnode in node.nodes ]) + return OrMonad([ subnode.monad for subnode in node.nodes ]) def postBitor(translator, node): left, right = (subnode.monad for subnode in node.nodes) return left | right @@ -776,7 +776,7 @@ def preCallFunc(translator, node): inner_expr = ast.GenExprInner(ast.Name(iter_name), [ for_expr ]) translator_cls = translator.__class__ subtranslator = translator_cls(inner_expr, translator.filter_num, translator.extractors, translator.vartypes, translator) - return translator.QuerySetMonad(translator, subtranslator) + return QuerySetMonad(translator, subtranslator) def postCallFunc(translator, node): args = [] kwargs = {} @@ -796,7 +796,7 @@ def postSubscript(translator, node): if len(node.subs) > 1: for x in node.subs: if isinstance(x, ast.Sliceobj): throw(TypeError) - key = translator.ListMonad(translator, [ item.monad for item in node.subs ]) + key = ListMonad(translator, [ item.monad for item in node.subs ]) return node.expr.monad[key] sub = node.subs[0] if isinstance(sub, ast.Sliceobj): @@ -830,7 +830,7 @@ def postIfExp(translator, node): elif not translator.row_value_syntax: throw(NotImplementedError) else: then_sql, else_sql = [ 'ROW' ] + then_sql, [ 'ROW' ] + else_sql expr = [ 'CASE', None, [ [ test_sql, then_sql ] ], else_sql ] - result = translator.ExprMonad.new(translator, result_type, expr) + result = ExprMonad.new(translator, result_type, expr) result.aggregated = test_monad.aggregated or then_monad.aggregated or else_monad.aggregated return result @@ -1027,18 +1027,17 @@ def __init__(monad, translator, type): def mixin_init(monad): pass def cmp(monad, op, monad2): - return monad.translator.CmpMonad(op, monad, monad2) + return CmpMonad(op, monad, monad2) def contains(monad, item, not_in=False): throw(TypeError) def nonzero(monad): throw(TypeError) def negate(monad): - return monad.translator.NotMonad(monad) + return NotMonad(monad) def getattr(monad, attrname): try: property_method = getattr(monad, 'attr_' + attrname) except AttributeError: if not hasattr(monad, 'call_' + attrname): throw(AttributeError, '%r object has no attribute %r: {EXPR}' % (type2str(monad.type), attrname)) - translator = monad.translator - return translator.MethodMonad(monad, attrname) + return MethodMonad(monad, attrname) return property_method() def len(monad): throw(TypeError) def count(monad, distinct=None): @@ -1066,7 +1065,7 @@ def count(monad, distinct=None): '%s database provider does not support entities ' 'with composite primary keys inside aggregate functions. Got: {EXPR}' % translator.dialect) - result = translator.ExprMonad.new(translator, int, [ 'COUNT', distinct, expr ]) + result = ExprMonad.new(translator, int, [ 'COUNT', distinct, expr ]) result.aggregated = True return result def aggregate(monad, func_name, distinct=None, sep=None): @@ -1108,7 +1107,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): if func_name == 'GROUP_CONCAT': if sep is not None: aggr_ast.append(['VALUE', sep]) - result = translator.ExprMonad.new(translator, result_type, aggr_ast) + result = ExprMonad.new(translator, result_type, aggr_ast) result.aggregated = True return result def __call__(monad, *args, **kwargs): throw(TypeError) @@ -1153,7 +1152,7 @@ def contains(monad, item, not_in=False): '%s database provider does not support tuples. Got: {EXPR} ' % translator.dialect) op = 'NOT_IN' if not_in else 'IN' sql = [ op, expr, monad.getsql() ] - return translator.BoolExprMonad(translator, sql) + return BoolExprMonad(translator, sql) def nonzero(monad): return monad def getsql(monad, subquery=None): provider = monad.translator.database.provider @@ -1263,7 +1262,7 @@ def contains(monad, x, not_in=False): sql = sqland([ sqlor([ [ 'NE', a, b ] for a, b in izip(left_sql, item.getsql()) ]) for item in monad.items ]) else: sql = sqlor([ sqland([ [ 'EQ', a, b ] for a, b in izip(left_sql, item.getsql()) ]) for item in monad.items ]) - return translator.BoolExprMonad(translator, sql) + return BoolExprMonad(translator, sql) def getsql(monad, subquery=None): return [ [ 'ROW' ] + [ item.getsql()[0] for item in monad.items ] ] @@ -1278,15 +1277,15 @@ class UuidMixin(MonadMixin): def make_numeric_binop(op, sqlop): def numeric_binop(monad, monad2): translator = monad.translator - if isinstance(monad2, (translator.AttrSetMonad, translator.NumericSetExprMonad)): - return translator.NumericSetExprMonad(op, sqlop, monad, monad2) + if isinstance(monad2, (AttrSetMonad, NumericSetExprMonad)): + return NumericSetExprMonad(op, sqlop, monad, monad2) if monad2.type == 'METHOD': raise_forgot_parentheses(monad2) result_type, monad, monad2 = coerce_monads(monad, monad2) if result_type is None: throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), op)) left_sql = monad.getsql()[0] right_sql = monad2.getsql()[0] - return translator.NumericExprMonad(translator, result_type, [ sqlop, left_sql, right_sql ]) + return NumericExprMonad(translator, result_type, [ sqlop, left_sql, right_sql ]) numeric_binop.__name__ = sqlop return numeric_binop @@ -1301,36 +1300,36 @@ def mixin_init(monad): __mod__ = make_numeric_binop('%', 'MOD') def __pow__(monad, monad2): translator = monad.translator - if not isinstance(monad2, translator.NumericMixin): + if not isinstance(monad2, NumericMixin): throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), '**')) left_sql = monad.getsql() right_sql = monad2.getsql() assert len(left_sql) == len(right_sql) == 1 - return translator.NumericExprMonad(translator, float, [ 'POW', left_sql[0], right_sql[0] ]) + return NumericExprMonad(translator, float, [ 'POW', left_sql[0], right_sql[0] ]) def __neg__(monad): sql = monad.getsql()[0] translator = monad.translator - return translator.NumericExprMonad(translator, monad.type, [ 'NEG', sql ]) + return NumericExprMonad(translator, monad.type, [ 'NEG', sql ]) def abs(monad): sql = monad.getsql()[0] translator = monad.translator - return translator.NumericExprMonad(translator, monad.type, [ 'ABS', sql ]) + return NumericExprMonad(translator, monad.type, [ 'ABS', sql ]) def nonzero(monad): translator = monad.translator - return translator.CmpMonad('!=', monad, translator.ConstMonad.new(translator, 0)) + return CmpMonad('!=', monad, ConstMonad.new(translator, 0)) def negate(monad): translator = monad.translator - result = translator.CmpMonad('==', monad, translator.ConstMonad.new(translator, 0)) - if isinstance(monad, translator.AttrMonad) and not monad.attr.nullable: + result = CmpMonad('==', monad, ConstMonad.new(translator, 0)) + if isinstance(monad, AttrMonad) and not monad.attr.nullable: return result sql = [ 'OR', result.getsql()[0], [ 'IS_NULL', monad.getsql()[0] ] ] - return translator.BoolExprMonad(translator, sql) + return BoolExprMonad(translator, sql) def numeric_attr_factory(name): def attr_func(monad): sql = [ name, monad.getsql()[0] ] translator = monad.translator - return translator.NumericExprMonad(translator, int, sql) + return NumericExprMonad(translator, int, sql) attr_func.__name__ = name.lower() return attr_func @@ -1339,7 +1338,7 @@ def datetime_binop(monad, monad2): translator = monad.translator if monad2.type != timedelta: throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), op)) - expr_monad_cls = translator.DateExprMonad if monad.type is date else translator.DatetimeExprMonad + expr_monad_cls = DateExprMonad if monad.type is date else DatetimeExprMonad delta = monad2.value if isinstance(monad2, TimedeltaConstMonad) else monad2.getsql()[0] return expr_monad_cls(translator, monad.type, [ sqlop, monad.getsql()[0], delta ]) datetime_binop.__name__ = sqlop @@ -1371,7 +1370,7 @@ def mixin_init(monad): def call_date(monad): translator = monad.translator sql = [ 'DATE', monad.getsql()[0] ] - return translator.ExprMonad.new(translator, date, sql) + return ExprMonad.new(translator, date, sql) attr_hour = numeric_attr_factory('HOUR') attr_minute = numeric_attr_factory('MINUTE') attr_second = numeric_attr_factory('SECOND') @@ -1387,7 +1386,7 @@ def string_binop(monad, monad2): left_sql = monad.getsql() right_sql = monad2.getsql() assert len(left_sql) == len(right_sql) == 1 - return translator.StringExprMonad(translator, monad.type, [ sqlop, left_sql[0], right_sql[0] ]) + return StringExprMonad(translator, monad.type, [ sqlop, left_sql[0], right_sql[0] ]) string_binop.__name__ = sqlop return string_binop @@ -1396,7 +1395,7 @@ def func(monad): sql = monad.getsql() assert len(sql) == 1 translator = monad.translator - return translator.StringExprMonad(translator, monad.type, [ sqlop, sql[0] ]) + return StringExprMonad(translator, monad.type, [ sqlop, sql[0] ]) func.__name__ = sqlop return func @@ -1406,17 +1405,17 @@ def mixin_init(monad): __add__ = make_string_binop('+', 'CONCAT') def __getitem__(monad, index): translator = monad.translator - if isinstance(index, translator.ListMonad): throw(TypeError, "String index must be of 'int' type. Got 'tuple' in {EXPR}") + if isinstance(index, ListMonad): throw(TypeError, "String index must be of 'int' type. Got 'tuple' in {EXPR}") elif isinstance(index, slice): if index.step is not None: throw(TypeError, 'Step is not supported in {EXPR}') start, stop = index.start, index.stop if start is None and stop is None: return monad - if isinstance(monad, translator.StringConstMonad) \ - and (start is None or isinstance(start, translator.NumericConstMonad)) \ - and (stop is None or isinstance(stop, translator.NumericConstMonad)): + if isinstance(monad, StringConstMonad) \ + and (start is None or isinstance(start, NumericConstMonad)) \ + and (stop is None or isinstance(stop, NumericConstMonad)): if start is not None: start = start.value if stop is not None: stop = stop.value - return translator.ConstMonad.new(translator, monad.value[start:stop]) + return ConstMonad.new(translator, monad.value[start:stop]) if start is not None and start.type is not int: throw(TypeError, "Invalid type of start index (expected 'int', got %r) in string slice {EXPR}" % type2str(start.type)) @@ -1424,9 +1423,9 @@ def __getitem__(monad, index): throw(TypeError, "Invalid type of stop index (expected 'int', got %r) in string slice {EXPR}" % type2str(stop.type)) expr_sql = monad.getsql()[0] - if start is None: start = translator.ConstMonad.new(translator, 0) + if start is None: start = ConstMonad.new(translator, 0) - if isinstance(start, translator.NumericConstMonad): + if isinstance(start, NumericConstMonad): if start.value < 0: throw(NotImplementedError, 'Negative indices are not supported in string slice {EXPR}') start_sql = [ 'VALUE', start.value + 1 ] else: @@ -1435,28 +1434,28 @@ def __getitem__(monad, index): if stop is None: len_sql = None - elif isinstance(stop, translator.NumericConstMonad): + elif isinstance(stop, NumericConstMonad): if stop.value < 0: throw(NotImplementedError, 'Negative indices are not supported in string slice {EXPR}') - if isinstance(start, translator.NumericConstMonad): + if isinstance(start, NumericConstMonad): len_sql = [ 'VALUE', stop.value - start.value ] else: len_sql = [ 'SUB', [ 'VALUE', stop.value ], start.getsql()[0] ] else: stop_sql = stop.getsql()[0] - if isinstance(start, translator.NumericConstMonad): + if isinstance(start, NumericConstMonad): len_sql = [ 'SUB', stop_sql, [ 'VALUE', start.value ] ] else: len_sql = [ 'SUB', stop_sql, start.getsql()[0] ] sql = [ 'SUBSTR', expr_sql, start_sql, len_sql ] - return translator.StringExprMonad(translator, monad.type, sql) + return StringExprMonad(translator, monad.type, sql) - if isinstance(monad, translator.StringConstMonad) and isinstance(index, translator.NumericConstMonad): - return translator.ConstMonad.new(translator, monad.value[index.value]) + if isinstance(monad, StringConstMonad) and isinstance(index, NumericConstMonad): + return ConstMonad.new(translator, monad.value[index.value]) if index.type is not int: throw(TypeError, 'String indices must be integers. Got %r in expression {EXPR}' % type2str(index.type)) expr_sql = monad.getsql()[0] - if isinstance(index, translator.NumericConstMonad): + if isinstance(index, NumericConstMonad): value = index.value if value >= 0: value += 1 index_sql = [ 'VALUE', value ] @@ -1464,23 +1463,23 @@ def __getitem__(monad, index): inner_sql = index.getsql()[0] index_sql = [ 'ADD', inner_sql, [ 'CASE', None, [ (['GE', inner_sql, [ 'VALUE', 0 ]], [ 'VALUE', 1 ]) ], [ 'VALUE', 0 ] ] ] sql = [ 'SUBSTR', expr_sql, index_sql, [ 'VALUE', 1 ] ] - return translator.StringExprMonad(translator, monad.type, sql) + return StringExprMonad(translator, monad.type, sql) def negate(monad): sql = monad.getsql()[0] translator = monad.translator - result = translator.BoolExprMonad(translator, [ 'EQ', [ 'LENGTH', sql ], [ 'VALUE', 0 ]]) + result = BoolExprMonad(translator, [ 'EQ', [ 'LENGTH', sql ], [ 'VALUE', 0 ]]) result.aggregated = monad.aggregated return result def nonzero(monad): sql = monad.getsql()[0] translator = monad.translator - result = translator.BoolExprMonad(translator, [ 'GT', [ 'LENGTH', sql ], [ 'VALUE', 0 ]]) + result = BoolExprMonad(translator, [ 'GT', [ 'LENGTH', sql ], [ 'VALUE', 0 ]]) result.aggregated = monad.aggregated return result def len(monad): sql = monad.getsql()[0] translator = monad.translator - return translator.NumericExprMonad(translator, int, [ 'LENGTH', sql ]) + return NumericExprMonad(translator, int, [ 'LENGTH', sql ]) def contains(monad, item, not_in=False): check_comparable(item, monad, 'LIKE') return monad._like(item, before='%', after='%', not_like=not_in) @@ -1501,7 +1500,7 @@ def call_endswith(monad, arg): def _like(monad, item, before=None, after=None, not_like=False): escape = False translator = monad.translator - if isinstance(item, translator.StringConstMonad): + if isinstance(item, StringConstMonad): value = item.value if '%' in value or '_' in value: escape = True @@ -1520,7 +1519,7 @@ def _like(monad, item, before=None, after=None, not_like=False): elif after: item_sql = [ 'CONCAT', item_sql, [ 'VALUE', after ] ] sql = [ 'NOT_LIKE' if not_like else 'LIKE', monad.getsql()[0], item_sql ] if escape: sql.append([ 'VALUE', '!' ]) - return translator.BoolExprMonad(translator, sql) + return BoolExprMonad(translator, sql) def strip(monad, chars, strip_type): translator = monad.translator if chars is not None and not are_comparable_types(monad.type, chars.type, None): @@ -1530,7 +1529,7 @@ def strip(monad, chars, strip_type): parent_sql = monad.getsql()[0] sql = [ strip_type, parent_sql ] if chars is not None: sql.append(chars.getsql()[0]) - return translator.StringExprMonad(translator, monad.type, sql) + return StringExprMonad(translator, monad.type, sql) def call_strip(monad, chars=None): return monad.strip(chars, 'TRIM') def call_lstrip(monad, chars=None): @@ -1547,7 +1546,7 @@ def mixin_init(monad): def get_path(monad): return monad, [] def __getitem__(monad, key): - return monad.translator.JsonItemMonad(monad, key) + return JsonItemMonad(monad, key) def contains(monad, key, not_in=False): translator = monad.translator if isinstance(key, ParamMonad): @@ -1560,35 +1559,35 @@ def contains(monad, key, not_in=False): key_sql = key.getsql()[0] sql = [ 'JSON_CONTAINS', base_sql, path, key_sql ] if not_in: sql = [ 'NOT', sql ] - return translator.BoolExprMonad(translator, sql) + return BoolExprMonad(translator, sql) def __or__(monad, other): translator = monad.translator - if not isinstance(other, translator.JsonMixin): + if not isinstance(other, JsonMixin): raise TypeError('Should be JSON: %s' % ast2src(other.node)) left_sql = monad.getsql()[0] right_sql = other.getsql()[0] sql = [ 'JSON_CONCAT', left_sql, right_sql ] - return translator.JsonExprMonad(translator, Json, sql) + return JsonExprMonad(translator, Json, sql) def len(monad): translator = monad.translator sql = [ 'JSON_ARRAY_LENGTH', monad.getsql()[0] ] - return translator.NumericExprMonad(translator, int, sql) + return NumericExprMonad(translator, int, sql) def cast_from_json(monad, type): if type in (Json, NoneType): return monad throw(TypeError, 'Cannot compare whole JSON value, you need to select specific sub-item: {EXPR}') def nonzero(monad): translator = monad.translator - return translator.BoolExprMonad(translator, [ 'JSON_NONZERO', monad.getsql()[0] ]) + return BoolExprMonad(translator, [ 'JSON_NONZERO', monad.getsql()[0] ]) class ObjectMixin(MonadMixin): def mixin_init(monad): assert isinstance(monad.type, EntityMeta) def negate(monad): translator = monad.translator - return translator.CmpMonad('is', monad, translator.NoneMonad(translator)) + return CmpMonad('is', monad, NoneMonad(translator)) def nonzero(monad): translator = monad.translator - return translator.CmpMonad('is not', monad, translator.NoneMonad(translator)) + return CmpMonad('is not', monad, NoneMonad(translator)) def getattr(monad, name): translator = monad.translator entity = monad.type @@ -1597,9 +1596,9 @@ def getattr(monad, name): 'Entity %s does not have attribute %s: {EXPR}' % (entity.__name__, name)) if hasattr(monad, 'tableref'): monad.tableref.used_attrs.add(attr) if not attr.is_collection: - return translator.AttrMonad.new(monad, attr) + return AttrMonad.new(monad, attr) else: - return translator.AttrSetMonad(monad, attr) + return AttrSetMonad(monad, attr) def requires_distinct(monad, joined=False): return monad.attr.reverse.is_collection or monad.parent.requires_distinct(joined) # parent ??? @@ -1619,16 +1618,16 @@ class AttrMonad(Monad): def new(parent, attr, *args, **kwargs): translator = parent.translator type = normalize_type(attr.py_type) - if type in numeric_types: cls = translator.NumericAttrMonad - elif type is unicode: cls = translator.StringAttrMonad - elif type is date: cls = translator.DateAttrMonad - elif type is time: cls = translator.TimeAttrMonad - elif type is timedelta: cls = translator.TimedeltaAttrMonad - elif type is datetime: cls = translator.DatetimeAttrMonad - elif type is buffer: cls = translator.BufferAttrMonad - elif type is UUID: cls = translator.UuidAttrMonad - elif type is Json: cls = translator.JsonAttrMonad - elif isinstance(type, EntityMeta): cls = translator.ObjectAttrMonad + if type in numeric_types: cls = NumericAttrMonad + elif type is unicode: cls = StringAttrMonad + elif type is date: cls = DateAttrMonad + elif type is time: cls = TimeAttrMonad + elif type is timedelta: cls = TimedeltaAttrMonad + elif type is datetime: cls = DatetimeAttrMonad + elif type is buffer: cls = BufferAttrMonad + elif type is UUID: cls = UuidAttrMonad + elif type is Json: cls = JsonAttrMonad + elif isinstance(type, EntityMeta): cls = ObjectAttrMonad else: throw(NotImplementedError, type) # pragma: no cover return cls(parent, attr, *args, **kwargs) def __new__(cls, *args): @@ -1680,14 +1679,14 @@ def negate(monad): result_sql = [ 'EQ', [ 'LENGTH', sql ], [ 'VALUE', 0 ] ] if monad.attr.nullable: result_sql = [ 'OR', result_sql, [ 'IS_NULL', sql ] ] - result = translator.BoolExprMonad(translator, result_sql) + result = BoolExprMonad(translator, result_sql) result.aggregated = monad.aggregated return result def nonzero(monad): sql = monad.getsql()[0] translator = monad.translator result_sql = [ 'GT', [ 'LENGTH', sql ], [ 'VALUE', 0 ] ] - result = translator.BoolExprMonad(translator, result_sql) + result = BoolExprMonad(translator, result_sql) result.aggregated = monad.aggregated return result @@ -1704,16 +1703,16 @@ class ParamMonad(Monad): @staticmethod def new(translator, type, paramkey): type = normalize_type(type) - if type in numeric_types: cls = translator.NumericParamMonad - elif type is unicode: cls = translator.StringParamMonad - elif type is date: cls = translator.DateParamMonad - elif type is time: cls = translator.TimeParamMonad - elif type is timedelta: cls = translator.TimedeltaParamMonad - elif type is datetime: cls = translator.DatetimeParamMonad - elif type is buffer: cls = translator.BufferParamMonad - elif type is UUID: cls = translator.UuidParamMonad - elif type is Json: cls = translator.JsonParamMonad - elif isinstance(type, EntityMeta): cls = translator.ObjectParamMonad + if type in numeric_types: cls = NumericParamMonad + elif type is unicode: cls = StringParamMonad + elif type is date: cls = DateParamMonad + elif type is time: cls = TimeParamMonad + elif type is timedelta: cls = TimedeltaParamMonad + elif type is datetime: cls = DatetimeParamMonad + elif type is buffer: cls = BufferParamMonad + elif type is UUID: cls = UuidParamMonad + elif type is Json: cls = JsonParamMonad + elif isinstance(type, EntityMeta): cls = ObjectParamMonad else: throw(NotImplementedError, 'Parameter {EXPR} has unsupported type %r' % (type,)) result = cls(translator, type, paramkey) result.aggregated = False @@ -1762,14 +1761,14 @@ def getsql(monad, subquery=None): class ExprMonad(Monad): @staticmethod def new(translator, type, sql): - if type in numeric_types: cls = translator.NumericExprMonad - elif type is unicode: cls = translator.StringExprMonad - elif type is date: cls = translator.DateExprMonad - elif type is time: cls = translator.TimeExprMonad - elif type is timedelta: cls = translator.TimedeltaExprMonad - elif type is datetime: cls = translator.DatetimeExprMonad - elif type is Json: cls = translator.JsonExprMonad - elif isinstance(type, EntityMeta): cls = translator.ObjectExprMonad + if type in numeric_types: cls = NumericExprMonad + elif type is unicode: cls = StringExprMonad + elif type is date: cls = DateExprMonad + elif type is time: cls = TimeExprMonad + elif type is timedelta: cls = TimedeltaExprMonad + elif type is datetime: cls = DatetimeExprMonad + elif type is Json: cls = JsonExprMonad + elif isinstance(type, EntityMeta): cls = ObjectExprMonad else: throw(NotImplementedError, type) # pragma: no cover return cls(translator, type, sql) def __new__(cls, *args): @@ -1824,7 +1823,7 @@ def cast_from_json(monad, type): return monad base_monad, path = monad.get_path() sql = [ 'JSON_VALUE', base_monad.getsql()[0], path, type ] - return translator.ExprMonad.new(translator, Json if type is NoneType else type, sql) + return ExprMonad.new(translator, Json if type is NoneType else type, sql) def getsql(monad): base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] @@ -1837,16 +1836,16 @@ class ConstMonad(Monad): @staticmethod def new(translator, value): value_type, value = normalize(value) - if value_type in numeric_types: cls = translator.NumericConstMonad - elif value_type is unicode: cls = translator.StringConstMonad - elif value_type is date: cls = translator.DateConstMonad - elif value_type is time: cls = translator.TimeConstMonad - elif value_type is timedelta: cls = translator.TimedeltaConstMonad - elif value_type is datetime: cls = translator.DatetimeConstMonad - elif value_type is NoneType: cls = translator.NoneMonad - elif value_type is buffer: cls = translator.BufferConstMonad - elif value_type is Json: cls = translator.JsonConstMonad - elif issubclass(value_type, type(Ellipsis)): cls = translator.EllipsisMonad + if value_type in numeric_types: cls = NumericConstMonad + elif value_type is unicode: cls = StringConstMonad + elif value_type is date: cls = DateConstMonad + elif value_type is time: cls = TimeConstMonad + elif value_type is timedelta: cls = TimedeltaConstMonad + elif value_type is datetime: cls = DatetimeConstMonad + elif value_type is NoneType: cls = NoneMonad + elif value_type is buffer: cls = BufferConstMonad + elif value_type is Json: cls = JsonConstMonad + elif issubclass(value_type, type(Ellipsis)): cls = EllipsisMonad else: throw(NotImplementedError, value_type) # pragma: no cover result = cls(translator, value) result.aggregated = False @@ -1872,7 +1871,7 @@ class EllipsisMonad(ConstMonad): class StringConstMonad(StringMixin, ConstMonad): def len(monad): - return monad.translator.ConstMonad.new(monad.translator, len(monad.value)) + return ConstMonad.new(monad.translator, len(monad.value)) class JsonConstMonad(JsonMixin, ConstMonad): pass class BufferConstMonad(BufferMixin, ConstMonad): pass @@ -1905,8 +1904,8 @@ def negate(monad): elif negated_op == 'NOT': assert len(sql) == 2 negated_sql = sql[1] - else: return translator.NotMonad(translator, sql) - return translator.BoolExprMonad(translator, negated_sql) + else: return NotMonad(translator, sql) + return BoolExprMonad(translator, negated_sql) cmp_ops = { '>=' : 'GE', '>' : 'GT', '<=' : 'LE', '<' : 'LT' } @@ -1941,7 +1940,7 @@ def __init__(monad, op, left, right): monad.left = left monad.right = right def negate(monad): - return monad.translator.CmpMonad(cmp_negate[monad.op], monad.left, monad.right) + return CmpMonad(cmp_negate[monad.op], monad.left, monad.right) def getsql(monad, subquery=None): op = monad.op left_sql = monad.left.getsql() @@ -1980,7 +1979,7 @@ def __init__(monad, operands): monad.translator = translator for operand in operands: if operand.type is not bool: items.append(operand.nonzero()) - elif isinstance(operand, translator.LogicalBinOpMonad) and monad.binop == operand.binop: + elif isinstance(operand, LogicalBinOpMonad) and monad.binop == operand.binop: items.extend(operand.operands) else: items.append(operand) BoolMonad.__init__(monad, items[0].translator) @@ -2030,9 +2029,9 @@ class FuncMonad(with_metaclass(FuncMonadMeta, Monad)): def __call__(monad, *args, **kwargs): translator = monad.translator for arg in args: - assert isinstance(arg, translator.Monad) + assert isinstance(arg, Monad) for value in kwargs.values(): - assert isinstance(value, translator.Monad) + assert isinstance(value, Monad) try: return monad.call(*args, **kwargs) except TypeError as exc: reraise_improved_typeerror(exc, 'call', monad.type.__name__) @@ -2041,23 +2040,23 @@ class FuncBufferMonad(FuncMonad): func = buffer def call(monad, source, encoding=None, errors=None): translator = monad.translator - if not isinstance(source, translator.StringConstMonad): throw(TypeError) + if not isinstance(source, StringConstMonad): throw(TypeError) source = source.value if encoding is not None: - if not isinstance(encoding, translator.StringConstMonad): throw(TypeError) + if not isinstance(encoding, StringConstMonad): throw(TypeError) encoding = encoding.value if errors is not None: - if not isinstance(errors, translator.StringConstMonad): throw(TypeError) + if not isinstance(errors, StringConstMonad): throw(TypeError) errors = errors.value if PY2: if encoding and errors: source = source.encode(encoding, errors) elif encoding: source = source.encode(encoding) - return translator.ConstMonad.new(translator, buffer(source)) + return ConstMonad.new(translator, buffer(source)) else: if encoding and errors: value = buffer(source, encoding, errors) elif encoding: value = buffer(source, encoding) else: value = buffer(source) - return translator.ConstMonad.new(translator, value) + return ConstMonad.new(translator, value) class FuncBoolMonad(FuncMonad): func = bool @@ -2078,32 +2077,32 @@ class FuncDecimalMonad(FuncMonad): func = Decimal def call(monad, x): translator = monad.translator - if not isinstance(x, translator.StringConstMonad): throw(TypeError) - return translator.ConstMonad.new(translator, Decimal(x.value)) + if not isinstance(x, StringConstMonad): throw(TypeError) + return ConstMonad.new(translator, Decimal(x.value)) class FuncDateMonad(FuncMonad): func = date def call(monad, year, month, day): translator = monad.translator for arg, name in izip((year, month, day), ('year', 'month', 'day')): - if not isinstance(arg, translator.NumericMixin) or arg.type is not int: throw(TypeError, + if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of date(year, month, day) function must be of 'int' type. " "Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) - return translator.ConstMonad.new(translator, date(year.value, month.value, day.value)) + return ConstMonad.new(translator, date(year.value, month.value, day.value)) def call_today(monad): translator = monad.translator - return translator.DateExprMonad(translator, date, [ 'TODAY' ]) + return DateExprMonad(translator, date, [ 'TODAY' ]) class FuncTimeMonad(FuncMonad): func = time def call(monad, *args): translator = monad.translator for arg, name in izip(args, ('hour', 'minute', 'second', 'microsecond')): - if not isinstance(arg, translator.NumericMixin) or arg.type is not int: throw(TypeError, + if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of time(...) function must be of 'int' type. Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) - return translator.ConstMonad.new(translator, time(*tuple(arg.value for arg in args))) + return ConstMonad.new(translator, time(*tuple(arg.value for arg in args))) class FuncTimedeltaMonad(FuncMonad): func = timedelta @@ -2112,11 +2111,11 @@ def call(monad, days=None, seconds=None, microseconds=None, milliseconds=None, m args = days, seconds, microseconds, milliseconds, minutes, hours, weeks for arg, name in izip(args, ('days', 'seconds', 'microseconds', 'milliseconds', 'minutes', 'hours', 'weeks')): if arg is None: continue - if not isinstance(arg, translator.NumericMixin) or arg.type is not int: throw(TypeError, + if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of timedelta(...) function must be of 'int' type. Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) value = timedelta(*(arg.value if arg is not None else 0 for arg in args)) - return translator.ConstMonad.new(translator, value) + return ConstMonad.new(translator, value) class FuncDatetimeMonad(FuncDateMonad): func = datetime @@ -2125,14 +2124,14 @@ def call(monad, year, month, day, hour=None, minute=None, second=None, microseco translator = monad.translator for arg, name in izip(args, ('year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond')): if arg is None: continue - if not isinstance(arg, translator.NumericMixin) or arg.type is not int: throw(TypeError, + if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of datetime(...) function must be of 'int' type. Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) value = datetime(*(arg.value if arg is not None else 0 for arg in args)) - return translator.ConstMonad.new(translator, value) + return ConstMonad.new(translator, value) def call_now(monad): translator = monad.translator - return translator.DatetimeExprMonad(translator, datetime, [ 'NOW' ]) + return DatetimeExprMonad(translator, datetime, [ 'NOW' ]) class FuncBetweenMonad(FuncMonad): func = between @@ -2143,7 +2142,7 @@ def call(monad, x, a, b): '%s instance cannot be argument of between() function: {EXPR}' % x.type.__name__) translator = x.translator sql = [ 'BETWEEN', x.getsql()[0], a.getsql()[0], b.getsql()[0] ] - return translator.BoolExprMonad(translator, sql) + return BoolExprMonad(translator, sql) class FuncConcatMonad(FuncMonad): func = concat @@ -2156,7 +2155,7 @@ def call(monad, *args): if isinstance(t, EntityMeta) or type(t) in (tuple, SetType): throw(TranslationError, 'Invalid argument of concat() function: %s' % ast2src(arg.node)) result_ast.extend(arg.getsql()) - return translator.ExprMonad.new(translator, unicode, result_ast) + return ExprMonad.new(translator, unicode, result_ast) class FuncLenMonad(FuncMonad): func = len @@ -2173,9 +2172,9 @@ class FuncCountMonad(FuncMonad): func = itertools.count, utils.count, core.count def call(monad, x=None, distinct=None): translator = monad.translator - if isinstance(x, translator.StringConstMonad) and x.value == '*': x = None + if isinstance(x, StringConstMonad) and x.value == '*': x = None if x is not None: return x.count(distinct) - result = translator.ExprMonad.new(translator, int, [ 'COUNT', None ]) + result = ExprMonad.new(translator, int, [ 'COUNT', None ]) result.aggregated = True return result @@ -2217,7 +2216,7 @@ def call(monad, *args): result[i].append(sql) sql = [ [ 'COALESCE' ] + coalesce_args for coalesce_args in result ] if not isinstance(t, EntityMeta): sql = sql[0] - return translator.ExprMonad.new(translator, t, sql) + return ExprMonad.new(translator, t, sql) class FuncDistinctMonad(FuncMonad): func = utils.distinct, core.distinct @@ -2263,20 +2262,20 @@ def minmax(monad, sqlop, *args): if arg.type is bool: args[i] = NumericExprMonad(translator, int, [ 'TO_INT', arg.getsql() ]) sql = [ sqlop, None ] + [ arg.getsql()[0] for arg in args ] - return translator.ExprMonad.new(translator, t, sql) + return ExprMonad.new(translator, t, sql) class FuncSelectMonad(FuncMonad): func = core.select def call(monad, queryset): translator = monad.translator - if not isinstance(queryset, translator.QuerySetMonad): throw(TypeError, + if not isinstance(queryset, QuerySetMonad): throw(TypeError, "'select' function expects generator expression, got: {EXPR}") return queryset class FuncExistsMonad(FuncMonad): func = core.exists def call(monad, arg): - if not isinstance(arg, monad.translator.SetMixin): throw(TypeError, + if not isinstance(arg, SetMixin): throw(TypeError, "'exists' function expects generator expression or collection, got: {EXPR}") return arg.nonzero() @@ -2320,7 +2319,6 @@ def call_distinct(monad): def make_attrset_binop(op, sqlop): def attrset_binop(monad, monad2): - NumericSetExprMonad = monad.translator.NumericSetExprMonad return NumericSetExprMonad(op, sqlop, monad, monad2) return attrset_binop @@ -2357,7 +2355,7 @@ def contains(monad, item, not_in=False): else: conditions += [ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(item.getsql(), expr_list) ] sql_ast = [ 'NOT_EXISTS' if not_in else 'EXISTS', from_ast, [ 'WHERE' ] + conditions ] - result = translator.BoolExprMonad(translator, sql_ast) + result = BoolExprMonad(translator, sql_ast) result.nogroup = True return result elif not not_in: @@ -2365,7 +2363,7 @@ def contains(monad, item, not_in=False): tableref = monad.make_tableref(translator.subquery) expr_list = monad.make_expr_list() expr_ast = sqland([ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(expr_list, item.getsql()) ]) - return translator.BoolExprMonad(translator, expr_ast) + return BoolExprMonad(translator, expr_ast) else: subquery = Subquery(translator.subquery) tableref = monad.make_tableref(subquery) @@ -2380,7 +2378,7 @@ def contains(monad, item, not_in=False): conditions.extend(subquery.conditions) from_ast[-1][-1] = sqland([ from_ast[-1][-1] ] + conditions) expr_ast = sqland([ [ 'IS_NULL', expr ] for expr in expr_list ]) - return translator.BoolExprMonad(translator, expr_ast) + return BoolExprMonad(translator, expr_ast) def getattr(monad, name): try: return Monad.getattr(monad, name) except AttributeError: pass @@ -2388,7 +2386,7 @@ def getattr(monad, name): if not isinstance(entity, EntityMeta): throw(AttributeError) attr = entity._adict_.get(name) if attr is None: throw(AttributeError) - return monad.translator.AttrSetMonad(monad, attr) + return AttrSetMonad(monad, attr) def requires_distinct(monad, joined=False, for_count=False): if monad.parent.requires_distinct(joined): return True reverse = monad.attr.reverse @@ -2396,7 +2394,7 @@ def requires_distinct(monad, joined=False, for_count=False): if reverse.is_collection: translator = monad.translator if not for_count and not translator.hint_join: return True - if isinstance(monad.parent, monad.translator.AttrSetMonad): return True + if isinstance(monad.parent, AttrSetMonad): return True return False def count(monad, distinct=None): translator = monad.translator @@ -2450,7 +2448,7 @@ def count(monad, distinct=None): else: sql_ast, optimized = monad._aggregated_scalar_subselect(make_aggr, extra_grouping) translator.aggregated_subquery_paths.add(monad.tableref.name_path) - result = translator.ExprMonad.new(translator, int, sql_ast) + result = ExprMonad.new(translator, int, sql_ast) if optimized: result.aggregated = True else: result.nogroup = True return result @@ -2494,7 +2492,7 @@ def make_aggr(expr_list): else: result_type = item_type translator.aggregated_subquery_paths.add(monad.tableref.name_path) - result = translator.ExprMonad.new(monad.translator, result_type, sql_ast) + result = ExprMonad.new(monad.translator, result_type, sql_ast) if optimized: result.aggregated = True else: result.nogroup = True return result @@ -2503,20 +2501,20 @@ def nonzero(monad): sql_ast = [ 'EXISTS', subquery.from_ast, [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] translator = monad.translator - return translator.BoolExprMonad(translator, sql_ast) + return BoolExprMonad(translator, sql_ast) def negate(monad): subquery = monad._subselect() sql_ast = [ 'NOT_EXISTS', subquery.from_ast, [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] translator = monad.translator - return translator.BoolExprMonad(translator, sql_ast) + return BoolExprMonad(translator, sql_ast) call_is_empty = negate def make_tableref(monad, subquery): parent = monad.parent attr = monad.attr translator = monad.translator if isinstance(parent, ObjectMixin): parent_tableref = parent.tableref - elif isinstance(parent, translator.AttrSetMonad): parent_tableref = parent.make_tableref(subquery) + elif isinstance(parent, AttrSetMonad): parent_tableref = parent.make_tableref(subquery) else: assert False # pragma: no cover if attr.reverse: name_path = parent_tableref.name_path + '-' + attr.name @@ -2638,7 +2636,6 @@ def getsql(monad, subquery=None): def make_numericset_binop(op, sqlop): def numericset_binop(monad, monad2): - NumericSetExprMonad = monad.translator.NumericSetExprMonad return NumericSetExprMonad(op, sqlop, monad, monad2) return numericset_binop @@ -2676,7 +2673,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): sql_ast = [ 'SELECT', [ 'AGGREGATES', aggr_ast ], subquery.from_ast, [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] - result = translator.ExprMonad.new(translator, result_type, sql_ast) + result = ExprMonad.new(translator, result_type, sql_ast) result.nogroup = True else: if not translator.from_optimized: @@ -2685,7 +2682,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): translator.subquery.from_ast.extend(from_ast) translator.from_optimized = True sql_ast = aggr_ast - result = translator.ExprMonad.new(translator, result_type, sql_ast) + result = ExprMonad.new(translator, result_type, sql_ast) result.aggregated = True return result def getsql(monad, subquery=None): @@ -2722,7 +2719,7 @@ def requires_distinct(monad, joined=False): def contains(monad, item, not_in=False): translator = monad.translator check_comparable(item, monad, 'in') - if isinstance(item, translator.ListMonad): + if isinstance(item, ListMonad): item_columns = [] for subitem in item.items: item_columns.extend(subitem.getsql()) else: item_columns = item.getsql() @@ -2779,17 +2776,17 @@ def contains(monad, item, not_in=False): having_ast = find_or_create_having_ast(subquery_ast) having_ast += in_conditions sql_ast = [ 'NOT_EXISTS' if not_in else 'EXISTS' ] + subquery_ast[2:] - return translator.BoolExprMonad(translator, sql_ast) + return BoolExprMonad(translator, sql_ast) def nonzero(monad): subquery_ast = monad.subtranslator.shallow_copy_of_subquery_ast() subquery_ast = [ 'EXISTS' ] + subquery_ast[2:] translator = monad.translator - return translator.BoolExprMonad(translator, subquery_ast) + return BoolExprMonad(translator, subquery_ast) def negate(monad): sql = monad.nonzero().sql assert sql[0] == 'EXISTS' translator = monad.translator - return translator.BoolExprMonad(translator, [ 'NOT_EXISTS' ] + sql[1:]) + return BoolExprMonad(translator, [ 'NOT_EXISTS' ] + sql[1:]) def count(monad, distinct=None): distinct = distinct_from_monad(distinct) translator = monad.translator @@ -2831,7 +2828,7 @@ def count(monad, distinct=None): else: throw(NotImplementedError) # pragma: no cover if sql_ast is None: sql_ast = [ 'SELECT', select_ast, from_ast, where_ast ] - return translator.ExprMonad.new(translator, int, sql_ast) + return ExprMonad.new(translator, int, sql_ast) len = count def aggregate(monad, func_name, distinct=None, sep=None): distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) @@ -2866,7 +2863,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): result_type = unicode else: result_type = expr_type - return translator.ExprMonad.new(translator, result_type, sql_ast) + return ExprMonad.new(translator, result_type, sql_ast) def call_count(monad, distinct=None): return monad.count(distinct=distinct) def call_sum(monad, distinct=None): @@ -2894,8 +2891,3 @@ def find_or_create_having_ast(subquery_ast): having_ast = [ 'HAVING' ] subquery_ast.insert(groupby_offset + 1, having_ast) return having_ast - -for name, value in items_list(globals()): - if name.endswith('Monad') or name.endswith('Mixin'): - setattr(SQLTranslator, name, value) -del name, value From 1c187d42769554bdf850b805f40e7af5d3180c4d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 28 Feb 2018 12:48:05 +0300 Subject: [PATCH 322/547] New implementation of getattr support --- pony/orm/asttranslation.py | 38 +------------------ pony/orm/core.py | 67 +++++++++++++++++++++++----------- pony/orm/sqltranslation.py | 33 ++++++++++++++--- pony/orm/tests/test_getattr.py | 8 ++-- 4 files changed, 79 insertions(+), 67 deletions(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 198cb5ca3..1e0d4a289 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -223,7 +223,6 @@ class PreTranslator(ASTTranslator): def __init__(translator, tree, globals, locals, special_functions, const_functions, outer_names=()): ASTTranslator.__init__(translator, tree) - translator.getattr_nodes = set() translator.globals = globals translator.locals = locals translator.special_functions = special_functions @@ -303,7 +302,6 @@ def postCallFunc(translator, node): elif x is getattr: attr_node = node.args[1] attr_node.parent_node = node - translator.getattr_nodes.add(attr_node) else: node.external = False elif x in translator.const_functions: for arg in node.args: @@ -312,23 +310,10 @@ def postCallFunc(translator, node): if node.dstar_args is not None and not node.dstar_args.constant: return node.constant = True -getattr_cache = {} extractors_cache = {} def create_extractors(code_key, tree, globals, locals, special_functions, const_functions, outer_names=()): - result = None - getattr_extractors = getattr_cache.get(code_key) - if getattr_extractors: - getattr_attrnames = HashableDict({src: extractor(globals, locals) - for src, extractor in iteritems(getattr_extractors)}) - extractors_key = HashableDict(code_key=code_key, getattr_attrnames=getattr_attrnames) - try: - result = extractors_cache.get(extractors_key) - except TypeError: - pass # unhashable type - if not result: - tree = copy_ast(tree) - + result = extractors_cache.get(code_key) if not result: pretranslator = PreTranslator(tree, globals, locals, special_functions, const_functions, outer_names) extractors = {} @@ -342,24 +327,5 @@ def extractor(globals, locals): def extractor(globals, locals, code=code): return eval(code, globals, locals) extractors[src] = extractor - - getattr_extractors = {} - getattr_attrnames = HashableDict() - for node in pretranslator.getattr_nodes: - if node in pretranslator.externals: - src = node.src - extractor = extractors[src] - getattr_extractors[src] = extractor - attrname_value = extractor(globals, locals) - getattr_attrnames[src] = attrname_value - elif isinstance(node, ast.Const): - attrname_value = node.value - else: throw(TypeError, '`%s` should be either external expression or constant.' % ast2src(node)) - if not isinstance(attrname_value, basestring): throw(TypeError, - '%s: attribute name must be string. Got: %r' % (ast2src(node.parent_node), attrname_value)) - node._attrname_value = attrname_value - getattr_cache[code_key] = getattr_extractors - - extractors_key = HashableDict(code_key=code_key, getattr_attrnames=getattr_attrnames) - result = extractors_cache[extractors_key] = extractors, tree, extractors_key + result = extractors_cache[code_key] = tree, extractors return result diff --git a/pony/orm/core.py b/pony/orm/core.py index b5ca8d9fd..9342b2e14 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5282,8 +5282,7 @@ def unpickle_query(query_result): class Query(object): def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False): assert isinstance(tree, ast.GenExprInner) - extractors, tree, extractors_key = create_extractors( - code_key, tree, globals, locals, special_functions, const_functions) + tree, extractors = create_extractors(code_key, tree, globals, locals, special_functions, const_functions) filter_num = next(filter_num_counter) vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) @@ -5300,23 +5299,26 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False database.provider.normalize_vars(vars, vartypes) query._vars = vars - query._key = HashableDict(extractors_key, vartypes=vartypes, left_join=left_join, filters=()) + query._key = HashableDict(code_key=code_key, vartypes=vartypes, left_join=left_join, filters=()) query._database = database - translator = database._translator_cache.get(query._key) + translator = query._get_translator(query._key, vars) + if translator is None: pickled_tree = pickle_ast(tree) tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) translator_cls = database.provider.translator_cls - translator = translator_cls(tree_copy, filter_num, extractors, vartypes, left_join=left_join) + translator = translator_cls(tree_copy, filter_num, extractors, vars, vartypes, left_join=left_join) name_path = translator.can_be_optimized() if name_path: tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) - try: translator = translator_cls(tree_copy, filter_num, extractors, vartypes, + try: translator = translator_cls(tree_copy, filter_num, extractors, vars, vartypes, left_join=True, optimize=name_path) except OptimizationFailed: translator.optimization_failed = True translator.pickled_tree = pickled_tree - database._translator_cache[query._key] = translator + if translator.can_be_cached: + database._translator_cache[query._key] = translator + query._translator = translator query._filters = () query._next_kwarg_id = 0 @@ -5332,6 +5334,17 @@ def _clone(query, **kwargs): return new_query def __reduce__(query): return unpickle_query, (query._fetch(),) + def _get_translator(query, query_key, vars): + database = query._database + translator = database._translator_cache.get(query_key) + all_func_vartypes = {} + if translator is not None: + for key, attrname in iteritems(translator.getattr_values): + assert key in vars + if attrname != vars[key]: + del database._translator_cache[query_key] + return None + return translator def _construct_sql_and_arguments(query, range=None, aggr_func_name=None, aggr_func_distinct=None, sep=None): translator = query._translator expr_type = translator.expr_type @@ -5339,8 +5352,18 @@ def _construct_sql_and_arguments(query, range=None, aggr_func_name=None, aggr_fu attrs_to_prefetch = tuple(sorted(query._attrs_to_prefetch_dict.get(expr_type, ()))) else: attrs_to_prefetch = () - sql_key = (query._key, range, query._distinct, (aggr_func_name, aggr_func_distinct, sep), - query._for_update, query._nowait, options.INNER_JOIN_SYNTAX, attrs_to_prefetch) + sql_key = HashableDict( + query._key, + vartypes=HashableDict(query._translator.vartypes), + getattr_values=HashableDict(translator.getattr_values), + range=range, + distinct=query._distinct, + aggr_func=(aggr_func_name, aggr_func_distinct, sep), + for_update=query._for_update, + nowait=query._nowait, + inner_join_syntax=options.INNER_JOIN_SYNTAX, + attrs_to_prefetch=attrs_to_prefetch + ) database = query._database cache_entry = database._constructed_sql_cache.get(sql_key) if cache_entry is None: @@ -5357,7 +5380,7 @@ def _construct_sql_and_arguments(query, range=None, aggr_func_name=None, aggr_fu arguments_key = HashableDict(arguments) if type(arguments) is dict else arguments try: hash(arguments_key) except: query_key = None # arguments are unhashable - else: query_key = sql_key + (arguments_key,) + else: query_key = HashableDict(sql_key, arguments_key=arguments_key) else: query_key = None return sql, arguments, attr_offsets, query_key def get_sql(query): @@ -5549,7 +5572,8 @@ def _order_by(query, method_name, *args): tup = (('without_order',),) new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + tup - new_translator = query._database._translator_cache.get(new_key) + + new_translator = query._get_translator(new_key, query._vars) if new_translator is None: new_translator = query._translator.without_order() query._database._translator_cache[new_key] = new_translator @@ -5574,7 +5598,8 @@ def _order_by(query, method_name, *args): tup = (('order_by_numbers' if numbers else 'order_by_attributes', args),) new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + tup - new_translator = query._database._translator_cache.get(new_key) + + new_translator = query._get_translator(new_key, query._vars) if new_translator is None: if numbers: new_translator = query._translator.order_by_numbers(args) else: new_translator = query._translator.order_by_attributes(args) @@ -5611,32 +5636,32 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names 'Expected: %d, got: %d' % (expr_count, len(argnames))) filter_num = next(filter_num_counter) - extractors, func_ast, extractors_key = create_extractors( - func_id, func_ast, globals, locals, special_functions, const_functions, - argnames or prev_translator.subquery) + func_ast, extractors = create_extractors( + func_id, func_ast, globals, locals, special_functions, const_functions, argnames or prev_translator.subquery) if extractors: vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) query._database.provider.normalize_vars(vars, vartypes) new_vars = query._vars.copy() new_vars.update(vars) else: new_vars, vartypes = query._vars, HashableDict() - tup = (('order_by' if order_by else 'where' if original_names else 'filter', extractors_key, vartypes),) + tup = (('order_by' if order_by else 'where' if original_names else 'filter', func_id, vartypes),) new_key = HashableDict(query._key, filters=query._key['filters'] + tup) - new_filters = query._filters + (('apply_lambda', filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes),) - new_translator = query._database._translator_cache.get(new_key) + new_filters = query._filters + (('apply_lambda', filter_num, order_by, func_ast, argnames, original_names, extractors, None, vartypes),) + + new_translator = query._get_translator(new_key, new_vars) if new_translator is None: prev_optimized = prev_translator.optimize - new_translator = prev_translator.apply_lambda(filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes) + new_translator = prev_translator.apply_lambda(filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) if not prev_optimized: name_path = new_translator.can_be_optimized() if name_path: tree_copy = unpickle_ast(prev_translator.pickled_tree) # tree = deepcopy(tree) translator_cls = prev_translator.__class__ new_translator = translator_cls( - tree_copy, prev_translator.original_filter_num, prev_translator.extractors, prev_translator.vartypes, + tree_copy, prev_translator.original_filter_num, prev_translator.extractors, None, prev_translator.vartypes, left_join=True, optimize=name_path) new_translator = query._reapply_filters(new_translator) - new_translator = new_translator.apply_lambda(filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes) + new_translator = new_translator.apply_lambda(filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) query._database._translator_cache[new_key] = new_translator return query._clone(_vars=new_vars, _key=new_key, _filters=new_filters, _translator=new_translator) def _reapply_filters(query, translator): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 5bf8a0deb..4039a01db 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -163,14 +163,17 @@ def call(translator, method, node): else: throw(TranslationError, 'Too complex aggregation, expressions cannot be combined: %s' % ast2src(node)) return monad - def __init__(translator, tree, filter_num, extractors, vartypes, parent_translator=None, left_join=False, optimize=None): + def __init__(translator, tree, filter_num, extractors, vars, vartypes, parent_translator=None, left_join=False, optimize=None): assert isinstance(tree, ast.GenExprInner), tree ASTTranslator.__init__(translator, tree) + translator.can_be_cached = True translator.database = None translator.lambda_argnames = None translator.filter_num = translator.original_filter_num = filter_num translator.extractors = extractors + translator.vars = vars.copy() if vars is not None else None translator.vartypes = vartypes.copy() + translator.getattr_values = {} translator.parent = parent_translator translator.left_join = left_join translator.optimize = optimize @@ -362,6 +365,7 @@ def func(value, converter=converter): offset += 1 translator.row_layout = row_layout translator.col_names = [ src for func, slice_or_offset, src in translator.row_layout ] + translator.vars = None def shallow_copy_of_subquery_ast(translator, move_outer_conditions=True, is_not_null_checks=False): subquery_ast, attr_offsets = translator.construct_sql_ast(distinct=False, is_not_null_checks=is_not_null_checks) assert attr_offsets is None @@ -616,11 +620,12 @@ def apply_kwfilters(translator, filterattrs, original_names=False): monads.append(CmpMonad('==', attr_monad, param_monad)) for m in monads: translator.conditions.extend(m.getsql()) return translator - def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_names, extractors, vartypes): + def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_names, extractors, vars, vartypes): translator = deepcopy(translator) func_ast = copy_ast(func_ast) # func_ast = deepcopy(func_ast) translator.filter_num = filter_num translator.extractors.update(extractors) + translator.vars = vars.copy() if vars is not None else None translator.vartypes.update(vartypes) translator.lambda_argnames = list(argnames) translator.original_names = original_names @@ -647,11 +652,12 @@ def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_ for m in cond_monads: if not m.aggregated: translator.conditions.extend(m.getsql()) else: translator.having_conditions.extend(m.getsql()) + translator.vars = None return translator def preGenExpr(translator, node): inner_tree = node.code translator_cls = translator.__class__ - subtranslator = translator_cls(inner_tree, translator.filter_num, translator.extractors, translator.vartypes, translator) + subtranslator = translator_cls(inner_tree, translator.filter_num, translator.extractors, translator.vars, translator.vartypes, translator) return QuerySetMonad(translator, subtranslator) def postGenExprIf(translator, node): monad = node.test.monad @@ -775,7 +781,7 @@ def preCallFunc(translator, node): for_expr = ast.GenExprFor(ast.AssName(iter_name, 'OP_ASSIGN'), name_ast, [ if_expr ]) inner_expr = ast.GenExprInner(ast.Name(iter_name), [ for_expr ]) translator_cls = translator.__class__ - subtranslator = translator_cls(inner_expr, translator.filter_num, translator.extractors, translator.vartypes, translator) + subtranslator = translator_cls(inner_expr, translator.filter_num, translator.extractors, translator.vars, translator.vartypes, translator) return QuerySetMonad(translator, subtranslator) def postCallFunc(translator, node): args = [] @@ -2165,8 +2171,23 @@ def call(monad, x): class GetattrMonad(FuncMonad): func = getattr def call(monad, obj_monad, name_monad): - name = name_monad.node._attrname_value - return obj_monad.getattr(name) + if isinstance(name_monad, ConstMonad): + attrname = name_monad.value + elif isinstance(name_monad, ParamMonad): + translator = monad.translator + while translator.parent: + translator = translator.parent + key = name_monad.paramkey[0] + if key in translator.getattr_values: + attrname = translator.getattr_values[key] + else: + attrname = translator.vars[key] + translator.getattr_values[key] = attrname + else: throw(TranslationError, 'Expression `{EXPR}` cannot be translated into SQL ' + 'because %s will be different for each row' % ast2src(name_monad.node)) + if not isinstance(attrname, basestring): + throw(TypeError, 'In `{EXPR}` second argument should be a string. Got: %r' % attrname) + return obj_monad.getattr(attrname) class FuncCountMonad(FuncMonad): func = itertools.count, utils.count, core.count diff --git a/pony/orm/tests/test_getattr.py b/pony/orm/tests/test_getattr.py index 629d50623..309e8dd6e 100644 --- a/pony/orm/tests/test_getattr.py +++ b/pony/orm/tests/test_getattr.py @@ -72,18 +72,18 @@ def test_not_instance_iter(self): val = select(getattr(x.name, 'startswith')('S') for x in self.db.Artist).first() self.assertTrue(val) + @raises_exception(TranslationError, 'Expression `getattr(x, x.name)` cannot be translated into SQL ' + 'because x.name will be different for each row') @db_session - @raises_exception(TypeError, '`x.name` should be either external expression or constant.') def test_not_external(self): select(getattr(x, x.name) for x in self.db.Artist) - @raises_exception(TypeError, 'getattr(x, 1): attribute name must be string. Got: 1') + @raises_exception(TypeError, 'In `getattr(x, 1)` second argument should be a string. Got: 1') @db_session def test_not_string(self): select(getattr(x, 1) for x in self.db.Artist) - - @raises_exception(TypeError, 'getattr(x, name): attribute name must be string. Got: 1') + @raises_exception(TypeError, 'In `getattr(x, name)` second argument should be a string. Got: 1') @db_session def test_not_string(self): name = 1 From d1033e278066c20278129c23ad846fa17aaebe1f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 16 Oct 2017 20:00:04 +0300 Subject: [PATCH 323/547] Hybrid methods, properties & attributes (the tests don't pass yet) --- pony/orm/core.py | 31 ++-- pony/orm/sqltranslation.py | 96 ++++++++++-- .../test_hybrid_methods_and_properties.py | 139 ++++++++++++++++++ 3 files changed, 245 insertions(+), 21 deletions(-) create mode 100644 pony/orm/tests/test_hybrid_methods_and_properties.py diff --git a/pony/orm/core.py b/pony/orm/core.py index 9342b2e14..768587500 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5298,11 +5298,11 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False if database.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % origin.__name__) database.provider.normalize_vars(vars, vartypes) - query._vars = vars query._key = HashableDict(code_key=code_key, vartypes=vartypes, left_join=left_join, filters=()) query._database = database - translator = query._get_translator(query._key, vars) + translator, vars = query._get_translator(query._key, vars) + query._vars = vars if translator is None: pickled_tree = pickle_ast(tree) @@ -5335,16 +5335,27 @@ def _clone(query, **kwargs): def __reduce__(query): return unpickle_query, (query._fetch(),) def _get_translator(query, query_key, vars): + new_vars = vars.copy() database = query._database translator = database._translator_cache.get(query_key) all_func_vartypes = {} if translator is not None: + if translator.func_extractors_map: + for func, func_extractors in iteritems(translator.func_extractors_map): + func_filter_num = translator.filter_num, 'func', id(func) + func_vars, func_vartypes = extract_vars( + func_filter_num, func_extractors, func.__globals__, {}, func.__closure__) # todo closures + database.provider.normalize_vars(func_vars, func_vartypes) + new_vars.update(func_vars) + all_func_vartypes.update(func_vartypes) + if all_func_vartypes != translator.func_vartypes: + return None, vars.copy() for key, attrname in iteritems(translator.getattr_values): - assert key in vars - if attrname != vars[key]: + assert key in new_vars + if attrname != new_vars[key]: del database._translator_cache[query_key] - return None - return translator + return None, vars.copy() + return translator, new_vars def _construct_sql_and_arguments(query, range=None, aggr_func_name=None, aggr_func_distinct=None, sep=None): translator = query._translator expr_type = translator.expr_type @@ -5573,7 +5584,7 @@ def _order_by(query, method_name, *args): new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + tup - new_translator = query._get_translator(new_key, query._vars) + new_translator, new_vars = query._get_translator(new_key, query._vars) if new_translator is None: new_translator = query._translator.without_order() query._database._translator_cache[new_key] = new_translator @@ -5599,7 +5610,7 @@ def _order_by(query, method_name, *args): new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + tup - new_translator = query._get_translator(new_key, query._vars) + new_translator, new_vars = query._get_translator(new_key, query._vars) if new_translator is None: if numbers: new_translator = query._translator.order_by_numbers(args) else: new_translator = query._translator.order_by_attributes(args) @@ -5648,7 +5659,7 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names new_key = HashableDict(query._key, filters=query._key['filters'] + tup) new_filters = query._filters + (('apply_lambda', filter_num, order_by, func_ast, argnames, original_names, extractors, None, vartypes),) - new_translator = query._get_translator(new_key, new_vars) + new_translator, new_vars = query._get_translator(new_key, new_vars) if new_translator is None: prev_optimized = prev_translator.optimize new_translator = prev_translator.apply_lambda(filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) @@ -5729,7 +5740,7 @@ def _apply_kwargs(query, kwargs, original_names=False): new_filters = query._filters + tup new_vars = query._vars.copy() new_vars.update(value_dict) - new_translator = query._database._translator_cache.get(new_key) + new_translator, new_vars = query._get_translator(new_key, new_vars) if new_translator is None: new_translator = translator.apply_kwfilters(filterattrs, original_names) query._database._translator_cache[new_key] = new_translator diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 4039a01db..59d468376 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, items_list, izip, xrange, basestring, unicode, buffer, with_metaclass -import types, sys, re, itertools +import types, sys, re, itertools, inspect from decimal import Decimal from datetime import date, time, datetime, timedelta from random import random @@ -13,13 +13,15 @@ from pony import options, utils from pony.utils import is_ident, throw, reraise, copy_ast, between, concat, coalesce -from pony.orm.asttranslation import ASTTranslator, ast2src, TranslationError +from pony.orm.asttranslation import ASTTranslator, ast2src, TranslationError, create_extractors +from pony.orm.decompiling import decompile from pony.orm.ormtypes import \ numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ normalize, normalize_type, coerce_types, are_comparable_types, \ Json from pony.orm import core -from pony.orm.core import EntityMeta, Set, JOIN, OptimizationFailed, Attribute, DescWrapper +from pony.orm.core import EntityMeta, Set, JOIN, OptimizationFailed, Attribute, DescWrapper, \ + special_functions, const_functions, extract_vars NoneType = type(None) @@ -171,9 +173,12 @@ def __init__(translator, tree, filter_num, extractors, vars, vartypes, parent_tr translator.lambda_argnames = None translator.filter_num = translator.original_filter_num = filter_num translator.extractors = extractors + translator.method_argnames_mapping_stack = [] + translator.func_extractors_map = {} translator.vars = vars.copy() if vars is not None else None translator.vartypes = vartypes.copy() translator.getattr_values = {} + translator.func_vartypes = {} translator.parent = parent_translator translator.left_join = left_join translator.optimize = optimize @@ -235,9 +240,15 @@ def __init__(translator, tree, filter_num, extractors, vars, vartypes, parent_tr node_name = node.name attr_names.reverse() name_path = node_name - parent_tableref = subquery.get_tableref(node_name) - if parent_tableref is None: throw(TranslationError, "Name %r must be defined in query" % node_name) + + monad = translator.resolve_name(node_name) + if monad is None: + throw(TranslationError, "Name %r must be defined in query" % node_name) + if not isinstance(monad, ObjectIterMonad): + throw(NotImplementedError) + parent_tableref = monad.tableref parent_entity = parent_tableref.entity + last_index = len(attr_names) - 1 for j, attrname in enumerate(attr_names): attr = parent_entity._adict_.get(attrname) @@ -699,9 +710,15 @@ def postList(translator, node): def postTuple(translator, node): return ListMonad(translator, [ item.monad for item in node.nodes ]) def postName(translator, node): - name = node.name + monad = translator.resolve_name(node.name) + assert monad is not None + return monad + def resolve_name(translator, name): t = translator while t is not None: + stack = t.method_argnames_mapping_stack + if stack and name in stack[-1]: + return stack[-1][name] argnames = t.lambda_argnames if argnames is not None and not t.original_names and name in argnames: i = argnames.index(name) @@ -710,7 +727,7 @@ def postName(translator, node): tableref = translator.subquery.get_tableref(name) if tableref is not None: return ObjectIterMonad(translator, tableref, tableref.entity) - else: assert False, name # pragma: no cover + return None def postAdd(translator, node): return node.left.monad + node.right.monad def postSub(translator, node): @@ -1241,6 +1258,54 @@ def __pow__(monad, monad2): raise_forgot_parentheses(monad) def __neg__(monad): raise_forgot_parentheses(monad) def abs(monad): raise_forgot_parentheses(monad) +class HybridMethodMonad(MethodMonad): + def __init__(monad, parent, attrname, func): + MethodMonad.__init__(monad, parent, attrname) + monad.func = func + def __call__(monad, *args, **kwargs): + translator = monad.translator + name_mapping = inspect.getcallargs(monad.func, monad.parent, *args, **kwargs) + + func = monad.func + if PY2 and isinstance(func, types.UnboundMethodType): + func = func.im_func + func_id = id(func) + func_filter_num = translator.filter_num, 'func', id(func) + func_ast, external_names, cells = decompile(func) + + func_ast, func_extractors = create_extractors( + func_id, func_ast, func.__globals__, {}, special_functions, const_functions, outer_names=name_mapping) + + t = translator + while t.parent is not None: + t = t.parent + if func not in t.func_extractors_map: + func_vars, func_vartypes = extract_vars(func_filter_num, func_extractors, func.__globals__, {}, cells) + translator.database.provider.normalize_vars(func_vars, func_vartypes) + if func.__closure__: + translator.can_be_cached = False + if func_extractors: + t.func_extractors_map[func] = func_extractors + t.func_vartypes.update(func_vartypes) + t.vartypes.update(func_vartypes) + t.vars.update(func_vars) + + stack = translator.method_argnames_mapping_stack + stack.append(name_mapping) + prev_filter_num = translator.filter_num + translator.filter_num = func_filter_num + func_ast = copy_ast(func_ast) + try: + translator.dispatch(func_ast) + except Exception as e: + if len(e.args) == 1 and isinstance(e.args[0], basestring): + msg = e.args[0] + ' (inside %s.%s)' % (monad.parent.type.__name__, monad.attrname) + e.args = (msg,) + raise + translator.filter_num = prev_filter_num + stack.pop() + return func_ast.monad + class EntityMonad(Monad): def __init__(monad, translator, entity): Monad.__init__(monad, translator, SetType(entity)) @@ -1594,12 +1659,21 @@ def negate(monad): def nonzero(monad): translator = monad.translator return CmpMonad('is not', monad, NoneMonad(translator)) - def getattr(monad, name): + def getattr(monad, attrname): translator = monad.translator entity = monad.type - attr = entity._adict_.get(name) or entity._subclass_adict_.get(name) - if attr is None: throw(AttributeError, - 'Entity %s does not have attribute %s: {EXPR}' % (entity.__name__, name)) + attr = entity._adict_.get(attrname) or entity._subclass_adict_.get(attrname) + if attr is None: + if hasattr(entity, attrname): + attr = getattr(entity, attrname, None) + if isinstance(attr, property): + new_monad = HybridMethodMonad(monad, attrname, attr.fget) + return new_monad() + if callable(attr): + func = getattr(attr, '__func__') if PY2 else attr + if func is not None: return HybridMethodMonad(monad, attrname, func) + throw(NotImplementedError, '{EXPR} cannot be translated to SQL') + throw(AttributeError, 'Entity %s does not have attribute %s: {EXPR}' % (entity.__name__, attrname)) if hasattr(monad, 'tableref'): monad.tableref.used_attrs.add(attr) if not attr.is_collection: return AttrMonad.new(monad, attr) diff --git a/pony/orm/tests/test_hybrid_methods_and_properties.py b/pony/orm/tests/test_hybrid_methods_and_properties.py new file mode 100644 index 000000000..7b2b4256b --- /dev/null +++ b/pony/orm/tests/test_hybrid_methods_and_properties.py @@ -0,0 +1,139 @@ +import unittest + +from pony.orm import * +from pony.orm.tests.testutils import * + +db = Database('sqlite', ':memory:') + +class Person(db.Entity): + first_name = Required(str) + last_name = Required(str) + favorite_color = Optional(str) + cars = Set(lambda: Car) + + @property + def full_name(self): + return self.first_name + ' ' + self.last_name + + @property + def has_car(self): + return not self.cars.is_empty() + + def cars_by_color1(self, color): + return select(car for car in self.cars if car.color == color) + + def cars_by_color2(self, color): + return self.cars.select(lambda car: car.color == color) + + @property + def cars_price(self): + return sum(c.price for c in self.cars) + + @property + def incorrect_full_name(self): + return self.first_name + ' ' + p.last_name # p is FakePerson instance here + + +class FakePerson(object): + pass + +p = FakePerson() +p.last_name = '***' + + +class Car(db.Entity): + brand = Required(str) + model = Required(str) + owner = Optional(Person) + year = Required(int) + price = Required(int) + color = Required(str) + +db.generate_mapping(create_tables=True) + +with db_session: + p1 = Person(first_name='Alexander', last_name='Kozlovsky', favorite_color='white') + p2 = Person(first_name='Alexei', last_name='Malashkevich', favorite_color='green') + p3 = Person(first_name='Vitaliy', last_name='Abetkin') + p4 = Person(first_name='Alexander', last_name='Tischenko', favorite_color='blue') + + c1 = Car(brand='Peugeot', model='306', owner=p1, year=2006, price=14000, color='red') + c2 = Car(brand='Honda', model='Accord', owner=p1, year=2007, price=13850, color='white') + c3 = Car(brand='Nissan', model='Skyline', owner=p2, year=2008, price=29900, color='black') + c4 = Car(brand='Volkswagen', model='Passat', owner=p1, year=2012, price=9400, color='blue') + c5 = Car(brand='Koenigsegg', model='CCXR', owner=p4, year=2016, price=4850000, color='white') + c6 = Car(brand='Lada', model='Kalina', owner=p4, year=2015, price=5000, color='white') + + +class TestHybridsAndProperties(unittest.TestCase): + @db_session + def test1(self): + persons = select(p.full_name for p in Person if p.has_car)[:] + self.assertEqual(set(persons), {'Alexander Kozlovsky', 'Alexei Malashkevich', 'Alexander Tischenko'}) + + @db_session + def test2(self): + cars_prices = select(p.cars_price for p in Person)[:] + self.assertEqual(set(cars_prices), {0, 29900, 37250, 4855000}) + + @db_session + def test3(self): + persons = select(p.full_name for p in Person if p.cars_price > 100000)[:] + self.assertEqual(set(persons), {'Alexander Tischenko'}) + + @db_session + def test4(self): + persons = select(p.full_name for p in Person if not p.cars_price)[:] + self.assertEqual(set(persons), {'Vitaliy Abetkin'}) + + @db_session + def test5(self): + persons = select(p.full_name for p in Person if exists(c for c in p.cars_by_color2('white') if c.price > 10000))[:] + self.assertEqual(set(persons), {'Alexander Kozlovsky', 'Alexander Tischenko'}) + + @db_session + def test6(self): + persons = select(p.full_name for p in Person if exists(c for c in p.cars_by_color1('white') if c.price > 10000))[:] + self.assertEqual(set(persons), {'Alexander Kozlovsky', 'Alexander Tischenko'}) + + @db_session + def test7(self): + c1 = Car[1] + persons = select(p.full_name for p in Person if c1 in p.cars_by_color2('red'))[:] + self.assertEqual(set(persons), {'Alexander Kozlovsky'}) + + @db_session + def test8(self): + c1 = Car[1] + persons = select(p.full_name for p in Person if c1 in p.cars_by_color1('red'))[:] + self.assertEqual(set(persons), {'Alexander Kozlovsky'}) + + @db_session + def test9(self): + persons = select(p.full_name for p in Person if p.cars_by_color1(p.favorite_color))[:] + self.assertEqual(set(persons), {'Alexander Kozlovsky'}) + + @db_session + def test10(self): + persons = select(p.full_name for p in Person if not p.cars_by_color1(p.favorite_color))[:] + self.assertEqual(set(persons), {'Alexander Tischenko', 'Alexei Malashkevich', 'Vitaliy Abetkin'}) + + @db_session + def test11(self): + persons = select(p.full_name for p in Person if p.cars_by_color2(p.favorite_color))[:] + self.assertEqual(set(persons), {'Alexander Kozlovsky'}) + + @db_session + def test12(self): + persons = select(p.full_name for p in Person if not p.cars_by_color2(p.favorite_color))[:] + self.assertEqual(set(persons), {'Alexander Tischenko', 'Alexei Malashkevich', 'Vitaliy Abetkin'}) + + @db_session + def test13(self): + # This test checks if accessing function-specific globals works correctly + persons = select(p.incorrect_full_name for p in Person if p.has_car)[:] + self.assertEqual(set(persons), {'Alexander ***', 'Alexei ***', 'Alexander ***'}) + + +if __name__ == '__main__': + unittest.main() From d21a9c177fc1dae727c38c666d802c12b24d7702 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 21 Mar 2018 16:00:34 +0300 Subject: [PATCH 324/547] Add support for calling obj.attr.select(lambda), obj.attr.exists(lambda) and obj.attr.filter(lambda) in queries --- pony/orm/sqltranslation.py | 47 ++++++++++++++----- .../tests/test_declarative_attr_set_monad.py | 28 +++++++++++ 2 files changed, 62 insertions(+), 13 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 59d468376..62192a23a 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -210,9 +210,16 @@ def __init__(translator, tree, filter_num, extractors, vars, vartypes, parent_tr monad = getattr(node, 'monad', None) src = getattr(node, 'src', None) if monad: # Lambda was encountered inside generator - assert isinstance(monad, EntityMonad) + assert parent_translator and i == 0 entity = monad.type.item_type - tablerefs[name] = TableRef(subquery, name, entity) + if isinstance(monad, EntityMonad): + tablerefs[name] = TableRef(subquery, name, entity) + elif isinstance(monad, AttrSetMonad): + translator.subquery = monad._subselect(translator.subquery) + tableref = monad.tableref + translator.method_argnames_mapping_stack.append({ + name: ObjectIterMonad(translator, tableref, entity)}) + else: assert False # pragma: no cover elif src: iterable = translator.vartypes[translator.filter_num, src] if not isinstance(iterable, SetType): throw(TranslationError, @@ -287,7 +294,7 @@ def __init__(translator, tree, filter_num, extractors, vars, vartypes, parent_tr if isinstance(if_.monad, AndMonad): cond_monads = if_.monad.operands else: cond_monads = [ if_.monad ] for m in cond_monads: - if not m.aggregated: translator.conditions.extend(m.getsql()) + if not getattr(m, 'aggregated', False): translator.conditions.extend(m.getsql()) else: translator.having_conditions.extend(m.getsql()) translator.dispatch(tree.expr) @@ -300,13 +307,15 @@ def __init__(translator, tree, filter_num, extractors, vars, vartypes, parent_tr expr_type = monad.type if isinstance(expr_type, SetType): expr_type = expr_type.item_type if isinstance(expr_type, EntityMeta): - monad.orderby_columns = list(xrange(1, len(expr_type._pk_columns_)+1)) + entity = expr_type + translator.expr_type = entity + monad.orderby_columns = list(xrange(1, len(entity._pk_columns_)+1)) if monad.aggregated: throw(TranslationError) - if isinstance(monad, ObjectMixin): - entity = monad.type + if isinstance(monad, QuerySetMonad): + throw(NotImplementedError) + elif isinstance(monad, ObjectMixin): tableref = monad.tableref elif isinstance(monad, AttrSetMonad): - entity = monad.type.item_type tableref = monad.make_tableref(translator.subquery) else: assert False # pragma: no cover if translator.aggregated: @@ -317,7 +326,6 @@ def __init__(translator, tree, filter_num, extractors, vars, vartypes, parent_tr pk_only = parent_translator is not None or translator.aggregated alias, pk_columns = tableref.make_join(pk_only=pk_only) translator.alias = alias - translator.expr_type = entity translator.expr_columns = [ [ 'COLUMN', alias, column ] for column in pk_columns ] translator.row_layout = None translator.col_names = [ attr.name for attr in entity._attrs_ @@ -783,9 +791,10 @@ def preCallFunc(translator, node): method_monad = func_node.monad if not isinstance(method_monad, MethodMonad): throw(NotImplementedError) entity_monad = method_monad.parent - if not isinstance(entity_monad, EntityMonad): throw(NotImplementedError) + if not isinstance(entity_monad, (EntityMonad, AttrSetMonad)): throw(NotImplementedError) entity = entity_monad.type.item_type - if method_monad.attrname != 'select': throw(TypeError) + method_name = method_monad.attrname + if method_name not in ('select', 'filter', 'exists'): throw(TypeError) if len(lambda_expr.argnames) != 1: throw(TypeError) if lambda_expr.varargs: throw(TypeError) if lambda_expr.kwargs: throw(TypeError) @@ -799,7 +808,10 @@ def preCallFunc(translator, node): inner_expr = ast.GenExprInner(ast.Name(iter_name), [ for_expr ]) translator_cls = translator.__class__ subtranslator = translator_cls(inner_expr, translator.filter_num, translator.extractors, translator.vars, translator.vartypes, translator) - return QuerySetMonad(translator, subtranslator) + monad = QuerySetMonad(translator, subtranslator) + if method_name == 'exists': + monad = monad.nonzero() + return monad def postCallFunc(translator, node): args = [] kwargs = {} @@ -2482,6 +2494,12 @@ def getattr(monad, name): attr = entity._adict_.get(name) if attr is None: throw(AttributeError) return AttrSetMonad(monad, attr) + def call_select(monad): + # calling with lambda argument processed in preCallFunc + return monad + call_filter = call_select + def call_exists(monad): + return monad def requires_distinct(monad, joined=False, for_count=False): if monad.parent.requires_distinct(joined): return True reverse = monad.attr.reverse @@ -2704,11 +2722,12 @@ def _joined_subselect(monad, make_aggr, extra_grouping=False, coalesce_to_zero=F expr_ast = [ 'COLUMN', alias, expr_name ] if coalesce_to_zero: expr_ast = [ 'COALESCE', expr_ast, [ 'VALUE', 0 ] ] return expr_ast, False - def _subselect(monad): + def _subselect(monad, subquery=None): if monad.subquery is not None: return monad.subquery attr = monad.attr translator = monad.translator - subquery = Subquery(translator.subquery) + if subquery is None: + subquery = Subquery(translator.subquery) monad.make_tableref(subquery) subquery.expr_list = monad.make_expr_list() if not attr.reverse and not attr.is_required: @@ -2974,6 +2993,8 @@ def call_group_concat(monad, sep=None, distinct=None): if not isinstance(sep, basestring): throw(TypeError, '`sep` option of `group_concat` should be type of str. Got: %s' % type(sep).__name__) return monad.aggregate('GROUP_CONCAT', distinct, sep=sep) + def getsql(monad): + throw(NotImplementedError) def find_or_create_having_ast(subquery_ast): groupby_offset = None diff --git a/pony/orm/tests/test_declarative_attr_set_monad.py b/pony/orm/tests/test_declarative_attr_set_monad.py index 938f882d2..6217277f1 100644 --- a/pony/orm/tests/test_declarative_attr_set_monad.py +++ b/pony/orm/tests/test_declarative_attr_set_monad.py @@ -168,6 +168,34 @@ def test27(self): def test28(self): groups = set(select(g for g in Group if not g.students.is_empty())) self.assertEqual(groups, {Group[41], Group[42], Group[44]}) + @raises_exception(NotImplementedError) + def test29(self): + students = select(g.students.select(lambda s: s.scholarship > 0) for g in Group if g.department == 101)[:] + def test30a(self): + s = Student[2] + groups = select(g for g in Group if g.department == 101 + and s in g.students.select(lambda s: s.scholarship > 0))[:] + self.assertEqual(set(groups), {Group[41]}) + def test30b(self): + s = Student[2] + groups = select(g for g in Group if g.department == 101 + and s in g.students.filter(lambda s: s.scholarship > 0))[:] + self.assertEqual(set(groups), {Group[41]}) + def test30c(self): + s = Student[2] + groups = select(g for g in Group if g.department == 101 + and s in g.students.select())[:] + self.assertEqual(set(groups), {Group[41]}) + def test30d(self): + s = Student[2] + groups = select(g for g in Group if g.department == 101 + and s in g.students.filter())[:] + self.assertEqual(set(groups), {Group[41]}) + def test31(self): + s = Student[2] + groups = select(g for g in Group if g.department == 101 and g.students.exists(lambda s: s.scholarship > 0))[:] + self.assertEqual(set(groups), {Group[41]}) + if __name__ == "__main__": unittest.main() From c507aef3615b240dbe4d3d504fa96db77ccd1006 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 2 Apr 2018 18:11:15 +0300 Subject: [PATCH 325/547] Translator refactoring --- pony/orm/core.py | 7 ++--- pony/orm/sqltranslation.py | 55 +++++++++++++++++++++----------------- 2 files changed, 34 insertions(+), 28 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 768587500..856678c10 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5308,11 +5308,11 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False pickled_tree = pickle_ast(tree) tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) translator_cls = database.provider.translator_cls - translator = translator_cls(tree_copy, filter_num, extractors, vars, vartypes, left_join=left_join) + translator = translator_cls(tree_copy, None, filter_num, extractors, vars, vartypes.copy(), left_join=left_join) name_path = translator.can_be_optimized() if name_path: tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) - try: translator = translator_cls(tree_copy, filter_num, extractors, vars, vartypes, + try: translator = translator_cls(tree_copy, None, filter_num, extractors, vars, vartypes.copy(), left_join=True, optimize=name_path) except OptimizationFailed: translator.optimization_failed = True translator.pickled_tree = pickled_tree @@ -5669,7 +5669,8 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names tree_copy = unpickle_ast(prev_translator.pickled_tree) # tree = deepcopy(tree) translator_cls = prev_translator.__class__ new_translator = translator_cls( - tree_copy, prev_translator.original_filter_num, prev_translator.extractors, None, prev_translator.vartypes, + tree_copy, None, prev_translator.original_filter_num, + prev_translator.extractors, None, prev_translator.vartypes.copy(), left_join=True, optimize=name_path) new_translator = query._reapply_filters(new_translator) new_translator = new_translator.apply_lambda(filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 62192a23a..db144e520 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -86,7 +86,7 @@ def dispatch(translator, node): def dispatch_external(translator, node): varkey = translator.filter_num, node.src - t = translator.vartypes[varkey] + t = translator.root_translator.vartypes[varkey] tt = type(t) if t is NoneType: monad = ConstMonad.new(translator, None) @@ -165,28 +165,36 @@ def call(translator, method, node): else: throw(TranslationError, 'Too complex aggregation, expressions cannot be combined: %s' % ast2src(node)) return monad - def __init__(translator, tree, filter_num, extractors, vars, vartypes, parent_translator=None, left_join=False, optimize=None): + def __init__(translator, tree, parent_translator, filter_num=None, extractors=None, vars=None, vartypes=None, left_join=False, optimize=None): assert isinstance(tree, ast.GenExprInner), tree ASTTranslator.__init__(translator, tree) translator.can_be_cached = True - translator.database = None - translator.lambda_argnames = None - translator.filter_num = translator.original_filter_num = filter_num + translator.parent = parent_translator + if parent_translator is None: + translator.root_translator = translator + translator.database = None + translator.subquery = Subquery(left_join=left_join) + assert filter_num is not None + translator.filter_num = translator.original_filter_num = filter_num + else: + translator.root_translator = parent_translator.root_translator + translator.database = parent_translator.database + translator.subquery = Subquery(parent_translator.subquery, left_join=left_join) + translator.filter_num = parent_translator.filter_num + translator.original_filter_num = None translator.extractors = extractors + translator.vars = vars + translator.vartypes = vartypes + translator.lambda_argnames = None translator.method_argnames_mapping_stack = [] translator.func_extractors_map = {} - translator.vars = vars.copy() if vars is not None else None - translator.vartypes = vartypes.copy() translator.getattr_values = {} translator.func_vartypes = {} - translator.parent = parent_translator translator.left_join = left_join translator.optimize = optimize translator.from_optimized = False translator.optimization_failed = False - if not parent_translator: subquery = Subquery(left_join=left_join) - else: subquery = Subquery(parent_translator.subquery, left_join=left_join) - translator.subquery = subquery + subquery = translator.subquery tablerefs = subquery.tablerefs translator.distinct = False translator.conditions = subquery.conditions @@ -221,7 +229,7 @@ def __init__(translator, tree, filter_num, extractors, vars, vartypes, parent_tr name: ObjectIterMonad(translator, tableref, entity)}) else: assert False # pragma: no cover elif src: - iterable = translator.vartypes[translator.filter_num, src] + iterable = translator.root_translator.vartypes[translator.filter_num, src] if not isinstance(iterable, SetType): throw(TranslationError, 'Inside declarative query, iterator must be entity. ' 'Got: for %s in %s' % (name, ast2src(qual.iter))) @@ -645,6 +653,7 @@ def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_ translator.filter_num = filter_num translator.extractors.update(extractors) translator.vars = vars.copy() if vars is not None else None + translator.vartypes = translator.vartypes.copy() # make HashableDict mutable again translator.vartypes.update(vartypes) translator.lambda_argnames = list(argnames) translator.original_names = original_names @@ -676,7 +685,7 @@ def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_ def preGenExpr(translator, node): inner_tree = node.code translator_cls = translator.__class__ - subtranslator = translator_cls(inner_tree, translator.filter_num, translator.extractors, translator.vars, translator.vartypes, translator) + subtranslator = translator_cls(inner_tree, translator) return QuerySetMonad(translator, subtranslator) def postGenExprIf(translator, node): monad = node.test.monad @@ -807,7 +816,7 @@ def preCallFunc(translator, node): for_expr = ast.GenExprFor(ast.AssName(iter_name, 'OP_ASSIGN'), name_ast, [ if_expr ]) inner_expr = ast.GenExprInner(ast.Name(iter_name), [ for_expr ]) translator_cls = translator.__class__ - subtranslator = translator_cls(inner_expr, translator.filter_num, translator.extractors, translator.vars, translator.vartypes, translator) + subtranslator = translator_cls(inner_expr, translator) monad = QuerySetMonad(translator, subtranslator) if method_name == 'exists': monad = monad.nonzero() @@ -1288,19 +1297,17 @@ def __call__(monad, *args, **kwargs): func_ast, func_extractors = create_extractors( func_id, func_ast, func.__globals__, {}, special_functions, const_functions, outer_names=name_mapping) - t = translator - while t.parent is not None: - t = t.parent - if func not in t.func_extractors_map: + root_translator = translator.root_translator + if func not in root_translator.func_extractors_map: func_vars, func_vartypes = extract_vars(func_filter_num, func_extractors, func.__globals__, {}, cells) translator.database.provider.normalize_vars(func_vars, func_vartypes) if func.__closure__: translator.can_be_cached = False if func_extractors: - t.func_extractors_map[func] = func_extractors - t.func_vartypes.update(func_vartypes) - t.vartypes.update(func_vartypes) - t.vars.update(func_vars) + root_translator.func_extractors_map[func] = func_extractors + root_translator.func_vartypes.update(func_vartypes) + root_translator.vartypes.update(func_vartypes) + root_translator.vars.update(func_vars) stack = translator.method_argnames_mapping_stack stack.append(name_mapping) @@ -2260,9 +2267,7 @@ def call(monad, obj_monad, name_monad): if isinstance(name_monad, ConstMonad): attrname = name_monad.value elif isinstance(name_monad, ParamMonad): - translator = monad.translator - while translator.parent: - translator = translator.parent + translator = monad.translator.root_translator key = name_monad.paramkey[0] if key in translator.getattr_values: attrname = translator.getattr_values[key] From 66835ef9c6a5485d714f55713caae68b59946dfc Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 11 Apr 2018 12:59:11 +0300 Subject: [PATCH 326/547] Add extract_outer_conditions option to AttrSetMonad._subselect() --- pony/orm/sqltranslation.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index db144e520..82ca6743c 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -223,7 +223,7 @@ def __init__(translator, tree, parent_translator, filter_num=None, extractors=No if isinstance(monad, EntityMonad): tablerefs[name] = TableRef(subquery, name, entity) elif isinstance(monad, AttrSetMonad): - translator.subquery = monad._subselect(translator.subquery) + translator.subquery = monad._subselect(translator.subquery, extract_outer_conditions=False) tableref = monad.tableref translator.method_argnames_mapping_stack.append({ name: ObjectIterMonad(translator, tableref, entity)}) @@ -901,6 +901,7 @@ def __init__(subquery, parent_subquery=None, left_join=False): subquery.left_join = left_join subquery.from_ast = [ 'LEFT_JOIN' if left_join else 'FROM' ] subquery.conditions = [] + subquery.outer_conditions = [] subquery.tablerefs = {} if parent_subquery is None: subquery.alias_counters = {} @@ -2660,6 +2661,7 @@ def _aggregated_scalar_subselect(monad, make_aggr, extra_grouping=False): optimized = True if not translator.from_optimized: from_ast = monad.subquery.from_ast[1:] + assert subquery.outer_conditions from_ast[0] = from_ast[0] + [ sqland(subquery.outer_conditions) ] translator.subquery.from_ast.extend(from_ast) translator.from_optimized = True @@ -2727,7 +2729,7 @@ def _joined_subselect(monad, make_aggr, extra_grouping=False, coalesce_to_zero=F expr_ast = [ 'COLUMN', alias, expr_name ] if coalesce_to_zero: expr_ast = [ 'COALESCE', expr_ast, [ 'VALUE', 0 ] ] return expr_ast, False - def _subselect(monad, subquery=None): + def _subselect(monad, subquery=None, extract_outer_conditions=True): if monad.subquery is not None: return monad.subquery attr = monad.attr translator = monad.translator @@ -2737,7 +2739,7 @@ def _subselect(monad, subquery=None): subquery.expr_list = monad.make_expr_list() if not attr.reverse and not attr.is_required: subquery.conditions.extend([ 'IS_NOT_NULL', expr ] for expr in subquery.expr_list) - if subquery is not translator.subquery: + if subquery is not translator.subquery and extract_outer_conditions: outer_cond = subquery.from_ast[1].pop() if outer_cond[0] == 'AND': subquery.outer_conditions = outer_cond[1:] else: subquery.outer_conditions = [ outer_cond ] @@ -2797,6 +2799,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): else: if not translator.from_optimized: from_ast = subquery.from_ast[1:] + assert subquery.outer_conditions from_ast[0] = from_ast[0] + [ sqland(subquery.outer_conditions) ] translator.subquery.from_ast.extend(from_ast) translator.from_optimized = True From 3599beb55c83dbe832b84614bcc5ba1efd660f92 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 14 Apr 2018 16:24:05 +0300 Subject: [PATCH 327/547] Improved handling of nullable expressions --- pony/orm/sqltranslation.py | 215 ++++++++++++++++++++----------------- 1 file changed, 118 insertions(+), 97 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 82ca6743c..b9446dcb6 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -504,7 +504,7 @@ def ast_transformer(ast): if is_not_null_checks: for monad in translator.expr_monads: if isinstance(monad, ObjectIterMonad): pass - elif isinstance(monad, AttrMonad) and not monad.attr.nullable: pass + elif not monad.nullable: pass else: notnull_conditions = [ [ 'IS_NOT_NULL', column_ast ] for column_ast in monad.getsql() ] if monad.aggregated: having_conditions.extend(notnull_conditions) @@ -874,21 +874,25 @@ def postIfExp(translator, node): elif not translator.row_value_syntax: throw(NotImplementedError) else: then_sql, else_sql = [ 'ROW' ] + then_sql, [ 'ROW' ] + else_sql expr = [ 'CASE', None, [ [ test_sql, then_sql ] ], else_sql ] - result = ExprMonad.new(translator, result_type, expr) + result = ExprMonad.new(translator, result_type, expr, + nullable=test_monad.nullable or then_monad.nullable or else_monad.nullable) result.aggregated = test_monad.aggregated or then_monad.aggregated or else_monad.aggregated return result -def coerce_monads(m1, m2): +def coerce_monads(m1, m2, for_comparison=False): result_type = coerce_types(m1.type, m2.type) - if result_type in numeric_types and bool in (m1.type, m2.type) and result_type is not bool: + if result_type in numeric_types and bool in (m1.type, m2.type) and ( + result_type is not bool or not for_comparison): translator = m1.translator if translator.dialect == 'PostgreSQL': + if result_type is bool: + result_type = int if m1.type is bool: - new_m1 = NumericExprMonad(translator, int, [ 'TO_INT', m1.getsql()[0] ]) + new_m1 = NumericExprMonad(translator, int, [ 'TO_INT', m1.getsql()[0] ], nullable=m1.nullable) new_m1.aggregated = m1.aggregated m1 = new_m1 if m2.type is bool: - new_m2 = NumericExprMonad(translator, int, [ 'TO_INT', m2.getsql()[0] ]) + new_m2 = NumericExprMonad(translator, int, [ 'TO_INT', m2.getsql()[0] ], nullable=m2.nullable) new_m2.aggregated = m2.aggregated m2 = new_m2 return result_type, m1, m2 @@ -1064,10 +1068,11 @@ class MonadMixin(with_metaclass(MonadMeta)): class Monad(with_metaclass(MonadMeta)): disable_distinct = False disable_ordering = False - def __init__(monad, translator, type): + def __init__(monad, translator, type, nullable=True): monad.node = None monad.translator = translator monad.type = type + monad.nullable = nullable monad.mixin_init() def mixin_init(monad): pass @@ -1110,7 +1115,7 @@ def count(monad, distinct=None): '%s database provider does not support entities ' 'with composite primary keys inside aggregate functions. Got: {EXPR}' % translator.dialect) - result = ExprMonad.new(translator, int, [ 'COUNT', distinct, expr ]) + result = ExprMonad.new(translator, int, [ 'COUNT', distinct, expr ], nullable=False) result.aggregated = True return result def aggregate(monad, func_name, distinct=None, sep=None): @@ -1152,7 +1157,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): if func_name == 'GROUP_CONCAT': if sep is not None: aggr_ast.append(['VALUE', sep]) - result = ExprMonad.new(translator, result_type, aggr_ast) + result = ExprMonad.new(translator, result_type, aggr_ast, nullable=True) result.aggregated = True return result def __call__(monad, *args, **kwargs): throw(TypeError) @@ -1170,9 +1175,9 @@ def __xor__(monad): throw(TypeError) def abs(monad): throw(TypeError) def cast_from_json(monad, type): assert False, monad def to_int(monad): - return NumericExprMonad(monad.translator, int, [ 'TO_INT', monad.getsql()[0] ]) + return NumericExprMonad(monad.translator, int, [ 'TO_INT', monad.getsql()[0] ], nullable=monad.nullable) def to_real(monad): - return NumericExprMonad(monad.translator, float, [ 'TO_REAL', monad.getsql()[0] ]) + return NumericExprMonad(monad.translator, float, [ 'TO_REAL', monad.getsql()[0] ], nullable=monad.nullable) def distinct_from_monad(distinct, default=None): if distinct is None: @@ -1182,10 +1187,10 @@ def distinct_from_monad(distinct, default=None): throw(TypeError, '`distinct` value should be True or False. Got: %s' % ast2src(distinct.node)) class RawSQLMonad(Monad): - def __init__(monad, translator, rawtype, varkey): + def __init__(monad, translator, rawtype, varkey, nullable=True): if rawtype.result_type is None: type = rawtype else: type = normalize_type(rawtype.result_type) - Monad.__init__(monad, translator, type) + Monad.__init__(monad, translator, type, nullable=nullable) monad.rawtype = rawtype monad.varkey = varkey def contains(monad, item, not_in=False): @@ -1197,7 +1202,7 @@ def contains(monad, item, not_in=False): '%s database provider does not support tuples. Got: {EXPR} ' % translator.dialect) op = 'NOT_IN' if not_in else 'IN' sql = [ op, expr, monad.getsql() ] - return BoolExprMonad(translator, sql) + return BoolExprMonad(translator, sql, nullable=item.nullable) def nonzero(monad): return monad def getsql(monad, subquery=None): provider = monad.translator.database.provider @@ -1254,7 +1259,7 @@ def raise_forgot_parentheses(monad): class MethodMonad(Monad): def __init__(monad, parent, attrname): - Monad.__init__(monad, parent.translator, 'METHOD') + Monad.__init__(monad, parent.translator, 'METHOD', nullable=False) monad.parent = parent monad.attrname = attrname def getattr(monad, attrname): @@ -1353,7 +1358,7 @@ def contains(monad, x, not_in=False): sql = sqland([ sqlor([ [ 'NE', a, b ] for a, b in izip(left_sql, item.getsql()) ]) for item in monad.items ]) else: sql = sqlor([ sqland([ [ 'EQ', a, b ] for a, b in izip(left_sql, item.getsql()) ]) for item in monad.items ]) - return BoolExprMonad(translator, sql) + return BoolExprMonad(translator, sql, nullable=x.nullable or any(item.nullable for item in monad.items)) def getsql(monad, subquery=None): return [ [ 'ROW' ] + [ item.getsql()[0] for item in monad.items ] ] @@ -1396,31 +1401,41 @@ def __pow__(monad, monad2): left_sql = monad.getsql() right_sql = monad2.getsql() assert len(left_sql) == len(right_sql) == 1 - return NumericExprMonad(translator, float, [ 'POW', left_sql[0], right_sql[0] ]) + return NumericExprMonad(translator, float, [ 'POW', left_sql[0], right_sql[0] ], + nullable=monad.nullable or monad2.nullable) def __neg__(monad): sql = monad.getsql()[0] translator = monad.translator - return NumericExprMonad(translator, monad.type, [ 'NEG', sql ]) + return NumericExprMonad(translator, monad.type, [ 'NEG', sql ], nullable=monad.nullable) def abs(monad): sql = monad.getsql()[0] translator = monad.translator - return NumericExprMonad(translator, monad.type, [ 'ABS', sql ]) + return NumericExprMonad(translator, monad.type, [ 'ABS', sql ], nullable=monad.nullable) def nonzero(monad): translator = monad.translator - return CmpMonad('!=', monad, ConstMonad.new(translator, 0)) + sql = monad.getsql()[0] + if not (translator.dialect == 'PostgreSQL' and monad.type is bool): + sql = [ 'NE', sql, [ 'VALUE', 0 ] ] + return BoolExprMonad(translator, sql, nullable=False) def negate(monad): + sql = monad.getsql()[0] translator = monad.translator - result = CmpMonad('==', monad, ConstMonad.new(translator, 0)) - if isinstance(monad, AttrMonad) and not monad.attr.nullable: - return result - sql = [ 'OR', result.getsql()[0], [ 'IS_NULL', monad.getsql()[0] ] ] - return BoolExprMonad(translator, sql) + pg_bool = translator.dialect == 'PostgreSQL' and monad.type is bool + result_sql = [ 'NOT', sql ] if pg_bool else [ 'EQ', sql, [ 'VALUE', 0 ] ] + if monad.nullable: + if isinstance(monad, AttrMonad): + result_sql = [ 'OR', result_sql, [ 'IS_NULL', sql ] ] + elif pg_bool: + result_sql = [ 'NOT', [ 'COALESCE', sql, [ 'VALUE', True ] ] ] + else: + result_sql = [ 'EQ', [ 'COALESCE', sql, [ 'VALUE', 0 ] ], [ 'VALUE', 0 ] ] + return BoolExprMonad(translator, result_sql, nullable=False) def numeric_attr_factory(name): def attr_func(monad): sql = [ name, monad.getsql()[0] ] translator = monad.translator - return NumericExprMonad(translator, int, sql) + return NumericExprMonad(translator, int, sql, nullable=monad.nullable) attr_func.__name__ = name.lower() return attr_func @@ -1431,7 +1446,8 @@ def datetime_binop(monad, monad2): _binop_errmsg % (type2str(monad.type), type2str(monad2.type), op)) expr_monad_cls = DateExprMonad if monad.type is date else DatetimeExprMonad delta = monad2.value if isinstance(monad2, TimedeltaConstMonad) else monad2.getsql()[0] - return expr_monad_cls(translator, monad.type, [ sqlop, monad.getsql()[0], delta ]) + return expr_monad_cls(translator, monad.type, [ sqlop, monad.getsql()[0], delta ], + nullable=monad.nullable or monad2.nullable) datetime_binop.__name__ = sqlop return datetime_binop @@ -1461,7 +1477,7 @@ def mixin_init(monad): def call_date(monad): translator = monad.translator sql = [ 'DATE', monad.getsql()[0] ] - return ExprMonad.new(translator, date, sql) + return ExprMonad.new(translator, date, sql, nullable=monad.nullable) attr_hour = numeric_attr_factory('HOUR') attr_minute = numeric_attr_factory('MINUTE') attr_second = numeric_attr_factory('SECOND') @@ -1477,7 +1493,8 @@ def string_binop(monad, monad2): left_sql = monad.getsql() right_sql = monad2.getsql() assert len(left_sql) == len(right_sql) == 1 - return StringExprMonad(translator, monad.type, [ sqlop, left_sql[0], right_sql[0] ]) + return StringExprMonad(translator, monad.type, [ sqlop, left_sql[0], right_sql[0] ], + nullable=monad.nullable or monad2.nullable) string_binop.__name__ = sqlop return string_binop @@ -1486,7 +1503,7 @@ def func(monad): sql = monad.getsql() assert len(sql) == 1 translator = monad.translator - return StringExprMonad(translator, monad.type, [ sqlop, sql[0] ]) + return StringExprMonad(translator, monad.type, [ sqlop, sql[0] ], nullable=monad.nullable) func.__name__ = sqlop return func @@ -1539,7 +1556,8 @@ def __getitem__(monad, index): len_sql = [ 'SUB', stop_sql, start.getsql()[0] ] sql = [ 'SUBSTR', expr_sql, start_sql, len_sql ] - return StringExprMonad(translator, monad.type, sql) + return StringExprMonad(translator, monad.type, sql, + nullable=monad.nullable or start.nullable or stop is not None and stop.nullable) if isinstance(monad, StringConstMonad) and isinstance(index, NumericConstMonad): return ConstMonad.new(translator, monad.value[index.value]) @@ -1554,17 +1572,30 @@ def __getitem__(monad, index): inner_sql = index.getsql()[0] index_sql = [ 'ADD', inner_sql, [ 'CASE', None, [ (['GE', inner_sql, [ 'VALUE', 0 ]], [ 'VALUE', 1 ]) ], [ 'VALUE', 0 ] ] ] sql = [ 'SUBSTR', expr_sql, index_sql, [ 'VALUE', 1 ] ] - return StringExprMonad(translator, monad.type, sql) + return StringExprMonad(translator, monad.type, sql, nullable=monad.nullable) def negate(monad): sql = monad.getsql()[0] translator = monad.translator - result = BoolExprMonad(translator, [ 'EQ', [ 'LENGTH', sql ], [ 'VALUE', 0 ]]) + if translator.dialect == 'Oracle': + result_sql = [ 'IS_NULL', sql ] + else: + result_sql = [ 'EQ', sql, [ 'VALUE', '' ] ] + if monad.nullable: + if isinstance(monad, AttrMonad): + result_sql = [ 'OR', result_sql, [ 'IS_NULL', sql ] ] + else: + result_sql = [ 'EQ', [ 'COALESCE', sql, [ 'VALUE', '' ] ], [ 'VALUE', '' ]] + result = BoolExprMonad(translator, result_sql, nullable=False) result.aggregated = monad.aggregated return result def nonzero(monad): sql = monad.getsql()[0] translator = monad.translator - result = BoolExprMonad(translator, [ 'GT', [ 'LENGTH', sql ], [ 'VALUE', 0 ]]) + if translator.dialect == 'Oracle': + result_sql = [ 'IS_NOT_NULL', sql ] + else: + result_sql = [ 'NE', sql, [ 'VALUE', '' ] ] + result = BoolExprMonad(translator, result_sql, nullable=False) result.aggregated = monad.aggregated return result def len(monad): @@ -1608,9 +1639,14 @@ def _like(monad, item, before=None, after=None, not_like=False): if before and after: item_sql = [ 'CONCAT', [ 'VALUE', before ], item_sql, [ 'VALUE', after ] ] elif before: item_sql = [ 'CONCAT', [ 'VALUE', before ], item_sql ] elif after: item_sql = [ 'CONCAT', item_sql, [ 'VALUE', after ] ] - sql = [ 'NOT_LIKE' if not_like else 'LIKE', monad.getsql()[0], item_sql ] - if escape: sql.append([ 'VALUE', '!' ]) - return BoolExprMonad(translator, sql) + sql = monad.getsql()[0] + if not_like and monad.nullable and not isinstance(monad, AttrMonad) and translator.dialect != 'Oracle': + sql = [ 'COALESCE', sql, [ 'VALUE', '' ] ] + result_sql = [ 'NOT_LIKE' if not_like else 'LIKE', sql, item_sql ] + if escape: result_sql.append([ 'VALUE', '!' ]) + if not_like and monad.nullable and (isinstance(monad, AttrMonad) or translator.dialect == 'Oracle'): + result_sql = [ 'OR', result_sql, [ 'IS_NULL', sql ] ] + return BoolExprMonad(translator, result_sql, nullable=not_like) def strip(monad, chars, strip_type): translator = monad.translator if chars is not None and not are_comparable_types(monad.type, chars.type, None): @@ -1620,7 +1656,7 @@ def strip(monad, chars, strip_type): parent_sql = monad.getsql()[0] sql = [ strip_type, parent_sql ] if chars is not None: sql.append(chars.getsql()[0]) - return StringExprMonad(translator, monad.type, sql) + return StringExprMonad(translator, monad.type, sql, nullable=monad.nullable) def call_strip(monad, chars=None): return monad.strip(chars, 'TRIM') def call_lstrip(monad, chars=None): @@ -1740,6 +1776,7 @@ def __init__(monad, parent, attr): Monad.__init__(monad, parent.translator, attr_type) monad.parent = parent monad.attr = attr + monad.nullable = attr.nullable def getsql(monad, subquery=None): parent = monad.parent attr = monad.attr @@ -1772,24 +1809,7 @@ def __init__(monad, parent, attr): parent_subquery = parent_monad.tableref.subquery monad.tableref = parent_subquery.add_tableref(name_path, parent_monad.tableref, attr) -class StringAttrMonad(StringMixin, AttrMonad): - def negate(monad): - sql = monad.getsql()[0] - translator = monad.translator - result_sql = [ 'EQ', [ 'LENGTH', sql ], [ 'VALUE', 0 ] ] - if monad.attr.nullable: - result_sql = [ 'OR', result_sql, [ 'IS_NULL', sql ] ] - result = BoolExprMonad(translator, result_sql) - result.aggregated = monad.aggregated - return result - def nonzero(monad): - sql = monad.getsql()[0] - translator = monad.translator - result_sql = [ 'GT', [ 'LENGTH', sql ], [ 'VALUE', 0 ] ] - result = BoolExprMonad(translator, result_sql) - result.aggregated = monad.aggregated - return result - +class StringAttrMonad(StringMixin, AttrMonad): pass class NumericAttrMonad(NumericMixin, AttrMonad): pass class DateAttrMonad(DateMixin, AttrMonad): pass class TimeAttrMonad(TimeMixin, AttrMonad): pass @@ -1822,7 +1842,7 @@ def __new__(cls, *args): return Monad.__new__(cls) def __init__(monad, translator, type, paramkey): type = normalize_type(type) - Monad.__init__(monad, translator, type) + Monad.__init__(monad, translator, type, nullable=False) monad.paramkey = paramkey if not isinstance(type, EntityMeta): provider = translator.database.provider @@ -1860,7 +1880,7 @@ def getsql(monad, subquery=None): class ExprMonad(Monad): @staticmethod - def new(translator, type, sql): + def new(translator, type, sql, nullable=True): if type in numeric_types: cls = NumericExprMonad elif type is unicode: cls = StringExprMonad elif type is date: cls = DateExprMonad @@ -1870,12 +1890,12 @@ def new(translator, type, sql): elif type is Json: cls = JsonExprMonad elif isinstance(type, EntityMeta): cls = ObjectExprMonad else: throw(NotImplementedError, type) # pragma: no cover - return cls(translator, type, sql) - def __new__(cls, *args): + return cls(translator, type, sql, nullable=nullable) + def __new__(cls, *args, **kwargs): if cls is ExprMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) - def __init__(monad, translator, type, sql): - Monad.__init__(monad, translator, type) + def __init__(monad, translator, type, sql, nullable=True): + Monad.__init__(monad, translator, type, nullable=nullable) monad.sql = sql def getsql(monad, subquery=None): return [ monad.sql ] @@ -1955,7 +1975,7 @@ def __new__(cls, *args): return Monad.__new__(cls) def __init__(monad, translator, value): value_type, value = normalize(value) - Monad.__init__(monad, translator, value_type) + Monad.__init__(monad, translator, value_type, nullable=value_type is NoneType) monad.value = value def getsql(monad, subquery=None): return [ [ 'VALUE', monad.value ] ] @@ -1982,15 +2002,15 @@ class TimedeltaConstMonad(TimedeltaMixin, ConstMonad): pass class DatetimeConstMonad(DatetimeMixin, ConstMonad): pass class BoolMonad(Monad): - def __init__(monad, translator): - Monad.__init__(monad, translator, bool) + def __init__(monad, translator, nullable=True): + Monad.__init__(monad, translator, bool, nullable=nullable) sql_negation = { 'IN' : 'NOT_IN', 'EXISTS' : 'NOT_EXISTS', 'LIKE' : 'NOT_LIKE', 'BETWEEN' : 'NOT_BETWEEN', 'IS_NULL' : 'IS_NOT_NULL' } sql_negation.update((value, key) for key, value in items_list(sql_negation)) class BoolExprMonad(BoolMonad): - def __init__(monad, translator, sql): - BoolMonad.__init__(monad, translator) + def __init__(monad, translator, sql, nullable=True): + BoolMonad.__init__(monad, translator, nullable=nullable) monad.sql = sql def getsql(monad, subquery=None): return [ monad.sql ] @@ -2004,8 +2024,8 @@ def negate(monad): elif negated_op == 'NOT': assert len(sql) == 2 negated_sql = sql[1] - else: return NotMonad(translator, sql) - return BoolExprMonad(translator, negated_sql) + else: return NotMonad(monad) + return BoolExprMonad(translator, negated_sql, nullable=monad.nullable) cmp_ops = { '>=' : 'GE', '>' : 'GT', '<=' : 'LE', '<' : 'LT' } @@ -2027,8 +2047,8 @@ def __init__(monad, op, left, right): elif op == 'is': op = '==' elif op == 'is not': op = '!=' check_comparable(left, right, op) - result_type, left, right = coerce_monads(left, right) - BoolMonad.__init__(monad, translator) + result_type, left, right = coerce_monads(left, right, for_comparison=True) + BoolMonad.__init__(monad, translator, nullable=left.nullable or right.nullable) monad.op = op monad.aggregated = getattr(left, 'aggregated', False) or getattr(right, 'aggregated', False) @@ -2082,7 +2102,8 @@ def __init__(monad, operands): elif isinstance(operand, LogicalBinOpMonad) and monad.binop == operand.binop: items.extend(operand.operands) else: items.append(operand) - BoolMonad.__init__(monad, items[0].translator) + nullable = any(item.nullable for item in items) + BoolMonad.__init__(monad, items[0].translator, nullable=nullable) monad.operands = items def getsql(monad, subquery=None): result = [ monad.binop ] @@ -2101,7 +2122,7 @@ class OrMonad(LogicalBinOpMonad): class NotMonad(BoolMonad): def __init__(monad, operand): if operand.type is not bool: operand = operand.nonzero() - BoolMonad.__init__(monad, operand.translator) + BoolMonad.__init__(monad, operand.translator, nullable=operand.nullable) monad.operand = operand def negate(monad): return monad.operand @@ -2192,7 +2213,7 @@ def call(monad, year, month, day): return ConstMonad.new(translator, date(year.value, month.value, day.value)) def call_today(monad): translator = monad.translator - return DateExprMonad(translator, date, [ 'TODAY' ]) + return DateExprMonad(translator, date, [ 'TODAY' ], nullable=monad.nullable) class FuncTimeMonad(FuncMonad): func = time @@ -2231,7 +2252,7 @@ def call(monad, year, month, day, hour=None, minute=None, second=None, microseco return ConstMonad.new(translator, value) def call_now(monad): translator = monad.translator - return DatetimeExprMonad(translator, datetime, [ 'NOW' ]) + return DatetimeExprMonad(translator, datetime, [ 'NOW' ], nullable=monad.nullable) class FuncBetweenMonad(FuncMonad): func = between @@ -2242,7 +2263,7 @@ def call(monad, x, a, b): '%s instance cannot be argument of between() function: {EXPR}' % x.type.__name__) translator = x.translator sql = [ 'BETWEEN', x.getsql()[0], a.getsql()[0], b.getsql()[0] ] - return BoolExprMonad(translator, sql) + return BoolExprMonad(translator, sql, nullable=x.nullable or a.nullable or b.nullable) class FuncConcatMonad(FuncMonad): func = concat @@ -2255,7 +2276,7 @@ def call(monad, *args): if isinstance(t, EntityMeta) or type(t) in (tuple, SetType): throw(TranslationError, 'Invalid argument of concat() function: %s' % ast2src(arg.node)) result_ast.extend(arg.getsql()) - return ExprMonad.new(translator, unicode, result_ast) + return ExprMonad.new(translator, unicode, result_ast, nullable=any(arg.nullable for arg in args)) class FuncLenMonad(FuncMonad): func = len @@ -2287,7 +2308,7 @@ def call(monad, x=None, distinct=None): translator = monad.translator if isinstance(x, StringConstMonad) and x.value == '*': x = None if x is not None: return x.count(distinct) - result = ExprMonad.new(translator, int, [ 'COUNT', None ]) + result = ExprMonad.new(translator, int, [ 'COUNT', None ], nullable=False) result.aggregated = True return result @@ -2329,7 +2350,7 @@ def call(monad, *args): result[i].append(sql) sql = [ [ 'COALESCE' ] + coalesce_args for coalesce_args in result ] if not isinstance(t, EntityMeta): sql = sql[0] - return ExprMonad.new(translator, t, sql) + return ExprMonad.new(translator, t, sql, nullable=all(arg.nullable for arg in args)) class FuncDistinctMonad(FuncMonad): func = utils.distinct, core.distinct @@ -2373,9 +2394,9 @@ def minmax(monad, sqlop, *args): args = list(args) for i, arg in enumerate(args): if arg.type is bool: - args[i] = NumericExprMonad(translator, int, [ 'TO_INT', arg.getsql() ]) + args[i] = NumericExprMonad(translator, int, [ 'TO_INT', arg.getsql() ], nullable=arg.nullable) sql = [ sqlop, None ] + [ arg.getsql()[0] for arg in args ] - return ExprMonad.new(translator, t, sql) + return ExprMonad.new(translator, t, sql, nullable=any(arg.nullable for arg in args)) class FuncSelectMonad(FuncMonad): func = core.select @@ -2399,7 +2420,7 @@ def call(monad, expr): class DescMonad(Monad): def __init__(monad, expr): - Monad.__init__(monad, expr.translator, expr.type) + Monad.__init__(monad, expr.translator, expr.type, nullable=expr.nullable) monad.expr = expr def getsql(monad): return [ [ 'DESC', item ] for item in monad.expr.getsql() ] @@ -2420,7 +2441,7 @@ def __init__(monad, translator, type): FuncMonad.__init__(monad, translator, type) translator.query_result_is_cacheable = False def __call__(monad): - return NumericExprMonad(monad.translator, float, [ 'RANDOM' ]) + return NumericExprMonad(monad.translator, float, [ 'RANDOM' ], nullable=False) class SetMixin(MonadMixin): forced_distinct = False @@ -2468,7 +2489,7 @@ def contains(monad, item, not_in=False): else: conditions += [ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(item.getsql(), expr_list) ] sql_ast = [ 'NOT_EXISTS' if not_in else 'EXISTS', from_ast, [ 'WHERE' ] + conditions ] - result = BoolExprMonad(translator, sql_ast) + result = BoolExprMonad(translator, sql_ast, nullable=False) result.nogroup = True return result elif not not_in: @@ -2476,7 +2497,7 @@ def contains(monad, item, not_in=False): tableref = monad.make_tableref(translator.subquery) expr_list = monad.make_expr_list() expr_ast = sqland([ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(expr_list, item.getsql()) ]) - return BoolExprMonad(translator, expr_ast) + return BoolExprMonad(translator, expr_ast, nullable=False) else: subquery = Subquery(translator.subquery) tableref = monad.make_tableref(subquery) @@ -2491,7 +2512,7 @@ def contains(monad, item, not_in=False): conditions.extend(subquery.conditions) from_ast[-1][-1] = sqland([ from_ast[-1][-1] ] + conditions) expr_ast = sqland([ [ 'IS_NULL', expr ] for expr in expr_list ]) - return BoolExprMonad(translator, expr_ast) + return BoolExprMonad(translator, expr_ast, nullable=False) def getattr(monad, name): try: return Monad.getattr(monad, name) except AttributeError: pass @@ -2567,7 +2588,7 @@ def count(monad, distinct=None): else: sql_ast, optimized = monad._aggregated_scalar_subselect(make_aggr, extra_grouping) translator.aggregated_subquery_paths.add(monad.tableref.name_path) - result = ExprMonad.new(translator, int, sql_ast) + result = ExprMonad.new(translator, int, sql_ast, nullable=False) if optimized: result.aggregated = True else: result.nogroup = True return result @@ -2611,7 +2632,7 @@ def make_aggr(expr_list): else: result_type = item_type translator.aggregated_subquery_paths.add(monad.tableref.name_path) - result = ExprMonad.new(monad.translator, result_type, sql_ast) + result = ExprMonad.new(monad.translator, result_type, sql_ast, nullable=func_name != 'SUM') if optimized: result.aggregated = True else: result.nogroup = True return result @@ -2620,13 +2641,13 @@ def nonzero(monad): sql_ast = [ 'EXISTS', subquery.from_ast, [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] translator = monad.translator - return BoolExprMonad(translator, sql_ast) + return BoolExprMonad(translator, sql_ast, nullable=False) def negate(monad): subquery = monad._subselect() sql_ast = [ 'NOT_EXISTS', subquery.from_ast, [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] translator = monad.translator - return BoolExprMonad(translator, sql_ast) + return BoolExprMonad(translator, sql_ast, nullable=False) call_is_empty = negate def make_tableref(monad, subquery): parent = monad.parent @@ -2794,7 +2815,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): sql_ast = [ 'SELECT', [ 'AGGREGATES', aggr_ast ], subquery.from_ast, [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] - result = ExprMonad.new(translator, result_type, sql_ast) + result = ExprMonad.new(translator, result_type, sql_ast, nullable=func_name != 'SUM') result.nogroup = True else: if not translator.from_optimized: @@ -2804,7 +2825,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): translator.subquery.from_ast.extend(from_ast) translator.from_optimized = True sql_ast = aggr_ast - result = ExprMonad.new(translator, result_type, sql_ast) + result = ExprMonad.new(translator, result_type, sql_ast, nullable=func_name != 'SUM') result.aggregated = True return result def getsql(monad, subquery=None): @@ -2898,17 +2919,17 @@ def contains(monad, item, not_in=False): having_ast = find_or_create_having_ast(subquery_ast) having_ast += in_conditions sql_ast = [ 'NOT_EXISTS' if not_in else 'EXISTS' ] + subquery_ast[2:] - return BoolExprMonad(translator, sql_ast) + return BoolExprMonad(translator, sql_ast, nullable=False) def nonzero(monad): subquery_ast = monad.subtranslator.shallow_copy_of_subquery_ast() subquery_ast = [ 'EXISTS' ] + subquery_ast[2:] translator = monad.translator - return BoolExprMonad(translator, subquery_ast) + return BoolExprMonad(translator, subquery_ast, nullable=False) def negate(monad): sql = monad.nonzero().sql assert sql[0] == 'EXISTS' translator = monad.translator - return BoolExprMonad(translator, [ 'NOT_EXISTS' ] + sql[1:]) + return BoolExprMonad(translator, [ 'NOT_EXISTS' ] + sql[1:], nullable=False) def count(monad, distinct=None): distinct = distinct_from_monad(distinct) translator = monad.translator @@ -2950,7 +2971,7 @@ def count(monad, distinct=None): else: throw(NotImplementedError) # pragma: no cover if sql_ast is None: sql_ast = [ 'SELECT', select_ast, from_ast, where_ast ] - return ExprMonad.new(translator, int, sql_ast) + return ExprMonad.new(translator, int, sql_ast, nullable=False) len = count def aggregate(monad, func_name, distinct=None, sep=None): distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) @@ -2985,7 +3006,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): result_type = unicode else: result_type = expr_type - return ExprMonad.new(translator, result_type, sql_ast) + return ExprMonad.new(translator, result_type, sql_ast, func_name != 'SUM') def call_count(monad, distinct=None): return monad.count(distinct=distinct) def call_sum(monad, distinct=None): From ec3ad652cd49c61159a7e62bc65d84d5c11a2ab1 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 17 Jul 2018 15:37:29 +0300 Subject: [PATCH 328/547] Store unbound methods in translator method caches --- pony/orm/asttranslation.py | 29 ++++++++++++++++++----------- pony/orm/sqltranslation.py | 4 ++-- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 1e0d4a289..704985076 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -9,18 +9,25 @@ class TranslationError(Exception): pass +pre_method_caches = {} +post_method_caches = {} + class ASTTranslator(object): def __init__(translator, tree): translator.tree = tree - translator.pre_methods = {} - translator.post_methods = {} + translator_cls = translator.__class__ + pre_method_caches.setdefault(translator_cls, {}) + post_method_caches.setdefault(translator_cls, {}) def dispatch(translator, node): - cls = node.__class__ + translator_cls = translator.__class__ + pre_methods = pre_method_caches[translator_cls] + post_methods = post_method_caches[translator_cls] + node_cls = node.__class__ - try: pre_method = translator.pre_methods[cls] + try: pre_method = pre_methods[node_cls] except KeyError: - pre_method = getattr(translator, 'pre' + cls.__name__, translator.default_pre) - translator.pre_methods[cls] = pre_method + pre_method = getattr(translator_cls, 'pre' + node_cls.__name__, translator_cls.default_pre) + pre_methods[node_cls] = pre_method stop = translator.call(pre_method, node) if stop: return @@ -28,13 +35,13 @@ def dispatch(translator, node): for child in node.getChildNodes(): translator.dispatch(child) - try: post_method = translator.post_methods[cls] + try: post_method = post_methods[node_cls] except KeyError: - post_method = getattr(translator, 'post' + cls.__name__, translator.default_post) - translator.post_methods[cls] = post_method + post_method = getattr(translator_cls, 'post' + node_cls.__name__, translator_cls.default_post) + post_methods[node_cls] = post_method translator.call(post_method, node) def call(translator, method, node): - return method(node) + return method(translator, node) def default_pre(translator, node): pass def default_post(translator, node): @@ -62,7 +69,7 @@ def __init__(translator, tree): ASTTranslator.__init__(translator, tree) translator.dispatch(tree) def call(translator, method, node): - node.src = method(node) + node.src = method(translator, node) def default_post(translator, node): throw(NotImplementedError, node) def postGenExpr(translator, node): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index b9446dcb6..bbd416c72 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -82,7 +82,7 @@ def dispatch(translator, node): if hasattr(node, 'monad'): return # monad already assigned somehow if not getattr(node, 'external', False) or getattr(node, 'constant', False): return ASTTranslator.dispatch(translator, node) # default route - translator.call(translator.dispatch_external, node) + translator.call(translator.__class__.dispatch_external, node) def dispatch_external(translator, node): varkey = translator.filter_num, node.src @@ -127,7 +127,7 @@ def dispatch_external(translator, node): monad.aggregated = monad.nogroup = False def call(translator, method, node): - try: monad = method(node) + try: monad = method(translator, node) except Exception: exc_class, exc, tb = sys.exc_info() try: From 4139ecc8b674672d33897745ba763e5fe577e2e7 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 16 Apr 2018 18:45:07 +0300 Subject: [PATCH 329/547] Refactoring: pass limit and offset instead of range to query._fetch() --- pony/orm/core.py | 15 +++++++++------ pony/orm/sqltranslation.py | 9 +++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 856678c10..64fa2797f 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5356,7 +5356,7 @@ def _get_translator(query, query_key, vars): del database._translator_cache[query_key] return None, vars.copy() return translator, new_vars - def _construct_sql_and_arguments(query, range=None, aggr_func_name=None, aggr_func_distinct=None, sep=None): + def _construct_sql_and_arguments(query, limit=None, offset=None, range=None, aggr_func_name=None, aggr_func_distinct=None, sep=None): translator = query._translator expr_type = translator.expr_type if isinstance(expr_type, EntityMeta) and query._attrs_to_prefetch_dict: @@ -5367,7 +5367,8 @@ def _construct_sql_and_arguments(query, range=None, aggr_func_name=None, aggr_fu query._key, vartypes=HashableDict(query._translator.vartypes), getattr_values=HashableDict(translator.getattr_values), - range=range, + limit=limit, + offset=offset, distinct=query._distinct, aggr_func=(aggr_func_name, aggr_func_distinct, sep), for_update=query._for_update, @@ -5379,7 +5380,7 @@ def _construct_sql_and_arguments(query, range=None, aggr_func_name=None, aggr_fu cache_entry = database._constructed_sql_cache.get(sql_key) if cache_entry is None: sql_ast, attr_offsets = translator.construct_sql_ast( - range, query._distinct, aggr_func_name, aggr_func_distinct, sep, + limit, offset, query._distinct, aggr_func_name, aggr_func_distinct, sep, query._for_update, query._nowait, attrs_to_prefetch) cache = database._get_cache() sql, adapter = database.provider.ast2sql(sql_ast) @@ -5397,9 +5398,9 @@ def _construct_sql_and_arguments(query, range=None, aggr_func_name=None, aggr_fu def get_sql(query): sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments() return sql - def _fetch(query, range=None): + def _fetch(query, limit=None, offset=None): translator = query._translator - sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments(range) + sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments(limit, offset) database = query._database cache = database._get_cache() if query._for_update: cache.immediate = True @@ -5761,7 +5762,9 @@ def __getitem__(query, key): if not start: return query._fetch() else: throw(TypeError, "Parameter 'stop' of slice object should be specified") if start >= stop: return [] - return query._fetch(range=(start, stop)) + limit = stop - start + offset = start + return query._fetch(limit, offset) @cut_traceback def limit(query, limit, offset=None): start = offset or 0 diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index bbd416c72..e62fc8369 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -425,7 +425,8 @@ def can_be_optimized(translator): if translator.groupby_monads: return False if len(translator.aggregated_subquery_paths) != 1: return False return next(iter(translator.aggregated_subquery_paths)) - def construct_sql_ast(translator, range=None, distinct=None, aggr_func_name=None, aggr_func_distinct=None, sep=None, + def construct_sql_ast(translator, limit=None, offset=None, distinct=None, + aggr_func_name=None, aggr_func_distinct=None, sep=None, for_update=False, nowait=False, attrs_to_prefetch=(), is_not_null_checks=False): attr_offsets = None if distinct is None: distinct = translator.distinct @@ -526,12 +527,8 @@ def ast_transformer(ast): if translator.order and not aggr_func_name: sql_ast.append([ 'ORDER_BY' ] + translator.order) - if range: + if limit is not None: assert not aggr_func_name - start, stop = range - limit = stop - start - offset = start - assert limit is not None limit_section = [ 'LIMIT', [ 'VALUE', limit ]] if offset: limit_section.append([ 'VALUE', offset ]) sql_ast = sql_ast + [ limit_section ] From c4881c147476e3e480ab40ef3cb0bff423b58f61 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 1 Jul 2018 13:12:32 +0300 Subject: [PATCH 330/547] Move code around --- pony/orm/sqltranslation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index e62fc8369..ab2399679 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -393,6 +393,10 @@ def func(value, converter=converter): translator.row_layout = row_layout translator.col_names = [ src for func, slice_or_offset, src in translator.row_layout ] translator.vars = None + def can_be_optimized(translator): + if translator.groupby_monads: return False + if len(translator.aggregated_subquery_paths) != 1: return False + return next(iter(translator.aggregated_subquery_paths)) def shallow_copy_of_subquery_ast(translator, move_outer_conditions=True, is_not_null_checks=False): subquery_ast, attr_offsets = translator.construct_sql_ast(distinct=False, is_not_null_checks=is_not_null_checks) assert attr_offsets is None @@ -421,10 +425,6 @@ def shallow_copy_of_subquery_ast(translator, move_outer_conditions=True, is_not_ else: where_ast.insert(1, outer_conditions) return [ 'SELECT', select_ast, from_ast, where_ast ] + other_ast - def can_be_optimized(translator): - if translator.groupby_monads: return False - if len(translator.aggregated_subquery_paths) != 1: return False - return next(iter(translator.aggregated_subquery_paths)) def construct_sql_ast(translator, limit=None, offset=None, distinct=None, aggr_func_name=None, aggr_func_distinct=None, sep=None, for_update=False, nowait=False, attrs_to_prefetch=(), is_not_null_checks=False): From ec8ffa21863270809f1acf2ceda503d1e68e3ac5 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 10 Jun 2018 21:20:47 +0300 Subject: [PATCH 331/547] Renaming: Subquery -> SqlQuery, shallow_copy_of_sql_ast() -> construct_subquery_ast() --- pony/orm/core.py | 6 +- pony/orm/dbproviders/oracle.py | 6 +- pony/orm/sqlbuilding.py | 6 +- pony/orm/sqltranslation.py | 362 +++++++++--------- .../test_declarative_join_optimization.py | 6 +- 5 files changed, 193 insertions(+), 193 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 64fa2797f..e70971f56 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5638,7 +5638,7 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names if argnames: if original_names: for name in argnames: - if name not in prev_translator.subquery.tablerefs: throw(TypeError, + if name not in prev_translator.sqlquery.tablerefs: throw(TypeError, 'Lambda argument %s does not correspond to any loop variable in original query' % name) else: expr_type = prev_translator.expr_type @@ -5649,7 +5649,7 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names filter_num = next(filter_num_counter) func_ast, extractors = create_extractors( - func_id, func_ast, globals, locals, special_functions, const_functions, argnames or prev_translator.subquery) + func_id, func_ast, globals, locals, special_functions, const_functions, argnames or prev_translator.sqlquery) if extractors: vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) query._database.provider.normalize_vars(vars, vartypes) @@ -5713,7 +5713,7 @@ def where(query, *args, **kwargs): def _apply_kwargs(query, kwargs, original_names=False): translator = query._translator if original_names: - tablerefs = translator.subquery.tablerefs + tablerefs = translator.sqlquery.tablerefs alias = translator.tree.quals[0].assign.name tableref = tablerefs[alias] entity = tableref.entity diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index b35d28c12..d8b97d873 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -165,7 +165,7 @@ def SELECT(builder, *sections): limit = last_section[1] if len(last_section) > 2: offset = last_section[2] sections = sections[:-1] - result = builder.subquery(*sections) + result = builder._subquery(*sections) indent = builder.indent_spaces * builder.indent if sections[0][0] == 'ROWID': @@ -179,14 +179,14 @@ def SELECT(builder, *sections): elif not offset: result = [ indent0, 'SELECT * FROM (\n' ] builder.indent += 1 - result.extend(builder.subquery(*sections)) + result.extend(builder._subquery(*sections)) builder.indent -= 1 result.extend((indent, ') WHERE ROWNUM <= ', builder(limit), '\n')) else: indent2 = indent + builder.indent_spaces result = [ indent0, 'SELECT %s FROM (\n' % x, indent2, 'SELECT t.*, ROWNUM "row-num" FROM (\n' ] builder.indent += 2 - result.extend(builder.subquery(*sections)) + result.extend(builder._subquery(*sections)) builder.indent -= 2 result.extend((indent2, ') t ')) if limit[0] == 'VALUE' and offset[0] == 'VALUE' \ diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index bde07d2c5..7e9a44089 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -235,7 +235,7 @@ def DELETE(builder, alias, from_ast, where=None): if alias is not None: builder.suppress_aliases = True if not where: return 'DELETE ', builder(from_ast) return 'DELETE ', builder(from_ast), builder(where) - def subquery(builder, *sections): + def _subquery(builder, *sections): builder.indent += 1 if not builder.inner_join_syntax: sections = move_conditions_from_inner_join_to_where(sections) @@ -246,7 +246,7 @@ def SELECT(builder, *sections): prev_suppress_aliases = builder.suppress_aliases builder.suppress_aliases = False try: - result = builder.subquery(*sections) + result = builder._subquery(*sections) if builder.indent: indent = builder.indent_spaces * builder.indent return '(\n', result, indent + ')' @@ -258,7 +258,7 @@ def SELECT_FOR_UPDATE(builder, nowait, *sections): result = builder.SELECT(*sections) return result, 'FOR UPDATE NOWAIT\n' if nowait else 'FOR UPDATE\n' def EXISTS(builder, *sections): - result = builder.subquery(*sections) + result = builder._subquery(*sections) indent = builder.indent_spaces * builder.indent return 'EXISTS (\n', indent, 'SELECT 1\n', result, indent, ')' def NOT_EXISTS(builder, *sections): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index ab2399679..c1a2c2071 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -173,13 +173,13 @@ def __init__(translator, tree, parent_translator, filter_num=None, extractors=No if parent_translator is None: translator.root_translator = translator translator.database = None - translator.subquery = Subquery(left_join=left_join) + translator.sqlquery = SqlQuery(left_join=left_join) assert filter_num is not None translator.filter_num = translator.original_filter_num = filter_num else: translator.root_translator = parent_translator.root_translator translator.database = parent_translator.database - translator.subquery = Subquery(parent_translator.subquery, left_join=left_join) + translator.sqlquery = SqlQuery(parent_translator.sqlquery, left_join=left_join) translator.filter_num = parent_translator.filter_num translator.original_filter_num = None translator.extractors = extractors @@ -194,10 +194,10 @@ def __init__(translator, tree, parent_translator, filter_num=None, extractors=No translator.optimize = optimize translator.from_optimized = False translator.optimization_failed = False - subquery = translator.subquery - tablerefs = subquery.tablerefs + sqlquery = translator.sqlquery + tablerefs = sqlquery.tablerefs translator.distinct = False - translator.conditions = subquery.conditions + translator.conditions = sqlquery.conditions translator.having_conditions = [] translator.order = [] translator.inside_order_by = False @@ -221,9 +221,9 @@ def __init__(translator, tree, parent_translator, filter_num=None, extractors=No assert parent_translator and i == 0 entity = monad.type.item_type if isinstance(monad, EntityMonad): - tablerefs[name] = TableRef(subquery, name, entity) + tablerefs[name] = TableRef(sqlquery, name, entity) elif isinstance(monad, AttrSetMonad): - translator.subquery = monad._subselect(translator.subquery, extract_outer_conditions=False) + translator.sqlquery = monad._subselect(translator.sqlquery, extract_outer_conditions=False) tableref = monad.tableref translator.method_argnames_mapping_stack.append({ name: ObjectIterMonad(translator, tableref, entity)}) @@ -241,7 +241,7 @@ def __init__(translator, tree, parent_translator, filter_num=None, extractors=No 'Collection expected inside left join query. ' 'Got: for %s in %s' % (name, ast2src(qual.iter))) translator.distinct = True - tableref = TableRef(subquery, name, entity) + tableref = TableRef(sqlquery, name, entity) tablerefs[name] = tableref tableref.make_join() node.monad = ObjectIterMonad(translator, tableref, entity) @@ -283,7 +283,7 @@ def __init__(translator, tree, parent_translator, filter_num=None, extractors=No else: can_affect_distinct = True if j == last_index: name_path = name else: name_path += '-' + attr.name - tableref = JoinedTableRef(subquery, name_path, parent_tableref, attr) + tableref = JoinedTableRef(sqlquery, name_path, parent_tableref, attr) if can_affect_distinct is not None: tableref.can_affect_distinct = can_affect_distinct tablerefs[name_path] = tableref @@ -324,7 +324,7 @@ def __init__(translator, tree, parent_translator, filter_num=None, extractors=No elif isinstance(monad, ObjectMixin): tableref = monad.tableref elif isinstance(monad, AttrSetMonad): - tableref = monad.make_tableref(translator.subquery) + tableref = monad.make_tableref(translator.sqlquery) else: assert False # pragma: no cover if translator.aggregated: translator.groupby_monads = [ monad ] @@ -397,7 +397,7 @@ def can_be_optimized(translator): if translator.groupby_monads: return False if len(translator.aggregated_subquery_paths) != 1: return False return next(iter(translator.aggregated_subquery_paths)) - def shallow_copy_of_subquery_ast(translator, move_outer_conditions=True, is_not_null_checks=False): + def construct_subquery_ast(translator, move_outer_conditions=True, is_not_null_checks=False): subquery_ast, attr_offsets = translator.construct_sql_ast(distinct=False, is_not_null_checks=is_not_null_checks) assert attr_offsets is None assert len(subquery_ast) >= 3 and subquery_ast[0] == 'SELECT' @@ -498,7 +498,7 @@ def ast_transformer(ast): select_ast, attr_offsets = translator.expr_type._construct_select_clause_( translator.alias, distinct, translator.tableref.used_attrs, attrs_to_prefetch) sql_ast.append(select_ast) - sql_ast.append(translator.subquery.from_ast) + sql_ast.append(translator.sqlquery.from_ast) conditions = translator.conditions[:] having_conditions = translator.having_conditions[:] @@ -544,9 +544,9 @@ def construct_delete_sql_ast(translator): 'Delete query cannot contains GROUP BY section or aggregate functions') assert not translator.having_conditions tableref = expr_monad.tableref - from_ast = translator.subquery.from_ast + from_ast = translator.sqlquery.from_ast assert from_ast[0] == 'FROM' - if len(from_ast) == 2 and not translator.subquery.used_from_subquery: + if len(from_ast) == 2 and not translator.sqlquery.used_from_subquery: sql_ast = [ 'DELETE', None, from_ast ] if translator.conditions: sql_ast.append([ 'WHERE' ] + translator.conditions) @@ -738,7 +738,7 @@ def resolve_name(translator, name): i = argnames.index(name) return t.expr_monads[i] t = t.parent - tableref = translator.subquery.get_tableref(name) + tableref = translator.sqlquery.get_tableref(name) if tableref is not None: return ObjectIterMonad(translator, tableref, tableref.entity) return None @@ -896,46 +896,46 @@ def coerce_monads(m1, m2, for_comparison=False): max_alias_length = 30 -class Subquery(object): - def __init__(subquery, parent_subquery=None, left_join=False): - subquery.parent_subquery = parent_subquery - subquery.left_join = left_join - subquery.from_ast = [ 'LEFT_JOIN' if left_join else 'FROM' ] - subquery.conditions = [] - subquery.outer_conditions = [] - subquery.tablerefs = {} - if parent_subquery is None: - subquery.alias_counters = {} - subquery.expr_counter = itertools.count(1) +class SqlQuery(object): + def __init__(sqlquery, parent_sqlquery=None, left_join=False): + sqlquery.parent_sqlquery = parent_sqlquery + sqlquery.left_join = left_join + sqlquery.from_ast = [ 'LEFT_JOIN' if left_join else 'FROM' ] + sqlquery.conditions = [] + sqlquery.outer_conditions = [] + sqlquery.tablerefs = {} + if parent_sqlquery is None: + sqlquery.alias_counters = {} + sqlquery.expr_counter = itertools.count(1) else: - subquery.alias_counters = parent_subquery.alias_counters.copy() - subquery.expr_counter = parent_subquery.expr_counter - subquery.used_from_subquery = False - def get_tableref(subquery, name_path, from_subquery=False): - tableref = subquery.tablerefs.get(name_path) + sqlquery.alias_counters = parent_sqlquery.alias_counters.copy() + sqlquery.expr_counter = parent_sqlquery.expr_counter + sqlquery.used_from_subquery = False + def get_tableref(sqlquery, name_path, from_subquery=False): + tableref = sqlquery.tablerefs.get(name_path) if tableref is not None: - if from_subquery and subquery.parent_subquery is None: - subquery.used_from_subquery = True + if from_subquery and sqlquery.parent_sqlquery is None: + sqlquery.used_from_subquery = True return tableref - if subquery.parent_subquery: - return subquery.parent_subquery.get_tableref(name_path, from_subquery=True) + if sqlquery.parent_sqlquery: + return sqlquery.parent_sqlquery.get_tableref(name_path, from_subquery=True) return None __contains__ = get_tableref - def add_tableref(subquery, name_path, parent_tableref, attr): - tablerefs = subquery.tablerefs + def add_tableref(sqlquery, name_path, parent_tableref, attr): + tablerefs = sqlquery.tablerefs assert name_path not in tablerefs - tableref = JoinedTableRef(subquery, name_path, parent_tableref, attr) + tableref = JoinedTableRef(sqlquery, name_path, parent_tableref, attr) tablerefs[name_path] = tableref return tableref - def make_alias(subquery, name): + def make_alias(sqlquery, name): name = name[:max_alias_length-3].lower() - i = subquery.alias_counters.setdefault(name, 0) + 1 + i = sqlquery.alias_counters.setdefault(name, 0) + 1 alias = name if i == 1 and name != 't' else '%s-%d' % (name, i) - subquery.alias_counters[name] = i + sqlquery.alias_counters[name] = i return alias - def join_table(subquery, parent_alias, alias, table_name, join_cond): + def join_table(sqlquery, parent_alias, alias, table_name, join_cond): new_item = [alias, 'TABLE', table_name, join_cond] - from_ast = subquery.from_ast + from_ast = sqlquery.from_ast for i in xrange(1, len(from_ast)): if from_ast[i][0] == parent_alias: for j in xrange(i+1, len(from_ast)): @@ -945,9 +945,9 @@ def join_table(subquery, parent_alias, alias, table_name, join_cond): from_ast.append(new_item) class TableRef(object): - def __init__(tableref, subquery, name, entity): - tableref.subquery = subquery - tableref.alias = subquery.make_alias(name) + def __init__(tableref, sqlquery, name, entity): + tableref.sqlquery = sqlquery + tableref.alias = sqlquery.make_alias(name) tableref.name_path = tableref.alias tableref.entity = entity tableref.joined = False @@ -956,18 +956,18 @@ def __init__(tableref, subquery, name, entity): def make_join(tableref, pk_only=False): entity = tableref.entity if not tableref.joined: - subquery = tableref.subquery - subquery.from_ast.append([ tableref.alias, 'TABLE', entity._table_ ]) + sqlquery = tableref.sqlquery + sqlquery.from_ast.append([ tableref.alias, 'TABLE', entity._table_ ]) if entity._discriminator_attr_: discr_criteria = entity._construct_discriminator_criteria_(tableref.alias) assert discr_criteria is not None - subquery.conditions.append(discr_criteria) + sqlquery.conditions.append(discr_criteria) tableref.joined = True return tableref.alias, entity._pk_columns_ class JoinedTableRef(object): - def __init__(tableref, subquery, name_path, parent_tableref, attr): - tableref.subquery = subquery + def __init__(tableref, sqlquery, name_path, parent_tableref, attr): + tableref.sqlquery = sqlquery tableref.name_path = name_path tableref.var_name = name_path if is_ident(name_path) else None tableref.alias = None @@ -984,7 +984,7 @@ def make_join(tableref, pk_only=False): if tableref.joined: if pk_only or not tableref.optimized: return tableref.alias, tableref.pk_columns - subquery = tableref.subquery + sqlquery = tableref.sqlquery attr = tableref.attr parent_pk_only = attr.pk_offset is not None or attr.is_collection parent_alias, left_pk_columns = tableref.parent_tableref.make_join(parent_pk_only) @@ -996,7 +996,7 @@ def make_join(tableref, pk_only=False): assert reverse.columns and not reverse.is_collection rentity = reverse.entity pk_columns = rentity._pk_columns_ - alias = subquery.make_alias(tableref.var_name or rentity.__name__) + alias = sqlquery.make_alias(tableref.var_name or rentity.__name__) join_cond = join_tables(parent_alias, alias, left_pk_columns, reverse.columns) else: if attr.pk_offset is not None: @@ -1009,19 +1009,19 @@ def make_join(tableref, pk_only=False): tableref.optimized = True tableref.joined = True return parent_alias, left_columns - alias = subquery.make_alias(tableref.var_name or entity.__name__) + alias = sqlquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(parent_alias, alias, left_columns, pk_columns) elif not attr.reverse.is_collection: - alias = subquery.make_alias(tableref.var_name or entity.__name__) + alias = sqlquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(parent_alias, alias, left_pk_columns, attr.reverse.columns) else: right_m2m_columns = attr.reverse_columns if attr.symmetric else attr.columns if not tableref.joined: m2m_table = attr.table - m2m_alias = subquery.make_alias('t') + m2m_alias = sqlquery.make_alias('t') reverse_columns = attr.columns if attr.symmetric else attr.reverse.columns m2m_join_cond = join_tables(parent_alias, m2m_alias, left_pk_columns, reverse_columns) - subquery.join_table(parent_alias, m2m_alias, m2m_table, m2m_join_cond) + sqlquery.join_table(parent_alias, m2m_alias, m2m_table, m2m_join_cond) if pk_only: tableref.alias = m2m_alias tableref.pk_columns = right_m2m_columns @@ -1031,13 +1031,13 @@ def make_join(tableref, pk_only=False): elif tableref.optimized: assert not pk_only m2m_alias = tableref.alias - alias = subquery.make_alias(tableref.var_name or entity.__name__) + alias = sqlquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(m2m_alias, alias, right_m2m_columns, pk_columns) if not pk_only and entity._discriminator_attr_: discr_criteria = entity._construct_discriminator_criteria_(alias) assert discr_criteria is not None join_cond.append(discr_criteria) - subquery.join_table(parent_alias, alias, entity._table_, join_cond) + sqlquery.join_table(parent_alias, alias, entity._table_, join_cond) tableref.alias = alias tableref.pk_columns = pk_columns tableref.optimized = False @@ -1201,7 +1201,7 @@ def contains(monad, item, not_in=False): sql = [ op, expr, monad.getsql() ] return BoolExprMonad(translator, sql, nullable=item.nullable) def nonzero(monad): return monad - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): provider = monad.translator.database.provider rawtype = monad.rawtype result = [] @@ -1356,7 +1356,7 @@ def contains(monad, x, not_in=False): else: sql = sqlor([ sqland([ [ 'EQ', a, b ] for a, b in izip(left_sql, item.getsql()) ]) for item in monad.items ]) return BoolExprMonad(translator, sql, nullable=x.nullable or any(item.nullable for item in monad.items)) - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): return [ [ 'ROW' ] + [ item.getsql()[0] for item in monad.items ] ] class BufferMixin(MonadMixin): @@ -1739,7 +1739,7 @@ class ObjectIterMonad(ObjectMixin, Monad): def __init__(monad, translator, tableref, entity): Monad.__init__(monad, translator, entity) monad.tableref = tableref - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): entity = monad.type alias, pk_columns = monad.tableref.make_join(pk_only=True) return [ [ 'COLUMN', alias, column ] for column in pk_columns ] @@ -1774,7 +1774,7 @@ def __init__(monad, parent, attr): monad.parent = parent monad.attr = attr monad.nullable = attr.nullable - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): parent = monad.parent attr = monad.attr entity = attr.entity @@ -1787,9 +1787,9 @@ def getsql(monad, subquery=None): else: columns = parent_columns elif not attr.columns: assert isinstance(monad, ObjectAttrMonad) - subquery = monad.translator.subquery - monad.translator.left_join = subquery.left_join = True - subquery.from_ast[0] = 'LEFT_JOIN' + sqlquery = monad.translator.sqlquery + monad.translator.left_join = sqlquery.left_join = True + sqlquery.from_ast[0] = 'LEFT_JOIN' alias, columns = monad.tableref.make_join() else: columns = attr.columns return [ [ 'COLUMN', alias, column ] for column in columns ] @@ -1801,10 +1801,10 @@ def __init__(monad, parent, attr): parent_monad = monad.parent entity = monad.type name_path = '-'.join((parent_monad.tableref.name_path, attr.name)) - monad.tableref = translator.subquery.get_tableref(name_path) + monad.tableref = translator.sqlquery.get_tableref(name_path) if monad.tableref is None: - parent_subquery = parent_monad.tableref.subquery - monad.tableref = parent_subquery.add_tableref(name_path, parent_monad.tableref, attr) + parent_sqlquery = parent_monad.tableref.sqlquery + monad.tableref = parent_sqlquery.add_tableref(name_path, parent_monad.tableref, attr) class StringAttrMonad(StringMixin, AttrMonad): pass class NumericAttrMonad(NumericMixin, AttrMonad): pass @@ -1845,7 +1845,7 @@ def __init__(monad, translator, type, paramkey): provider = translator.database.provider monad.converter = provider.get_converter_by_py_type(type) else: monad.converter = None - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): return [ [ 'PARAM', monad.paramkey, monad.converter ] ] class ObjectParamMonad(ObjectMixin, ParamMonad): @@ -1855,7 +1855,7 @@ def __init__(monad, translator, entity, paramkey): varkey, i, j = paramkey assert j is None monad.params = tuple((varkey, i, j) for j in xrange(len(entity._pk_converters_))) - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): entity = monad.type assert len(monad.params) == len(entity._pk_converters_) return [ [ 'PARAM', param, converter ] for param, converter in izip(monad.params, entity._pk_converters_) ] @@ -1872,7 +1872,7 @@ class BufferParamMonad(BufferMixin, ParamMonad): pass class UuidParamMonad(UuidMixin, ParamMonad): pass class JsonParamMonad(JsonMixin, ParamMonad): - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): return [ [ 'JSON_PARAM', ParamMonad.getsql(monad)[0] ] ] class ExprMonad(Monad): @@ -1894,11 +1894,11 @@ def __new__(cls, *args, **kwargs): def __init__(monad, translator, type, sql, nullable=True): Monad.__init__(monad, translator, type, nullable=nullable) monad.sql = sql - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): return [ monad.sql ] class ObjectExprMonad(ObjectMixin, ExprMonad): - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): return monad.sql class StringExprMonad(StringMixin, ExprMonad): pass @@ -1974,7 +1974,7 @@ def __init__(monad, translator, value): value_type, value = normalize(value) Monad.__init__(monad, translator, value_type, nullable=value_type is NoneType) monad.value = value - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): return [ [ 'VALUE', monad.value ] ] class NoneMonad(ConstMonad): @@ -2009,7 +2009,7 @@ class BoolExprMonad(BoolMonad): def __init__(monad, translator, sql, nullable=True): BoolMonad.__init__(monad, translator, nullable=nullable) monad.sql = sql - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): return [ monad.sql ] def negate(monad): translator = monad.translator @@ -2058,7 +2058,7 @@ def __init__(monad, op, left, right): monad.right = right def negate(monad): return CmpMonad(cmp_negate[monad.op], monad.left, monad.right) - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): op = monad.op left_sql = monad.left.getsql() if op == 'is': @@ -2102,7 +2102,7 @@ def __init__(monad, operands): nullable = any(item.nullable for item in items) BoolMonad.__init__(monad, items[0].translator, nullable=nullable) monad.operands = items - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): result = [ monad.binop ] for operand in monad.operands: operand_sql = operand.getsql() @@ -2123,7 +2123,7 @@ def __init__(monad, operand): monad.operand = operand def negate(monad): return monad.operand - def getsql(monad, subquery=None): + def getsql(monad, sqlquery=None): return [ [ 'NOT', monad.operand.getsql()[0] ] ] class ErrorSpecialFuncMonad(Monad): @@ -2460,7 +2460,7 @@ def __init__(monad, parent, attr): Monad.__init__(monad, translator, SetType(item_type)) monad.parent = parent monad.attr = attr - monad.subquery = None + monad.sqlquery = None monad.tableref = None def cmp(monad, op, monad2): translator = monad.translator @@ -2473,10 +2473,10 @@ def contains(monad, item, not_in=False): check_comparable(item, monad, 'in') if not translator.hint_join: sqlop = 'NOT_IN' if not_in else 'IN' - subquery = monad._subselect() - expr_list = subquery.expr_list - from_ast = subquery.from_ast - conditions = subquery.outer_conditions + subquery.conditions + sqlquery = monad._subselect() + expr_list = sqlquery.expr_list + from_ast = sqlquery.from_ast + conditions = sqlquery.outer_conditions + sqlquery.conditions if len(expr_list) == 1: subquery_ast = [ 'SELECT', [ 'ALL' ] + expr_list, from_ast, [ 'WHERE' ] + conditions ] sql_ast = [ sqlop, item.getsql()[0], subquery_ast ] @@ -2491,22 +2491,22 @@ def contains(monad, item, not_in=False): return result elif not not_in: translator.distinct = True - tableref = monad.make_tableref(translator.subquery) + tableref = monad.make_tableref(translator.sqlquery) expr_list = monad.make_expr_list() expr_ast = sqland([ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(expr_list, item.getsql()) ]) return BoolExprMonad(translator, expr_ast, nullable=False) else: - subquery = Subquery(translator.subquery) - tableref = monad.make_tableref(subquery) + sqlquery = SqlQuery(translator.sqlquery) + tableref = monad.make_tableref(sqlquery) attr = monad.attr alias, columns = tableref.make_join(pk_only=attr.reverse) expr_list = monad.make_expr_list() if not attr.reverse: columns = attr.columns - from_ast = translator.subquery.from_ast + from_ast = translator.sqlquery.from_ast from_ast[0] = 'LEFT_JOIN' - from_ast.extend(subquery.from_ast[1:]) + from_ast.extend(sqlquery.from_ast[1:]) conditions = [ [ 'EQ', [ 'COLUMN', alias, column ], expr ] for column, expr in izip(columns, item.getsql()) ] - conditions.extend(subquery.conditions) + conditions.extend(sqlquery.conditions) from_ast[-1][-1] = sqland([ from_ast[-1][-1] ] + conditions) expr_ast = sqland([ [ 'IS_NULL', expr ] for expr in expr_list ]) return BoolExprMonad(translator, expr_ast, nullable=False) @@ -2537,11 +2537,11 @@ def count(monad, distinct=None): translator = monad.translator distinct = distinct_from_monad(distinct, monad.requires_distinct(joined=translator.hint_join, for_count=True)) - subquery = monad._subselect() - expr_list = subquery.expr_list - from_ast = subquery.from_ast - inner_conditions = subquery.conditions - outer_conditions = subquery.outer_conditions + sqlquery = monad._subselect() + expr_list = sqlquery.expr_list + from_ast = sqlquery.from_ast + inner_conditions = sqlquery.conditions + outer_conditions = sqlquery.outer_conditions sql_ast = make_aggr = None extra_grouping = False @@ -2634,29 +2634,29 @@ def make_aggr(expr_list): else: result.nogroup = True return result def nonzero(monad): - subquery = monad._subselect() - sql_ast = [ 'EXISTS', subquery.from_ast, - [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] + sqlquery = monad._subselect() + sql_ast = [ 'EXISTS', sqlquery.from_ast, + [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] translator = monad.translator return BoolExprMonad(translator, sql_ast, nullable=False) def negate(monad): - subquery = monad._subselect() - sql_ast = [ 'NOT_EXISTS', subquery.from_ast, - [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] + sqlquery = monad._subselect() + sql_ast = [ 'NOT_EXISTS', sqlquery.from_ast, + [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] translator = monad.translator return BoolExprMonad(translator, sql_ast, nullable=False) call_is_empty = negate - def make_tableref(monad, subquery): + def make_tableref(monad, sqlquery): parent = monad.parent attr = monad.attr translator = monad.translator if isinstance(parent, ObjectMixin): parent_tableref = parent.tableref - elif isinstance(parent, AttrSetMonad): parent_tableref = parent.make_tableref(subquery) + elif isinstance(parent, AttrSetMonad): parent_tableref = parent.make_tableref(sqlquery) else: assert False # pragma: no cover if attr.reverse: name_path = parent_tableref.name_path + '-' + attr.name - monad.tableref = subquery.get_tableref(name_path) \ - or subquery.add_tableref(name_path, parent_tableref, attr) + monad.tableref = sqlquery.get_tableref(name_path) \ + or sqlquery.add_tableref(name_path, parent_tableref, attr) else: monad.tableref = parent_tableref monad.tableref.can_affect_distinct = True return monad.tableref @@ -2672,36 +2672,36 @@ def make_expr_list(monad): return [ [ 'COLUMN', alias, column ] for column in columns ] def _aggregated_scalar_subselect(monad, make_aggr, extra_grouping=False): translator = monad.translator - subquery = monad._subselect() + sqlquery = monad._subselect() optimized = False if translator.optimize == monad.tableref.name_path: - sql_ast = make_aggr(subquery.expr_list) + sql_ast = make_aggr(sqlquery.expr_list) optimized = True if not translator.from_optimized: - from_ast = monad.subquery.from_ast[1:] - assert subquery.outer_conditions - from_ast[0] = from_ast[0] + [ sqland(subquery.outer_conditions) ] - translator.subquery.from_ast.extend(from_ast) + from_ast = monad.sqlquery.from_ast[1:] + assert sqlquery.outer_conditions + from_ast[0] = from_ast[0] + [ sqland(sqlquery.outer_conditions) ] + translator.sqlquery.from_ast.extend(from_ast) translator.from_optimized = True - else: sql_ast = [ 'SELECT', [ 'AGGREGATES', make_aggr(subquery.expr_list) ], - subquery.from_ast, - [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] + else: sql_ast = [ 'SELECT', [ 'AGGREGATES', make_aggr(sqlquery.expr_list) ], + sqlquery.from_ast, + [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] if extra_grouping: # This is for Oracle only, with COUNT(COUNT(*)) - sql_ast.append([ 'GROUP_BY' ] + subquery.expr_list) + sql_ast.append([ 'GROUP_BY' ] + sqlquery.expr_list) return sql_ast, optimized def _joined_subselect(monad, make_aggr, extra_grouping=False, coalesce_to_zero=False): translator = monad.translator - subquery = monad._subselect() - expr_list = subquery.expr_list - from_ast = subquery.from_ast - inner_conditions = subquery.conditions - outer_conditions = subquery.outer_conditions + sqlquery = monad._subselect() + expr_list = sqlquery.expr_list + from_ast = sqlquery.from_ast + inner_conditions = sqlquery.conditions + outer_conditions = sqlquery.outer_conditions groupby_columns = [ inner_column[:] for cond, outer_column, inner_column in outer_conditions ] assert len({alias for _, alias, column in groupby_columns}) == 1 if extra_grouping: - inner_alias = translator.subquery.make_alias('t') + inner_alias = translator.sqlquery.make_alias('t') inner_columns = [ 'DISTINCT' ] col_mapping = {} col_names = set() @@ -2714,7 +2714,7 @@ def _joined_subselect(monad, make_aggr, extra_grouping=False, coalesce_to_zero=F expr = [ 'AS', column_ast, cname ] new_name = cname else: - new_name = 'expr-%d' % next(translator.subquery.expr_counter) + new_name = 'expr-%d' % next(translator.sqlquery.expr_counter) col_mapping[tname, cname] = new_name expr = [ 'AS', column_ast, new_name ] inner_columns.append(expr) @@ -2730,42 +2730,42 @@ def _joined_subselect(monad, make_aggr, extra_grouping=False, coalesce_to_zero=F new_name = col_mapping[tname, cname] outer_conditions[i] = [ cond, outer_column, [ 'COLUMN', inner_alias, new_name ] ] - subquery_columns = [ 'ALL' ] + subselect_columns = [ 'ALL' ] for column_ast in groupby_columns: assert column_ast[0] == 'COLUMN' - subquery_columns.append([ 'AS', column_ast, column_ast[2] ]) - expr_name = 'expr-%d' % next(translator.subquery.expr_counter) - subquery_columns.append([ 'AS', make_aggr(expr_list), expr_name ]) - subquery_ast = [ subquery_columns, from_ast ] + subselect_columns.append([ 'AS', column_ast, column_ast[2] ]) + expr_name = 'expr-%d' % next(translator.sqlquery.expr_counter) + subselect_columns.append([ 'AS', make_aggr(expr_list), expr_name ]) + subquery_ast = [ subselect_columns, from_ast ] if inner_conditions and not extra_grouping: subquery_ast.append([ 'WHERE' ] + inner_conditions) subquery_ast.append([ 'GROUP_BY' ] + groupby_columns) - alias = translator.subquery.make_alias('t') + alias = translator.sqlquery.make_alias('t') for cond in outer_conditions: cond[2][1] = alias - translator.subquery.from_ast.append([ alias, 'SELECT', subquery_ast, sqland(outer_conditions) ]) + translator.sqlquery.from_ast.append([ alias, 'SELECT', subquery_ast, sqland(outer_conditions) ]) expr_ast = [ 'COLUMN', alias, expr_name ] if coalesce_to_zero: expr_ast = [ 'COALESCE', expr_ast, [ 'VALUE', 0 ] ] return expr_ast, False - def _subselect(monad, subquery=None, extract_outer_conditions=True): - if monad.subquery is not None: return monad.subquery + def _subselect(monad, sqlquery=None, extract_outer_conditions=True): + if monad.sqlquery is not None: return monad.sqlquery attr = monad.attr translator = monad.translator - if subquery is None: - subquery = Subquery(translator.subquery) - monad.make_tableref(subquery) - subquery.expr_list = monad.make_expr_list() + if sqlquery is None: + sqlquery = SqlQuery(translator.sqlquery) + monad.make_tableref(sqlquery) + sqlquery.expr_list = monad.make_expr_list() if not attr.reverse and not attr.is_required: - subquery.conditions.extend([ 'IS_NOT_NULL', expr ] for expr in subquery.expr_list) - if subquery is not translator.subquery and extract_outer_conditions: - outer_cond = subquery.from_ast[1].pop() - if outer_cond[0] == 'AND': subquery.outer_conditions = outer_cond[1:] - else: subquery.outer_conditions = [ outer_cond ] - monad.subquery = subquery - return subquery - def getsql(monad, subquery=None): - if subquery is None: subquery = monad.translator.subquery - monad.make_tableref(subquery) + sqlquery.conditions.extend([ 'IS_NOT_NULL', expr ] for expr in sqlquery.expr_list) + if sqlquery is not translator.sqlquery and extract_outer_conditions: + outer_cond = sqlquery.from_ast[1].pop() + if outer_cond[0] == 'AND': sqlquery.outer_conditions = outer_cond[1:] + else: sqlquery.outer_conditions = [ outer_cond ] + monad.sqlquery = sqlquery + return sqlquery + def getsql(monad, sqlquery=None): + if sqlquery is None: sqlquery = monad.translator.sqlquery + monad.make_tableref(sqlquery) return monad.make_expr_list() __add__ = make_attrset_binop('+', 'ADD') __sub__ = make_attrset_binop('-', 'SUB') @@ -2792,12 +2792,12 @@ def __init__(monad, op, sqlop, left, right): def aggregate(monad, func_name, distinct=None, sep=None): distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) translator = monad.translator - subquery = Subquery(translator.subquery) - expr = monad.getsql(subquery)[0] + sqlquery = SqlQuery(translator.sqlquery) + expr = monad.getsql(sqlquery)[0] translator.aggregated_subquery_paths.add(monad.tableref.name_path) - outer_cond = subquery.from_ast[1].pop() - if outer_cond[0] == 'AND': subquery.outer_conditions = outer_cond[1:] - else: subquery.outer_conditions = [ outer_cond ] + outer_cond = sqlquery.from_ast[1].pop() + if outer_cond[0] == 'AND': sqlquery.outer_conditions = outer_cond[1:] + else: sqlquery.outer_conditions = [ outer_cond ] if func_name == 'AVG': result_type = float elif func_name == 'GROUP_CONCAT': @@ -2810,26 +2810,26 @@ def aggregate(monad, func_name, distinct=None, sep=None): aggr_ast.append(['VALUE', sep]) if translator.optimize != monad.tableref.name_path: sql_ast = [ 'SELECT', [ 'AGGREGATES', aggr_ast ], - subquery.from_ast, - [ 'WHERE' ] + subquery.outer_conditions + subquery.conditions ] + sqlquery.from_ast, + [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] result = ExprMonad.new(translator, result_type, sql_ast, nullable=func_name != 'SUM') result.nogroup = True else: if not translator.from_optimized: - from_ast = subquery.from_ast[1:] - assert subquery.outer_conditions - from_ast[0] = from_ast[0] + [ sqland(subquery.outer_conditions) ] - translator.subquery.from_ast.extend(from_ast) + from_ast = sqlquery.from_ast[1:] + assert sqlquery.outer_conditions + from_ast[0] = from_ast[0] + [ sqland(sqlquery.outer_conditions) ] + translator.sqlquery.from_ast.extend(from_ast) translator.from_optimized = True sql_ast = aggr_ast result = ExprMonad.new(translator, result_type, sql_ast, nullable=func_name != 'SUM') result.aggregated = True return result - def getsql(monad, subquery=None): - if subquery is None: subquery = monad.translator.subquery + def getsql(monad, sqlquery=None): + if sqlquery is None: sqlquery = monad.translator.sqlquery left, right = monad.left, monad.right - left_expr = left.getsql(subquery)[0] - right_expr = right.getsql(subquery)[0] + left_expr = left.getsql(sqlquery)[0] + right_expr = right.getsql(sqlquery)[0] if isinstance(left, NumericMixin): left_path = '' else: left_path = left.tableref.name_path + '-' if isinstance(right, NumericMixin): right_path = '' @@ -2865,17 +2865,17 @@ def contains(monad, item, not_in=False): else: item_columns = item.getsql() sub = monad.subtranslator - if translator.hint_join and len(sub.subquery.from_ast[1]) == 3: - subquery_ast = sub.shallow_copy_of_subquery_ast() + if translator.hint_join and len(sub.sqlquery.from_ast[1]) == 3: + subquery_ast = sub.construct_subquery_ast() select_ast, from_ast, where_ast = subquery_ast[1:4] - subquery = translator.subquery + sqlquery = translator.sqlquery if not not_in: translator.distinct = True - if subquery.from_ast[0] == 'FROM': - subquery.from_ast[0] = 'INNER_JOIN' + if sqlquery.from_ast[0] == 'FROM': + sqlquery.from_ast[0] = 'INNER_JOIN' else: - subquery.left_join = True - subquery.from_ast[0] = 'LEFT_JOIN' + sqlquery.left_join = True + sqlquery.from_ast[0] = 'LEFT_JOIN' col_names = set() new_names = [] exprs = [] @@ -2889,26 +2889,26 @@ def contains(monad, item, not_in=False): new_names.append(col_name) select_ast[i] = [ 'AS', column_ast, col_name ] continue - new_name = 'expr-%d' % next(subquery.expr_counter) + new_name = 'expr-%d' % next(sqlquery.expr_counter) new_names.append(new_name) select_ast[i] = [ 'AS', column_ast, new_name ] - alias = subquery.make_alias('t') + alias = sqlquery.make_alias('t') outer_conditions = [ [ 'EQ', item_column, [ 'COLUMN', alias, new_name ] ] for item_column, new_name in izip(item_columns, new_names) ] - subquery.from_ast.append([ alias, 'SELECT', subquery_ast[1:], sqland(outer_conditions) ]) + sqlquery.from_ast.append([ alias, 'SELECT', subquery_ast[1:], sqland(outer_conditions) ]) if not_in: sql_ast = sqland([ [ 'IS_NULL', [ 'COLUMN', alias, new_name ] ] for new_name in new_names ]) else: sql_ast = [ 'EQ', [ 'VALUE', 1 ], [ 'VALUE', 1 ] ] else: if len(item_columns) == 1: - subquery_ast = sub.shallow_copy_of_subquery_ast(is_not_null_checks=not_in) + subquery_ast = sub.construct_subquery_ast(is_not_null_checks=not_in) sql_ast = [ 'NOT_IN' if not_in else 'IN', item_columns[0], subquery_ast ] elif translator.row_value_syntax: - subquery_ast = sub.shallow_copy_of_subquery_ast(is_not_null_checks=not_in) + subquery_ast = sub.construct_subquery_ast(is_not_null_checks=not_in) sql_ast = [ 'NOT_IN' if not_in else 'IN', [ 'ROW' ] + item_columns, subquery_ast ] else: - subquery_ast = sub.shallow_copy_of_subquery_ast() + subquery_ast = sub.construct_subquery_ast() select_ast, from_ast, where_ast = subquery_ast[1:4] in_conditions = [ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(item_columns, select_ast[1:]) ] if not sub.aggregated: where_ast += in_conditions @@ -2918,7 +2918,7 @@ def contains(monad, item, not_in=False): sql_ast = [ 'NOT_EXISTS' if not_in else 'EXISTS' ] + subquery_ast[2:] return BoolExprMonad(translator, sql_ast, nullable=False) def nonzero(monad): - subquery_ast = monad.subtranslator.shallow_copy_of_subquery_ast() + subquery_ast = monad.subtranslator.construct_subquery_ast() subquery_ast = [ 'EXISTS' ] + subquery_ast[2:] translator = monad.translator return BoolExprMonad(translator, subquery_ast, nullable=False) @@ -2933,7 +2933,7 @@ def count(monad, distinct=None): sub = monad.subtranslator if sub.aggregated: throw(TranslationError, 'Too complex aggregation in {EXPR}') - subquery_ast = sub.shallow_copy_of_subquery_ast() + subquery_ast = sub.construct_subquery_ast() from_ast, where_ast = subquery_ast[2:4] sql_ast = None @@ -2952,13 +2952,13 @@ def count(monad, distinct=None): if translator.sqlite_version < (3, 6, 21): if sub.aggregated: throw(TranslationError) alias, pk_columns = sub.tableref.make_join(pk_only=False) - subquery_ast = sub.shallow_copy_of_subquery_ast() + subquery_ast = sub.construct_subquery_ast() from_ast, where_ast = subquery_ast[2:4] sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', True if distinct is None else distinct, [ 'COLUMN', alias, 'ROWID' ] ] ], from_ast, where_ast ] else: - alias = translator.subquery.make_alias('t') + alias = translator.sqlquery.make_alias('t') sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', None ] ], [ 'FROM', [ alias, 'SELECT', [ [ 'DISTINCT' if distinct is not False else 'ALL' ] + sub.expr_columns, from_ast, where_ast ] ] ] ] @@ -2975,7 +2975,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): translator = monad.translator sub = monad.subtranslator if sub.aggregated: throw(TranslationError, 'Too complex aggregation in {EXPR}') - subquery_ast = sub.shallow_copy_of_subquery_ast() + subquery_ast = sub.construct_subquery_ast() from_ast, where_ast = subquery_ast[2:4] expr_type = sub.expr_type if func_name in ('SUM', 'AVG'): @@ -3022,14 +3022,14 @@ def call_group_concat(monad, sep=None, distinct=None): def getsql(monad): throw(NotImplementedError) -def find_or_create_having_ast(subquery_ast): +def find_or_create_having_ast(sections): groupby_offset = None - for i, section in enumerate(subquery_ast): + for i, section in enumerate(sections): section_name = section[0] if section_name == 'GROUP_BY': groupby_offset = i elif section_name == 'HAVING': return section having_ast = [ 'HAVING' ] - subquery_ast.insert(groupby_offset + 1, having_ast) + sections.insert(groupby_offset + 1, having_ast) return having_ast diff --git a/pony/orm/tests/test_declarative_join_optimization.py b/pony/orm/tests/test_declarative_join_optimization.py index 3427e060c..8bc380a77 100644 --- a/pony/orm/tests/test_declarative_join_optimization.py +++ b/pony/orm/tests/test_declarative_join_optimization.py @@ -61,14 +61,14 @@ def test4(self): self.assertEqual(Group._table_ not in flatten(q._translator.conditions), True) def test5(self): q = select(s for s in Student if s.group.number == 1 or s.group.major == '1') - self.assertEqual(Group._table_ in flatten(q._translator.subquery.from_ast), True) + self.assertEqual(Group._table_ in flatten(q._translator.sqlquery.from_ast), True) # def test6(self): ### Broken with ExprEvalError: Group[101] raises ObjectNotFound: Group[101] # q = select(s for s in Student if s.group == Group[101]) - # self.assertEqual(Group._table_ not in flatten(q._translator.subquery.from_ast), True) + # self.assertEqual(Group._table_ not in flatten(q._translator.sqlquery.from_ast), True) def test7(self): q = select(s for s in Student if sum(c.credits for c in Course if s.group.dept == c.dept) > 10) objects = q[:] - self.assertEqual(str(q._translator.subquery.from_ast), + self.assertEqual(str(q._translator.sqlquery.from_ast), "['FROM', ['s', 'TABLE', 'Student'], ['group', 'TABLE', 'Group', ['EQ', ['COLUMN', 's', 'group'], ['COLUMN', 'group', 'number']]]]") From b81ac6041d3807b0784b4a8d8682dee04a97694c Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 21 Jul 2018 00:18:19 +0300 Subject: [PATCH 332/547] Change construct_subquery_ast() options: add `distinct=None`, `aliases=None` and `star=None` and remove `move_outer_conditions=True` --- pony/orm/sqlbuilding.py | 2 ++ pony/orm/sqltranslation.py | 37 ++++++++++++++++++++++++------------- 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 7e9a44089..b02fd7dc9 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -375,6 +375,8 @@ def make_param(builder, param_class, paramkey, *args): return param def make_composite_param(builder, paramkey, items, func): return builder.make_param(builder.composite_param_class, paramkey, items, func) + def STAR(builder, table_alias): + return builder.quote_name(table_alias), '.*' def ROW(builder, *items): return '(', join(', ', imap(builder, items)), ')' def VALUE(builder, value): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index c1a2c2071..b4666f821 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -397,16 +397,27 @@ def can_be_optimized(translator): if translator.groupby_monads: return False if len(translator.aggregated_subquery_paths) != 1: return False return next(iter(translator.aggregated_subquery_paths)) - def construct_subquery_ast(translator, move_outer_conditions=True, is_not_null_checks=False): - subquery_ast, attr_offsets = translator.construct_sql_ast(distinct=False, is_not_null_checks=is_not_null_checks) + def construct_subquery_ast(translator, aliases=None, star=None, distinct=None, is_not_null_checks=False): + subquery_ast, attr_offsets = translator.construct_sql_ast(distinct=distinct, is_not_null_checks=is_not_null_checks) assert attr_offsets is None assert len(subquery_ast) >= 3 and subquery_ast[0] == 'SELECT' select_ast = subquery_ast[1][:] - assert select_ast[0] == 'ALL' + assert select_ast[0] in ('ALL', 'DISTINCT', 'AGGREGATES'), select_ast + if aliases: + assert not star and len(aliases) == len(select_ast) - 1 + for i, alias in enumerate(aliases): + expr = select_ast[i+1] + if expr[0] == 'AS': expr = expr[1] + select_ast[i+1] = [ 'AS', expr, alias ] + elif star is not None: + assert isinstance(star, basestring) + for section in subquery_ast: + assert section[0] not in ('GROUP_BY', 'HAVING'), subquery_ast + select_ast[1:] = [ [ 'STAR', star ] ] from_ast = subquery_ast[2][:] - assert from_ast[0] == 'FROM' + assert from_ast[0] in ('FROM', 'LEFT_JOIN') if len(subquery_ast) == 3: where_ast = [ 'WHERE' ] @@ -418,7 +429,7 @@ def construct_subquery_ast(translator, move_outer_conditions=True, is_not_null_c where_ast = subquery_ast[3][:] other_ast = subquery_ast[4:] - if move_outer_conditions and len(from_ast[1]) == 4: + if len(from_ast[1]) == 4: outer_conditions = from_ast[1][-1] from_ast[1] = from_ast[1][:-1] if outer_conditions[0] == 'AND': where_ast[1:1] = outer_conditions[1:] @@ -2866,7 +2877,7 @@ def contains(monad, item, not_in=False): sub = monad.subtranslator if translator.hint_join and len(sub.sqlquery.from_ast[1]) == 3: - subquery_ast = sub.construct_subquery_ast() + subquery_ast = sub.construct_subquery_ast(distinct=False) select_ast, from_ast, where_ast = subquery_ast[1:4] sqlquery = translator.sqlquery if not not_in: @@ -2902,13 +2913,13 @@ def contains(monad, item, not_in=False): else: sql_ast = [ 'EQ', [ 'VALUE', 1 ], [ 'VALUE', 1 ] ] else: if len(item_columns) == 1: - subquery_ast = sub.construct_subquery_ast(is_not_null_checks=not_in) + subquery_ast = sub.construct_subquery_ast(distinct=False, is_not_null_checks=not_in) sql_ast = [ 'NOT_IN' if not_in else 'IN', item_columns[0], subquery_ast ] elif translator.row_value_syntax: - subquery_ast = sub.construct_subquery_ast(is_not_null_checks=not_in) + subquery_ast = sub.construct_subquery_ast(distinct=False, is_not_null_checks=not_in) sql_ast = [ 'NOT_IN' if not_in else 'IN', [ 'ROW' ] + item_columns, subquery_ast ] else: - subquery_ast = sub.construct_subquery_ast() + subquery_ast = sub.construct_subquery_ast(distinct=False) select_ast, from_ast, where_ast = subquery_ast[1:4] in_conditions = [ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(item_columns, select_ast[1:]) ] if not sub.aggregated: where_ast += in_conditions @@ -2918,7 +2929,7 @@ def contains(monad, item, not_in=False): sql_ast = [ 'NOT_EXISTS' if not_in else 'EXISTS' ] + subquery_ast[2:] return BoolExprMonad(translator, sql_ast, nullable=False) def nonzero(monad): - subquery_ast = monad.subtranslator.construct_subquery_ast() + subquery_ast = monad.subtranslator.construct_subquery_ast(distinct=False) subquery_ast = [ 'EXISTS' ] + subquery_ast[2:] translator = monad.translator return BoolExprMonad(translator, subquery_ast, nullable=False) @@ -2933,7 +2944,7 @@ def count(monad, distinct=None): sub = monad.subtranslator if sub.aggregated: throw(TranslationError, 'Too complex aggregation in {EXPR}') - subquery_ast = sub.construct_subquery_ast() + subquery_ast = sub.construct_subquery_ast(distinct=False) from_ast, where_ast = subquery_ast[2:4] sql_ast = None @@ -2952,7 +2963,7 @@ def count(monad, distinct=None): if translator.sqlite_version < (3, 6, 21): if sub.aggregated: throw(TranslationError) alias, pk_columns = sub.tableref.make_join(pk_only=False) - subquery_ast = sub.construct_subquery_ast() + subquery_ast = sub.construct_subquery_ast(distinct=False) from_ast, where_ast = subquery_ast[2:4] sql_ast = [ 'SELECT', [ 'AGGREGATES', [ 'COUNT', True if distinct is None else distinct, [ 'COLUMN', alias, 'ROWID' ] ] ], @@ -2975,7 +2986,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): translator = monad.translator sub = monad.subtranslator if sub.aggregated: throw(TranslationError, 'Too complex aggregation in {EXPR}') - subquery_ast = sub.construct_subquery_ast() + subquery_ast = sub.construct_subquery_ast(distinct=False) from_ast, where_ast = subquery_ast[2:4] expr_type = sub.expr_type if func_name in ('SUM', 'AVG'): From c28c0c882d06cc643cbfcb26d97c6c16138743bd Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 20 Jul 2018 19:38:08 +0300 Subject: [PATCH 333/547] Fix query optimization check --- pony/orm/sqltranslation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index b4666f821..b65ae6a79 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -396,7 +396,11 @@ def func(value, converter=converter): def can_be_optimized(translator): if translator.groupby_monads: return False if len(translator.aggregated_subquery_paths) != 1: return False - return next(iter(translator.aggregated_subquery_paths)) + aggr_path = next(iter(translator.aggregated_subquery_paths)) + for name in translator.sqlquery.tablerefs: + if not aggr_path.startswith(name): + return False + return aggr_path def construct_subquery_ast(translator, aliases=None, star=None, distinct=None, is_not_null_checks=False): subquery_ast, attr_offsets = translator.construct_sql_ast(distinct=distinct, is_not_null_checks=is_not_null_checks) assert attr_offsets is None From 60bc463aac03e1893d0242a6ac599f8f1b7c683a Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 21 Jul 2018 03:18:40 +0300 Subject: [PATCH 334/547] Support of select(x for x in previous_query) --- pony/orm/core.py | 261 +++++++++--- pony/orm/dbapiprovider.py | 9 +- pony/orm/dbproviders/oracle.py | 1 + pony/orm/ormtypes.py | 11 + pony/orm/sqlbuilding.py | 5 +- pony/orm/sqltranslation.py | 395 +++++++++++++----- .../tests/test_declarative_query_set_monad.py | 65 ++- pony/orm/tests/test_query.py | 15 +- .../tests/test_select_from_select_queries.py | 225 ++++++++++ 9 files changed, 811 insertions(+), 176 deletions(-) create mode 100644 pony/orm/tests/test_select_from_select_queries.py diff --git a/pony/orm/core.py b/pony/orm/core.py index e70971f56..bd6dad0bd 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -19,7 +19,7 @@ import pony from pony import options from pony.orm.decompiling import decompile -from pony.orm.ormtypes import LongStr, LongUnicode, numeric_types, RawSQL, normalize, Json, TrackedValue +from pony.orm.ormtypes import LongStr, LongUnicode, numeric_types, RawSQL, normalize, Json, TrackedValue, QueryType from pony.orm.asttranslation import ast2src, create_extractors, TranslationError from pony.orm.dbapiprovider import ( DBAPIProvider, DBException, Warning, Error, InterfaceError, DatabaseError, DataError, @@ -205,9 +205,17 @@ def __init__(exc, src, cause): TranslationError.__init__(exc, msg) exc.cause = cause -class OptimizationFailed(Exception): +class PonyInternalException(Exception): + pass + +class OptimizationFailed(PonyInternalException): pass # Internal exception, cannot be encountered in user code +class UseAnotherTranslator(PonyInternalException): + def __init__(self, translator): + Exception.__init__(self, 'This exception should be catched internally by PonyORM') + self.translator = translator + class DatabaseContainsIncorrectValue(RuntimeWarning): pass @@ -5156,18 +5164,19 @@ def get_globals_and_locals(args, kwargs, frame_depth, from_generator=False): % (len(args) > 4 and 's' or '', ', '.join(imap(repr, args[3:])))) else: locals = {} - locals.update(sys._getframe(frame_depth+1).f_locals) + if frame_depth is not None: + locals.update(sys._getframe(frame_depth+1).f_locals) if type(func) is types.GeneratorType: globals = func.gi_frame.f_globals locals.update(func.gi_frame.f_locals) - else: + elif frame_depth is not None: globals = sys._getframe(frame_depth+1).f_globals if kwargs: throw(TypeError, 'Keyword arguments cannot be specified together with positional arguments') return func, globals, locals def make_query(args, frame_depth, left_join=False): gen, globals, locals = get_globals_and_locals( - args, kwargs=None, frame_depth=frame_depth+1, from_generator=True) + args, kwargs=None, frame_depth=frame_depth+1 if frame_depth is not None else None, from_generator=True) if isinstance(gen, types.GeneratorType): tree, external_names, cells = decompile(gen) code_key = id(gen.gi_frame.f_code) @@ -5256,6 +5265,22 @@ def extract_vars(filter_num, extractors, globals, locals, cells=None): key = filter_num, src try: value = extractor(globals, locals) except Exception as cause: raise ExprEvalError(src, cause) + + if isinstance(value, types.GeneratorType): + value = make_query((value,), frame_depth=None) + + if isinstance(value, QueryResultIterator): + query_result = value._query_result + if query_result._items: + value = tuple(query_result._items[value._position:]) + else: + value = value._query_result._query + + if isinstance(value, Query): + query = value + vars.update(query._vars) + vartypes.update(query._translator.vartypes) + if src == 'None' and value is not None: throw(TranslationError) if src == 'True' and value is not True: throw(TranslationError) if src == 'False' and value is not False: throw(TranslationError) @@ -5268,7 +5293,8 @@ def extract_vars(filter_num, extractors, globals, locals, cells=None): else: unsupported = True if unsupported: typename = type(value).__name__ - if src == '.0': throw(TypeError, 'Cannot iterate over non-entity object') + if src == '.0': + throw(TypeError, 'Query cannot iterate over anything but entity class or another query') throw(TypeError, 'Expression `%s` has unsupported type %r' % (src, typename)) vartypes[key], value = normalize(value) vars[key] = value @@ -5288,14 +5314,23 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False node = tree.quals[0].iter origin = vars[filter_num, node.src] - if isinstance(origin, EntityIter): origin = origin.entity - elif not isinstance(origin, EntityMeta): - if node.src == '.0': throw(TypeError, 'Cannot iterate over non-entity object') - throw(TypeError, 'Cannot iterate over non-entity object %s' % node.src) - query._origin = origin - database = origin._database_ - if database is None: throw(TranslationError, 'Entity %s is not mapped to a database' % origin.__name__) - if database.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % origin.__name__) + if isinstance(origin, Query): + database = origin._translator.database + elif isinstance(origin, QueryResult): + database = origin._query._translator.database + elif isinstance(origin, QueryResultIterator): + database = origin._query_result._query._translator.database + else: + if isinstance(origin, EntityIter): + origin = origin.entity + elif not isinstance(origin, EntityMeta): + if node.src == '.0': throw(TypeError, + 'Query can only iterate over entity or another query (not a list of objects)') + throw(TypeError, 'Cannot iterate over non-entity object %s' % node.src) + database = origin._database_ + if database is None: throw(TranslationError, 'Entity %s is not mapped to a database' % origin.__name__) + if database.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % origin.__name__) + database.provider.normalize_vars(vars, vartypes) query._key = HashableDict(code_key=code_key, vartypes=vartypes, left_join=left_join, filters=()) @@ -5308,13 +5343,20 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False pickled_tree = pickle_ast(tree) tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) translator_cls = database.provider.translator_cls - translator = translator_cls(tree_copy, None, filter_num, extractors, vars, vartypes.copy(), left_join=left_join) + try: + translator = translator_cls(tree_copy, None, filter_num, extractors, vars, vartypes.copy(), left_join=left_join) + except UseAnotherTranslator as e: + translator = e.translator name_path = translator.can_be_optimized() if name_path: tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) - try: translator = translator_cls(tree_copy, None, filter_num, extractors, vars, vartypes.copy(), - left_join=True, optimize=name_path) - except OptimizationFailed: translator.optimization_failed = True + try: + translator = translator_cls(tree_copy, None, filter_num, extractors, vars, vartypes.copy(), + left_join=True, optimize=name_path) + except UseAnotherTranslator as e: + translator = e.translator + except OptimizationFailed: + translator.optimization_failed = True translator.pickled_tree = pickled_tree if translator.can_be_cached: database._translator_cache[query._key] = translator @@ -5327,6 +5369,10 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False query._prefetch = False query._entities_to_prefetch = set() query._attrs_to_prefetch_dict = defaultdict(set) + def _get_type_(query): + return QueryType(query) + def _normalize_var(query, query_type): + return query_type, query def _clone(query, **kwargs): new_query = object.__new__(Query) new_query.__dict__.update(query.__dict__) @@ -5398,38 +5444,37 @@ def _construct_sql_and_arguments(query, limit=None, offset=None, range=None, agg def get_sql(query): sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments() return sql - def _fetch(query, limit=None, offset=None): + def _actual_fetch(query, limit=None, offset=None): translator = query._translator sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments(limit, offset) database = query._database cache = database._get_cache() if query._for_update: cache.immediate = True cache.prepare_connection_for_query_execution() # may clear cache.query_results - try: result = cache.query_results[query_key] + try: items = cache.query_results[query_key] except KeyError: cursor = database._exec_sql(sql, arguments) if isinstance(translator.expr_type, EntityMeta): entity = translator.expr_type - result = entity._fetch_objects(cursor, attr_offsets, for_update=query._for_update, + items = entity._fetch_objects(cursor, attr_offsets, for_update=query._for_update, used_attrs=translator.get_used_attrs()) elif len(translator.row_layout) == 1: func, slice_or_offset, src = translator.row_layout[0] - result = list(starmap(func, cursor.fetchall())) + items = list(starmap(func, cursor.fetchall())) else: - result = [ tuple(func(sql_row[slice_or_offset]) + items = [ tuple(func(sql_row[slice_or_offset]) for func, slice_or_offset, src in translator.row_layout) for sql_row in cursor.fetchall() ] for i, t in enumerate(translator.expr_type): - if isinstance(t, EntityMeta) and t._subclasses_: t._load_many_(row[i] for row in result) - if query_key is not None: cache.query_results[query_key] = result + if isinstance(t, EntityMeta) and t._subclasses_: t._load_many_(row[i] for row in items) + if query_key is not None: cache.query_results[query_key] = items else: stats = database._dblocal.stats stat = stats.get(sql) if stat is not None: stat.cache_count += 1 else: stats[sql] = QueryStat(sql) - - if query._prefetch: query._do_prefetch(result) - return QueryResult(result, query, translator.expr_type, translator.col_names) + if query._prefetch: query._do_prefetch(items) + return items @cut_traceback def prefetch(query, *args): query = query._clone(_entities_to_prefetch=query._entities_to_prefetch.copy(), @@ -5547,7 +5592,7 @@ def delete(query, bulk=None): if not isinstance(query._translator.expr_type, EntityMeta): throw(TypeError, 'Delete query should be applied to a single entity. Got: %s' % ast2src(query._translator.tree.expr)) - objects = query._fetch() + objects = query._actual_fetch() for obj in objects: obj._delete_() return len(objects) translator = query._translator @@ -5567,10 +5612,10 @@ def delete(query, bulk=None): return cursor.rowcount @cut_traceback def __len__(query): - return len(query._fetch()) + return len(query._actual_fetch()) @cut_traceback def __iter__(query): - return iter(query._fetch()) + return iter(query._fetch(lazy=True)) @cut_traceback def order_by(query, *args): return query._order_by('order_by', *args) @@ -5638,18 +5683,20 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names if argnames: if original_names: for name in argnames: - if name not in prev_translator.sqlquery.tablerefs: throw(TypeError, - 'Lambda argument %s does not correspond to any loop variable in original query' % name) + if name not in prev_translator.namespace: throw(TypeError, + 'Lambda argument `%s` does not correspond to any variable in original query' % name) else: expr_type = prev_translator.expr_type expr_count = len(expr_type) if type(expr_type) is tuple else 1 if len(argnames) != expr_count: throw(TypeError, 'Incorrect number of lambda arguments. ' 'Expected: %d, got: %d' % (expr_count, len(argnames))) + else: + original_names = True filter_num = next(filter_num_counter) func_ast, extractors = create_extractors( - func_id, func_ast, globals, locals, special_functions, const_functions, argnames or prev_translator.sqlquery) + func_id, func_ast, globals, locals, special_functions, const_functions, argnames or prev_translator.namespace) if extractors: vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) query._database.provider.normalize_vars(vars, vartypes) @@ -5669,10 +5716,13 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names if name_path: tree_copy = unpickle_ast(prev_translator.pickled_tree) # tree = deepcopy(tree) translator_cls = prev_translator.__class__ - new_translator = translator_cls( + try: + new_translator = translator_cls( tree_copy, None, prev_translator.original_filter_num, prev_translator.extractors, None, prev_translator.vartypes.copy(), left_join=True, optimize=name_path) + except UseAnotherTranslator: + assert False new_translator = query._reapply_filters(new_translator) new_translator = new_translator.apply_lambda(filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) query._database._translator_cache[new_key] = new_translator @@ -5762,19 +5812,19 @@ def __getitem__(query, key): if not start: return query._fetch() else: throw(TypeError, "Parameter 'stop' of slice object should be specified") if start >= stop: return [] - limit = stop - start - offset = start + return query._fetch(limit=stop-start, offset=start) + def _fetch(query, limit=None, offset=None, lazy=False): + return QueryResult(query, limit, offset, lazy=lazy) + @cut_traceback + def fetch(query, limit=None, offset=None): return query._fetch(limit, offset) @cut_traceback def limit(query, limit, offset=None): - start = offset or 0 - stop = start + limit - return query[start:stop] + return query._fetch(limit, offset, lazy=True) @cut_traceback def page(query, pagenum, pagesize=10): - start = (pagenum - 1) * pagesize - stop = pagenum * pagesize - return query[start:stop] + offset = (pagenum - 1) * pagesize + return query._fetch(pagesize, offset, lazy=True) def _aggregate(query, aggr_func_name, distinct=None, sep=None): translator = query._translator sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments( @@ -5839,25 +5889,108 @@ def strcut(s, width): else: return s[:width-3] + '...' -class QueryResult(list): - __slots__ = '_query', '_expr_type', '_col_names' - def __init__(result, list, query, expr_type, col_names): - result[:] = list - result._query = query - result._expr_type = expr_type - result._col_names = col_names - def __getstate__(result): - return list(result), result._expr_type, result._col_names - def __setstate__(result, state): - result[:] = state[0] - result._expr_type = state[1] - result._col_names = state[2] + +class QueryResultIterator(object): + __slots__ = '_query_result', '_position' + def __init__(self, query_result): + self._query_result = query_result + self._position = 0 + def _get_type_(self): + if self._position != 0: + throw(NotImplementedError, 'Cannot use partially exhausted iterator, please convert to list') + return self._query_result._get_type_() + def _normalize_var(self, query_type): + if self._position != 0: throw(NotImplementedError) + return self._query_result._normalize_var(query_type) + def next(self): + qr = self._query_result + if qr._items is None: + qr._items = qr._query._actual_fetch(qr._limit, qr._offset) + if self._position >= len(qr._items): + raise StopIteration + item = qr._items[self._position] + self._position += 1 + return item + __next__ = next + def __length_hint__(self): + return len(self._query_result) - self._position + + +class QueryResult(object): + __slots__ = '_query', '_limit', '_offset', '_items', '_expr_type', '_col_names' + def __init__(self, query, limit, offset, lazy): + translator = query._translator + self._query = query + self._limit = limit + self._offset = offset + self._items = None if lazy else self._query._actual_fetch(limit, offset) + self._expr_type = translator.expr_type + self._col_names = translator.col_names + def _get_type_(self): + if self._items is None: + return QueryType(self._query) + item_type = self._query._translator.expr_type + return tuple(item_type for item in self._items) + def _normalize_var(self, query_type): + if self._items is None: + return query_type, self._query + items = tuple(normalize(item) for item in self._items) + item_type = self._query._translator.expr_type + return tuple(item_type for item in items), items + def _get_items(self): + if self._items is None: + self._items = self._query._actual_fetch(self._limit, self._offset) + return self._items + def __getstate__(self): + return self._get_items(), self._limit, self._offset, self._expr_type, self._col_names + def __setstate__(self, state): + self._query = None + self._items, self._limit, self._offset, self._expr_type, self._col_names = state + def __repr__(self): + return repr(self._get_items()) + def __iter__(self): + return QueryResultIterator(self) + def __len__(self): + if self._items is None: + self._items = self._query._actual_fetch(self._limit, self._offset) + return len(self._items) + def __getitem__(self, key): + if self._items is None: + self._items = self._query._actual_fetch(self._limit, self._offset) + return self._items[key] + def __contains__(self, item): + return item in self._get_items() + def index(self, item): + return self._get_items().index(item) + def _other_items(self, other): + return other._get_items() if isinstance(other, QueryResult) else other + def __eq__(self, other): + return self._get_items() == self._other_items(other) + def __ne__(self, other): + return self._get_items() != self._other_items(other) + def __lt__(self, other): + return self._get_items() < self._other_items(other) + def __le__(self, other): + return self._get_items() <= self._other_items(other) + def __gt__(self, other): + return self._get_items() > self._other_items(other) + def __ge__(self, other): + return self._get_items() >= self._other_items(other) + def __reversed__(self): + return reversed(self._get_items()) + def reverse(self): + self._get_items().reverse() + def sort(self, *args, **kwargs): + self._get_items().sort(*args, **kwargs) @cut_traceback - def show(result, width=None): + def show(self, width=None): + if self._items is None: + self._items = self._query._actual_fetch(self._limit, self._offset) + if not width: width = options.CONSOLE_WIDTH max_columns = width // 5 - expr_type = result._expr_type - col_names = result._col_names + expr_type = self._expr_type + col_names = self._col_names def to_str(x): return tostring(x).replace('\n', ' ') @@ -5870,11 +6003,11 @@ def to_str(x): col_name = col_names[0] row_maker = lambda obj: (getattr(obj, col_name),) else: row_maker = attrgetter(*col_names) - rows = [ tuple(to_str(value) for value in row_maker(obj)) for obj in result ] + rows = [tuple(to_str(value) for value in row_maker(obj)) for obj in self._items] elif len(col_names) == 1: - rows = [ (to_str(obj),) for obj in result ] + rows = [(to_str(obj),) for obj in self._items] else: - rows = [ tuple(to_str(value) for value in row) for row in result ] + rows = [tuple(to_str(value) for value in row) for row in self._items] remaining_columns = {} for col_num, colname in enumerate(col_names): @@ -5902,8 +6035,8 @@ def to_str(x): print(strjoin('+', ('-' * width_dict[i] for i in xrange(len(col_names))))) for row in rows: print(strjoin('|', (strcut(item, width_dict[i]) for i, item in enumerate(row)))) - def to_json(result, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): - return result._query._database.to_json(result, include, exclude, converter, with_schema, schema_hash) + def to_json(self, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): + return self._query._database.to_json(self, include, exclude, converter, with_schema, schema_hash) @cut_traceback diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 5f1b7129e..cd913230f 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -1,5 +1,5 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PY2, basestring, unicode, buffer, int_types +from pony.py23compat import PY2, basestring, unicode, buffer, int_types, iteritems import os, re, json from decimal import Decimal, InvalidOperation @@ -9,7 +9,7 @@ import pony from pony.utils import is_utf8, decorator, throw, localbase, deprecated from pony.converting import str2date, str2time, str2datetime, str2timedelta -from pony.orm.ormtypes import LongStr, LongUnicode, RawSQLType, TrackedValue, Json +from pony.orm.ormtypes import LongStr, LongUnicode, RawSQLType, TrackedValue, Json, QueryType class DBException(Exception): def __init__(exc, original_exc, *args): @@ -195,7 +195,10 @@ def format_table_name(provider, name): return provider.quote_name(name) def normalize_vars(provider, vars, vartypes): - pass + for key, value in iteritems(vars): + vartype = vartypes[key] + if isinstance(vartype, QueryType): + vartypes[key], vars[key] = value._normalize_var(vartype) def ast2sql(provider, ast): builder = provider.sqlbuilder_cls(provider, ast) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index d8b97d873..bb6eaddad 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -437,6 +437,7 @@ def normalize_name(provider, name): return name[:provider.max_name_len].upper() def normalize_vars(provider, vars, vartypes): + DBAPIProvider.normalize_vars(provider, vars, vartypes) for key, value in iteritems(vars): if value == '': vars[key] = None diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 0f9f012cd..48ae017f5 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -123,6 +123,17 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) +class QueryType(object): + def __init__(self, query): + self.query_key = query._key + self.translator = query._translator + def __hash__(self): + return hash(self.query_key) + def __eq__(self, other): + return type(other) is QueryType and self.query_key == other.query_key + def __ne__(self, other): + return not self.__eq__(other) + numeric_types = {bool, int, float, Decimal} comparable_types = {int, float, Decimal, unicode, date, time, datetime, timedelta, bool, UUID} primitive_types = comparable_types | {buffer} diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index b02fd7dc9..043e33f7a 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -23,11 +23,12 @@ def __init__(param, paramstyle, paramkey, converter=None, optimistic=False): def eval(param, values): varkey, i, j = param.paramkey value = values[varkey] - t = type(value) if i is not None: + t = type(value) if t is tuple: value = value[i] elif t is RawSQL: value = value.values[i] - else: assert False + elif hasattr(value, '_get_items'): value = value._get_items()[i] + else: assert False, t if j is not None: assert type(type(value)).__name__ == 'EntityMeta' value = value._get_raw_pkval_()[j] diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index b65ae6a79..b8bcdbbef 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -18,10 +18,10 @@ from pony.orm.ormtypes import \ numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ normalize, normalize_type, coerce_types, are_comparable_types, \ - Json + Json, QueryType from pony.orm import core from pony.orm.core import EntityMeta, Set, JOIN, OptimizationFailed, Attribute, DescWrapper, \ - special_functions, const_functions, extract_vars + special_functions, const_functions, extract_vars, Query, UseAnotherTranslator NoneType = type(None) @@ -166,6 +166,7 @@ def call(translator, method, node): return monad def __init__(translator, tree, parent_translator, filter_num=None, extractors=None, vars=None, vartypes=None, left_join=False, optimize=None): + this = translator assert isinstance(tree, ast.GenExprInner), tree ASTTranslator.__init__(translator, tree) translator.can_be_cached = True @@ -185,8 +186,7 @@ def __init__(translator, tree, parent_translator, filter_num=None, extractors=No translator.extractors = extractors translator.vars = vars translator.vartypes = vartypes - translator.lambda_argnames = None - translator.method_argnames_mapping_stack = [] + translator.namespace_stack = [{}] if not parent_translator else [ parent_translator.namespace.copy() ] translator.func_extractors_map = {} translator.getattr_values = {} translator.func_vartypes = {} @@ -194,10 +194,8 @@ def __init__(translator, tree, parent_translator, filter_num=None, extractors=No translator.optimize = optimize translator.from_optimized = False translator.optimization_failed = False - sqlquery = translator.sqlquery - tablerefs = sqlquery.tablerefs translator.distinct = False - translator.conditions = sqlquery.conditions + translator.conditions = translator.sqlquery.conditions translator.having_conditions = [] translator.order = [] translator.inside_order_by = False @@ -207,90 +205,131 @@ def __init__(translator, tree, parent_translator, filter_num=None, extractors=No translator.aggregated_subquery_paths = set() for i, qual in enumerate(tree.quals): assign = qual.assign - if not isinstance(assign, ast.AssName): throw(NotImplementedError, ast2src(assign)) - if assign.flags != 'OP_ASSIGN': throw(TypeError, ast2src(assign)) + if isinstance(assign, ast.AssTuple): + ass_names = tuple(assign.nodes) + elif isinstance(assign, ast.AssName): + ass_names = (assign,) + else: + throw(NotImplemented, ast2src(assign)) + + for ass_name in ass_names: + if not isinstance(ass_name, ast.AssName): + throw(NotImplemented, ast2src(ass_name)) + if ass_name.flags != 'OP_ASSIGN': + throw(TypeError, ast2src(ass_name)) - name = assign.name - if name in tablerefs: throw(TranslationError, 'Duplicate name: %r' % name) - if name.startswith('__'): throw(TranslationError, 'Illegal name: %r' % name) + names = tuple(ass_name.name for ass_name in ass_names) + for name in names: + if name in translator.namespace and name in translator.sqlquery.tablerefs: + throw(TranslationError, 'Duplicate name: %r' % name) + if name.startswith('__'): throw(TranslationError, 'Illegal name: %r' % name) + + name = names[0] if len(names) == 1 else None + + def check_name_is_single(): + if len(names) > 1: throw(TypeError, 'Single variable name expected. Got: %s' % ast2src(assign)) + + database = entity = None node = qual.iter monad = getattr(node, 'monad', None) - src = getattr(node, 'src', None) + if monad: # Lambda was encountered inside generator + check_name_is_single() assert parent_translator and i == 0 entity = monad.type.item_type if isinstance(monad, EntityMonad): - tablerefs[name] = TableRef(sqlquery, name, entity) + tableref = TableRef(translator.sqlquery, name, entity) + translator.sqlquery.tablerefs[name] = tableref elif isinstance(monad, AttrSetMonad): translator.sqlquery = monad._subselect(translator.sqlquery, extract_outer_conditions=False) tableref = monad.tableref - translator.method_argnames_mapping_stack.append({ - name: ObjectIterMonad(translator, tableref, entity)}) else: assert False # pragma: no cover - elif src: - iterable = translator.root_translator.vartypes[translator.filter_num, src] - if not isinstance(iterable, SetType): throw(TranslationError, - 'Inside declarative query, iterator must be entity. ' - 'Got: for %s in %s' % (name, ast2src(qual.iter))) - entity = iterable.item_type - if not isinstance(entity, EntityMeta): - throw(TranslationError, 'for %s in %s' % (name, ast2src(qual.iter))) - if i > 0: - if translator.left_join: throw(TranslationError, - 'Collection expected inside left join query. ' - 'Got: for %s in %s' % (name, ast2src(qual.iter))) - translator.distinct = True - tableref = TableRef(sqlquery, name, entity) - tablerefs[name] = tableref - tableref.make_join() - node.monad = ObjectIterMonad(translator, tableref, entity) - else: - attr_names = [] - while isinstance(node, ast.Getattr): - attr_names.append(node.attrname) - node = node.expr - if not isinstance(node, ast.Name) or not attr_names: - throw(TranslationError, 'for %s in %s' % (name, ast2src(qual.iter))) - node_name = node.name - attr_names.reverse() - name_path = node_name - - monad = translator.resolve_name(node_name) - if monad is None: - throw(TranslationError, "Name %r must be defined in query" % node_name) - if not isinstance(monad, ObjectIterMonad): - throw(NotImplementedError) - parent_tableref = monad.tableref - parent_entity = parent_tableref.entity - - last_index = len(attr_names) - 1 - for j, attrname in enumerate(attr_names): - attr = parent_entity._adict_.get(attrname) - if attr is None: throw(AttributeError, attrname) - entity = attr.py_type + new_namespace = translator.namespace.copy() + new_namespace[name] = ObjectIterMonad(translator, tableref, entity) + translator.namespace_stack.append(new_namespace) + elif node.external: + iterable = translator.root_translator.vartypes[translator.filter_num, node.src] + if isinstance(iterable, SetType): + check_name_is_single() + entity = iterable.item_type if not isinstance(entity, EntityMeta): + throw(TranslationError, 'for %s in %s' % (name, ast2src(qual.iter))) + if i > 0: + if translator.left_join: throw(TranslationError, + 'Collection expected inside left join query. ' + 'Got: for %s in %s' % (name, ast2src(qual.iter))) + translator.distinct = True + tableref = TableRef(translator.sqlquery, name, entity) + translator.sqlquery.tablerefs[name] = tableref + tableref.make_join() + translator.namespace[name] = node.monad = ObjectIterMonad(translator, tableref, entity) + elif isinstance(iterable, QueryType): + base_translator = deepcopy(iterable.translator) + database = base_translator.database + try: + translator.process_query_qual(base_translator, names, try_extend_base_query=not i) + except UseAnotherTranslator as e: + translator = e.translator + else: throw(TranslationError, 'Inside declarative query, iterator must be entity. ' + 'Got: for %s in %s' % (name, ast2src(qual.iter))) + + else: + translator.dispatch(node) + monad = node.monad + + if isinstance(monad, QuerySetMonad): + subtranslator = monad.subtranslator + database = subtranslator.database + try: + translator.process_query_qual(subtranslator, names) + except UseAnotherTranslator: + assert False + else: + check_name_is_single() + attr_names = [] + while isinstance(monad, AttrSetMonad) and monad.parent is not None: + attr_names.append(monad.attr.name) + monad = monad.parent + attr_names.reverse() + + if not isinstance(monad, ObjectIterMonad): throw(NotImplementedError, 'for %s in %s' % (name, ast2src(qual.iter))) - can_affect_distinct = None - if attr.is_collection: - if not isinstance(attr, Set): throw(NotImplementedError, ast2src(qual.iter)) - reverse = attr.reverse - if reverse.is_collection: - if not isinstance(reverse, Set): throw(NotImplementedError, ast2src(qual.iter)) - translator.distinct = True - elif parent_tableref.alias != tree.quals[i-1].assign.name: - translator.distinct = True - else: can_affect_distinct = True - if j == last_index: name_path = name - else: name_path += '-' + attr.name - tableref = JoinedTableRef(sqlquery, name_path, parent_tableref, attr) - if can_affect_distinct is not None: - tableref.can_affect_distinct = can_affect_distinct - tablerefs[name_path] = tableref - parent_tableref = tableref - parent_entity = entity - - database = entity._database_ + name_path = monad.tableref.alias # or name_path, it is the same + + parent_tableref = monad.tableref + parent_entity = parent_tableref.entity + + last_index = len(attr_names) - 1 + for j, attrname in enumerate(attr_names): + attr = parent_entity._adict_.get(attrname) + if attr is None: throw(AttributeError, attrname) + entity = attr.py_type + if not isinstance(entity, EntityMeta): + throw(NotImplementedError, 'for %s in %s' % (name, ast2src(qual.iter))) + can_affect_distinct = None + if attr.is_collection: + if not isinstance(attr, Set): throw(NotImplementedError, ast2src(qual.iter)) + reverse = attr.reverse + if reverse.is_collection: + if not isinstance(reverse, Set): throw(NotImplementedError, ast2src(qual.iter)) + translator.distinct = True + elif parent_tableref.alias != tree.quals[i-1].assign.name: + translator.distinct = True + else: can_affect_distinct = True + if j == last_index: name_path = name + else: name_path += '-' + attr.name + tableref = translator.sqlquery.add_tableref(name_path, parent_tableref, attr) + if j == last_index: + translator.namespace[name] = ObjectIterMonad(translator, tableref, tableref.entity) + if can_affect_distinct is not None: + tableref.can_affect_distinct = can_affect_distinct + parent_tableref = tableref + parent_entity = entity + + if database is None: + assert entity is not None + database = entity._database_ assert database.schema is not None if translator.database is None: translator.database = database elif translator.database is not database: throw(TranslationError, @@ -358,7 +397,8 @@ def __init__(translator, tree, parent_translator, filter_num=None, extractors=No expr_set.add(m.tableref.name_path) elif isinstance(m, AttrMonad) and isinstance(m.parent, ObjectIterMonad): expr_set.add((m.parent.tableref.name_path, m.attr)) - for tr in tablerefs.values(): + for tr in translator.sqlquery.tablerefs.values(): + if tr.entity is None: continue if not tr.can_affect_distinct: continue if tr.name_path in expr_set: continue if any((tr.name_path, attr) not in expr_set for attr in tr.entity._pk_attrs_): @@ -393,6 +433,11 @@ def func(value, converter=converter): translator.row_layout = row_layout translator.col_names = [ src for func, slice_or_offset, src in translator.row_layout ] translator.vars = None + if translator is not this: + raise UseAnotherTranslator(translator) + @property + def namespace(translator): + return translator.namespace_stack[-1] def can_be_optimized(translator): if translator.groupby_monads: return False if len(translator.aggregated_subquery_paths) != 1: return False @@ -401,9 +446,94 @@ def can_be_optimized(translator): if not aggr_path.startswith(name): return False return aggr_path + def process_query_qual(translator, other_translator, names, try_extend_base_query=False): + sqlquery = translator.sqlquery + tablerefs = sqlquery.tablerefs + expr_types = other_translator.expr_type + if not isinstance(expr_types, tuple): expr_types = (expr_types,) + expr_count = len(expr_types) + + if expr_count > 1 and len(names) == 1: + throw(NotImplementedError, + 'Please unpack a tuple of (%s) in for-loop to individual variables (like: "for x, y in ...")' + % (', '.join(ast2src(m.node) for m in other_translator.expr_monads))) + elif expr_count > len(names): + throw(TranslationError, + 'Not enough values to unpack "for %s in select(%s for ...)" (expected %d, got %d)' + % (', '.join(names), + ', '.join(ast2src(m.node) for m in other_translator.expr_monads), + len(names), expr_count)) + elif expr_count < len(names): + throw(TranslationError, + 'Too many values to unpack "for %s in select(%s for ...)" (expected %d, got %d)' + % (', '.join(names), + ', '.join(ast2src(m.node) for m in other_translator.expr_monads), + len(names), expr_count)) + + if try_extend_base_query: + if other_translator.aggregated: pass + elif other_translator.left_join: pass + else: + assert translator.parent is None + assert other_translator.vars is None + other_translator.filter_num = translator.filter_num + other_translator.extractors.update(translator.extractors) + other_translator.vars = translator.vars + other_translator.vartypes.update(translator.vartypes) + other_translator.left_join = translator.left_join + other_translator.optimize = translator.optimize + other_translator.namespace_stack = [ + {name: expr for name, expr in izip(names, other_translator.expr_monads)} + ] + raise UseAnotherTranslator(other_translator) + + if len(names) == 1 and isinstance(other_translator.expr_type, EntityMeta) \ + and not other_translator.aggregated and not other_translator.distinct: + name = names[0] + entity = other_translator.expr_type + [expr_monad] = other_translator.expr_monads + entity_alias = expr_monad.tableref.alias + subquery_ast = other_translator.construct_subquery_ast(star=entity_alias) + tableref = StarTableRef(sqlquery, name, entity, subquery_ast) + tablerefs[name] = tableref + tableref.make_join() + translator.namespace[name] = ObjectIterMonad(translator, tableref, entity) + else: + aliases = [] + aliases_dict = {} + for name, base_expr_monad in izip(names, other_translator.expr_monads): + t = base_expr_monad.type + if isinstance(t, EntityMeta): + t_aliases = [] + for suffix in t._pk_paths_: + alias = '%s-%s' % (name, suffix) + t_aliases.append(alias) + aliases.extend(t_aliases) + aliases_dict[base_expr_monad] = t_aliases + else: + aliases.append(name) + aliases_dict[base_expr_monad] = name + + subquery_ast = other_translator.construct_subquery_ast(aliases=aliases) + tableref = ExprTableRef(sqlquery, 't', subquery_ast, names, aliases) + for name in names: + tablerefs[name] = tableref + tableref.make_join() + + for name, base_expr_monad in izip(names, other_translator.expr_monads): + t = base_expr_monad.type + if isinstance(t, EntityMeta): + columns = aliases_dict[base_expr_monad] + expr_tableref = ExprJoinedTableRef(sqlquery, tableref, columns, name, t) + expr_monad = ObjectIterMonad(translator, expr_tableref, t) + else: + column = aliases_dict[base_expr_monad] + expr_ast = ['COLUMN', tableref.alias, column] + expr_monad = ExprMonad.new(translator, t, expr_ast, base_expr_monad.nullable) + assert name not in translator.namespace + translator.namespace[name] = expr_monad def construct_subquery_ast(translator, aliases=None, star=None, distinct=None, is_not_null_checks=False): subquery_ast, attr_offsets = translator.construct_sql_ast(distinct=distinct, is_not_null_checks=is_not_null_checks) - assert attr_offsets is None assert len(subquery_ast) >= 3 and subquery_ast[0] == 'SELECT' select_ast = subquery_ast[1][:] @@ -667,8 +797,13 @@ def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_ translator.vars = vars.copy() if vars is not None else None translator.vartypes = translator.vartypes.copy() # make HashableDict mutable again translator.vartypes.update(vartypes) - translator.lambda_argnames = list(argnames) - translator.original_names = original_names + + if not original_names: + assert argnames + translator.namespace_stack.append({name: monad for name, monad in izip(argnames, translator.expr_monads)}) + elif argnames: + translator.namespace_stack.append({name: translator.namespace[name] for name in argnames}) + translator.dispatch(func_ast) if isinstance(func_ast, ast.Tuple): nodes = func_ast.nodes else: nodes = (func_ast,) @@ -697,7 +832,10 @@ def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_ def preGenExpr(translator, node): inner_tree = node.code translator_cls = translator.__class__ - subtranslator = translator_cls(inner_tree, translator) + try: + subtranslator = translator_cls(inner_tree, translator) + except UseAnotherTranslator: + assert False return QuerySetMonad(translator, subtranslator) def postGenExprIf(translator, node): monad = node.test.monad @@ -743,20 +881,9 @@ def postName(translator, node): assert monad is not None return monad def resolve_name(translator, name): - t = translator - while t is not None: - stack = t.method_argnames_mapping_stack - if stack and name in stack[-1]: - return stack[-1][name] - argnames = t.lambda_argnames - if argnames is not None and not t.original_names and name in argnames: - i = argnames.index(name) - return t.expr_monads[i] - t = t.parent - tableref = translator.sqlquery.get_tableref(name) - if tableref is not None: - return ObjectIterMonad(translator, tableref, tableref.entity) - return None + if name not in translator.namespace: + throw(TranslationError, 'Name %s is not found in %s' % (name, translator.namespace)) + return translator.namespace[name] def postAdd(translator, node): return node.left.monad + node.right.monad def postSub(translator, node): @@ -828,7 +955,10 @@ def preCallFunc(translator, node): for_expr = ast.GenExprFor(ast.AssName(iter_name, 'OP_ASSIGN'), name_ast, [ if_expr ]) inner_expr = ast.GenExprInner(ast.Name(iter_name), [ for_expr ]) translator_cls = translator.__class__ - subtranslator = translator_cls(inner_expr, translator) + try: + subtranslator = translator_cls(inner_expr, translator) + except UseAnotherTranslator: + assert False monad = QuerySetMonad(translator, subtranslator) if method_name == 'exists': monad = monad.nonzero() @@ -935,7 +1065,6 @@ def get_tableref(sqlquery, name_path, from_subquery=False): if sqlquery.parent_sqlquery: return sqlquery.parent_sqlquery.get_tableref(name_path, from_subquery=True) return None - __contains__ = get_tableref def add_tableref(sqlquery, name_path, parent_tableref, attr): tablerefs = sqlquery.tablerefs assert name_path not in tablerefs @@ -980,6 +1109,65 @@ def make_join(tableref, pk_only=False): tableref.joined = True return tableref.alias, entity._pk_columns_ +class ExprTableRef(TableRef): + def __init__(tableref, sqlquery, name, subquery_ast, expr_names, expr_aliases): + TableRef.__init__(tableref, sqlquery, name, None) + tableref.subquery_ast = subquery_ast + tableref.expr_names = expr_names + tableref.expr_aliases = expr_aliases + def make_join(tableref, pk_only=False): + assert tableref.subquery_ast[0] == 'SELECT' + if not tableref.joined: + sqlquery = tableref.sqlquery + sqlquery.from_ast.append([tableref.alias, 'SELECT', tableref.subquery_ast[1:]]) + tableref.joined = True + return tableref.alias, None + +class StarTableRef(TableRef): + def __init__(tableref, sqlquery, name, entity, subquery_ast): + TableRef.__init__(tableref, sqlquery, name, entity) + tableref.subquery_ast = subquery_ast + def make_join(tableref, pk_only=False): + entity = tableref.entity + assert tableref.subquery_ast[0] == 'SELECT' + if not tableref.joined: + sqlquery = tableref.sqlquery + sqlquery.from_ast.append([ tableref.alias, 'SELECT', tableref.subquery_ast[1:] ]) + if entity._discriminator_attr_: # ??? + discr_criteria = entity._construct_discriminator_criteria_(tableref.alias) + assert discr_criteria is not None + sqlquery.conditions.append(discr_criteria) + tableref.joined = True + return tableref.alias, entity._pk_columns_ + +class ExprJoinedTableRef(object): + def __init__(tableref, sqlquery, parent_tableref, parent_columns, name, entity): + tableref.sqlquery = sqlquery + tableref.parent_tableref = parent_tableref + tableref.parent_columns = parent_columns + tableref.name = tableref.name_path = name + tableref.entity = entity + tableref.alias = None + tableref.joined = False + tableref.can_affect_distinct = False + tableref.used_attrs = set() + def make_join(tableref, pk_only=False): + entity = tableref.entity + if tableref.joined: + return tableref.alias, tableref.pk_columns + sqlquery = tableref.sqlquery + parent_alias, left_pk_columns = tableref.parent_tableref.make_join() + if pk_only: + tableref.alias = parent_alias + tableref.pk_columns = tableref.parent_columns + return tableref.alias, tableref.pk_columns + tableref.alias = sqlquery.make_alias(tableref.name) + tableref.pk_columns = entity._pk_columns_ + join_cond = join_tables(parent_alias, tableref.alias, tableref.parent_columns, tableref.pk_columns) + sqlquery.join_table(parent_alias, tableref.alias, entity._table_, join_cond) + tableref.joined = True + return tableref.alias, tableref.pk_columns + class JoinedTableRef(object): def __init__(tableref, sqlquery, name_path, parent_tableref, attr): tableref.sqlquery = sqlquery @@ -1007,6 +1195,7 @@ def make_join(tableref, pk_only=False): pk_columns = entity._pk_columns_ if not attr.is_collection: if not attr.columns: + # one-to-one relationship with foreign key column on the right side reverse = attr.reverse assert reverse.columns and not reverse.is_collection rentity = reverse.entity @@ -1014,6 +1203,7 @@ def make_join(tableref, pk_only=False): alias = sqlquery.make_alias(tableref.var_name or rentity.__name__) join_cond = join_tables(parent_alias, alias, left_pk_columns, reverse.columns) else: + # one-to-one or many-to-one relationship with foreign key column on the left side if attr.pk_offset is not None: offset = attr.pk_columns_offset left_columns = left_pk_columns[offset:offset+len(attr.columns)] @@ -1027,9 +1217,11 @@ def make_join(tableref, pk_only=False): alias = sqlquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(parent_alias, alias, left_columns, pk_columns) elif not attr.reverse.is_collection: + # many-to-one relationship alias = sqlquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(parent_alias, alias, left_pk_columns, attr.reverse.columns) else: + # many-to-many relationship right_m2m_columns = attr.reverse_columns if attr.symmetric else attr.columns if not tableref.joined: m2m_table = attr.table @@ -1327,7 +1519,7 @@ def __call__(monad, *args, **kwargs): root_translator.vartypes.update(func_vartypes) root_translator.vars.update(func_vars) - stack = translator.method_argnames_mapping_stack + stack = translator.namespace_stack stack.append(name_mapping) prev_filter_num = translator.filter_num translator.filter_num = func_filter_num @@ -3035,7 +3227,10 @@ def call_group_concat(monad, sep=None, distinct=None): throw(TypeError, '`sep` option of `group_concat` should be type of str. Got: %s' % type(sep).__name__) return monad.aggregate('GROUP_CONCAT', distinct, sep=sep) def getsql(monad): - throw(NotImplementedError) + translator = monad.translator + sub = monad.subtranslator + subquery_ast = sub.construct_subquery_ast() + return subquery_ast def find_or_create_having_ast(sections): groupby_offset = None diff --git a/pony/orm/tests/test_declarative_query_set_monad.py b/pony/orm/tests/test_declarative_query_set_monad.py index c91517a44..b0a384a20 100644 --- a/pony/orm/tests/test_declarative_query_set_monad.py +++ b/pony/orm/tests/test_declarative_query_set_monad.py @@ -13,6 +13,7 @@ class Group(db.Entity): class Student(db.Entity): name = Required(unicode) + age = Required(int) group = Required('Group') scholarship = Required(int, default=0) courses = Set('Course') @@ -28,9 +29,9 @@ class Course(db.Entity): with db_session: g1 = Group(id=1) g2 = Group(id=2) - s1 = Student(id=1, name='S1', group=g1, scholarship=0) - s2 = Student(id=2, name='S2', group=g1, scholarship=100) - s3 = Student(id=3, name='S3', group=g2, scholarship=500) + s1 = Student(id=1, name='S1', age=20, group=g1, scholarship=0) + s2 = Student(id=2, name='S2', age=23, group=g1, scholarship=100) + s3 = Student(id=3, name='S3', age=23, group=g2, scholarship=500) c1 = Course(name='C1', semester=1, students=[s1, s2]) c2 = Course(name='C2', semester=1, students=[s2, s3]) c3 = Course(name='C3', semester=2, students=[s3]) @@ -286,5 +287,63 @@ def test_group_concat_11(self): self.assertEqual(result, '1,2') + @raises_exception(TypeError, 'Query can only iterate over entity or another query (not a list of objects)') + def test_select_from_select_1(self): + query = select(s for s in Student if s.scholarship > 0)[:] + result = set(select(x for x in query)) + self.assertEqual(result, {}) + + def test_select_from_select_2(self): + p, q = 50, 400 + query = select(s for s in Student if s.scholarship > p) + result = select(x.id for x in query if x.scholarship < q)[:] + self.assertEqual(set(result), {2}) + + def test_select_from_select_3(self): + p, q = 50, 400 + g = (s for s in Student if s.scholarship > p) + result = select(x.id for x in g if x.scholarship < q)[:] + self.assertEqual(set(result), {2}) + + def test_select_from_select_4(self): + p, q = 50, 400 + result = select(x.id for x in (s for s in Student if s.scholarship > p) + if x.scholarship < q)[:] + self.assertEqual(set(result), {2}) + + def test_select_from_select_5(self): + p, q = 50, 400 + result = select(x.id for x in select(s for s in Student if s.scholarship > 0) + if x.scholarship < 400)[:] + self.assertEqual(set(result), {2}) + + def test_select_from_select_6(self): + query = select(s.name for s in Student if s.scholarship > 0) + result = select(x for x in query if not x.endswith('3')) + self.assertEqual(set(result), {'S2'}) + + @raises_exception(TranslationError, 'Too many values to unpack "for a, b in select(s for ...)" (expected 2, got 1)') + def test_select_from_select_7(self): + query = select(s for s in Student if s.scholarship > 0) + result = select(a for a, b in query) + + @raises_exception(NotImplementedError, 'Please unpack a tuple of (s.name, s.group) in for-loop ' + 'to individual variables (like: "for x, y in ...")') + def test_select_from_select_8(self): + query = select((s.name, s.group) for s in Student if s.scholarship > 0) + result = select(x for x in query) + + @raises_exception(TranslationError, 'Not enough values to unpack "for x, y in ' + 'select(s.name, s.group, s.scholarship for ...)" (expected 2, got 3)') + def test_select_from_select_9(self): + query = select((s.name, s.group, s.scholarship) for s in Student if s.scholarship > 0) + result = select(x for x, y in query) + + def test_select_from_select_10(self): + query = select((s.name, s.age) for s in Student if s.scholarship > 0) + result = select(n for n, a in query if n.endswith('2') and a > 20) + self.assertEqual(set(x for x in result), {'S2'}) + + if __name__ == "__main__": unittest.main() diff --git a/pony/orm/tests/test_query.py b/pony/orm/tests/test_query.py index be7977454..070e15ea4 100644 --- a/pony/orm/tests/test_query.py +++ b/pony/orm/tests/test_query.py @@ -1,5 +1,5 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PYPY2 +from pony.py23compat import PYPY2, pickle import unittest from datetime import date @@ -36,14 +36,14 @@ def setUp(self): def tearDown(self): rollback() db_session.__exit__() - @raises_exception(TypeError, 'Cannot iterate over non-entity object') + @raises_exception(TypeError, "Query can only iterate over entity or another query (not a list of objects)") def test1(self): select(s for s in []) - @raises_exception(TypeError, 'Cannot iterate over non-entity object X') + @raises_exception(TypeError, "Cannot iterate over non-entity object X") def test2(self): X = [1, 2, 3] select('x for x in X') - @raises_exception(TypeError, "Cannot iterate over non-entity object") + @raises_exception(TypeError, "Query can only iterate over entity or another query (not a list of objects)") def test3(self): g = Group[1] select(s for s in g.students) @@ -154,6 +154,13 @@ def find_by_gpa(): return lambda s: s.gpa > gpa fn = find_by_gpa() students = list(Student.select(fn)) + def test_pickle(self): + objects = select(s for s in Student if s.scholarship > 0).order_by(desc(Student.id)) + data = pickle.dumps(objects) + rollback() + objects = pickle.loads(data) + self.assertEqual([obj.id for obj in objects], [3, 2]) + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_select_from_select_queries.py b/pony/orm/tests/test_select_from_select_queries.py new file mode 100644 index 000000000..cbcf60adf --- /dev/null +++ b/pony/orm/tests/test_select_from_select_queries.py @@ -0,0 +1,225 @@ +import unittest + +from pony.orm import * +from pony.orm.tests.testutils import * + +db = Database('sqlite', ':memory:') + +class Group(db.Entity): + number = PrimaryKey(int) + major = Required(str) + students = Set('Student') + +class Student(db.Entity): + first_name = Required(unicode) + last_name = Required(unicode) + age = Required(int) + group = Required('Group') + scholarship = Required(int, default=0) + courses = Set('Course') + + @property + def full_name(self): + return self.first_name + ' ' + self.last_name + +class Course(db.Entity): + name = Required(unicode) + semester = Required(int) + credits = Required(int) + PrimaryKey(name, semester) + students = Set('Student') + +db.generate_mapping(create_tables=True) + +with db_session: + g1 = Group(number=123, major='Computer Science') + g2 = Group(number=456, major='Graphic Design') + s1 = Student(id=1, first_name='John', last_name='Smith', age=20, group=g1, scholarship=0) + s2 = Student(id=2, first_name='Alex', last_name='Green', age=24, group=g1, scholarship=100) + s3 = Student(id=3, first_name='Mary', last_name='White', age=23, group=g1, scholarship=500) + s4 = Student(id=4, first_name='John', last_name='Brown', age=20, group=g2, scholarship=400) + s5 = Student(id=5, first_name='Bruce', last_name='Lee', age=22, group=g2, scholarship=300) + c1 = Course(name='Math', semester=1, credits=10, students=[s1, s2, s4]) + c2 = Course(name='Computer Science', semester=1, credits=20, students=[s2, s3]) + c3 = Course(name='3D Modeling', semester=2, credits=15, students=[s3, s5]) + + +class TestSelectFromSelect(unittest.TestCase): + @db_session + def test_1(self): # basic select from another query + q = select(s for s in Student if s.scholarship > 0) + q2 = select(s for s in q if s.scholarship < 500) + self.assertEqual(set(s.first_name for s in q2), {'Alex', 'John', 'Bruce'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) # single SELECT...FROM expression + + @db_session + def test_2(self): # different variable name in the second query + q = select(s for s in Student if s.scholarship > 0) + q2 = select(x for x in q if x.scholarship < 500) + self.assertEqual(set(s.first_name for s in q2), {'Alex', 'John', 'Bruce'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_3(self): # selecting single column instead of entity in the second query + q = select(s for s in Student if s.scholarship > 0) + q2 = select(x.first_name for x in q if x.scholarship < 500) + self.assertEqual(set(q2), {'Alex', 'Bruce', 'John'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_4(self): # selecting single column instead of entity in the first query + q = select(s.first_name for s in Student if s.scholarship > 0) + q2 = select(name for name in q if 'r' in name) + self.assertEqual(set(q2), {'Bruce', 'Mary'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_5(self): # selecting hybrid property in the second query + q = select(s for s in Student if s.scholarship > 0) + q2 = select(x.full_name for x in q if x.scholarship < 500) + self.assertEqual(set(q2), {'Alex Green', 'Bruce Lee', 'John Brown'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_6(self): # selecting hybrid property in the first query + q = select(s.full_name for s in Student if s.scholarship < 500) + q2 = select(x for x in q if x.startswith('J')) + self.assertEqual(set(q2), {'John Smith', 'John Brown'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + @raises_exception(ExprEvalError, "s.scholarship > 0 raises NameError: name 's' is not defined") + def test_7(self): # test access to original query var name from the new query + q = select(s.first_name for s in Student if s.scholarship < 500) + q2 = select(x for x in q if s.scholarship > 0) + + @db_session + def test_8(self): # test using external name which is equal to original query var name + class Dummy(object): + scholarship = 1 + s = Dummy() + q = select(s.first_name for s in Student if s.scholarship < 500) + q2 = select(x for x in q if s.scholarship > 0) + self.assertEqual(set(q2), {'John', 'Alex', 'Bruce'}) + + @db_session + def test_9(self): # test reusing variable name from the original query + q = select(s for s in Student if s.scholarship > 0) + q2 = select(x for x in q for s in Student if x.scholarship < s.scholarship) + self.assertEqual(set(s.first_name for s in q2), {'Alex', 'John', 'Bruce'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_10(self): # test .filter() + q = select(s for s in Student if s.scholarship > 0) + q2 = q.filter(lambda a: a.scholarship < 500) + q3 = select(x for x in q2 if x.age > 20) + q4 = q3.filter(lambda b: b.age < 24) + self.assertEqual(set(s.first_name for s in q4), {'Bruce'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_11(self): # test .where() + q = select(s for s in Student if s.scholarship > 0) + q2 = q.where(lambda s: s.scholarship < 500) + q3 = select(x for x in q2 if x.age > 20) + q4 = q3.where(lambda x: x.age < 24) # the name should be accessible in previous generator + self.assertEqual(set(s.first_name for s in q4), {'Bruce'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + @raises_exception(TypeError, 'Lambda argument `s` does not correspond to any variable in original query') + def test_12(self): # test .where() + q = select(s for s in Student if s.scholarship > 0) + q2 = q.where(lambda s: s.scholarship < 500) + q3 = select(x for x in q2 if x.age > 20) + q4 = q3.where(lambda s: s.age < 24) + + @db_session + def test_13(self): # select several expressions from the first query + q = select((s.full_name, s.age) for s in Student if s.scholarship > 0) + q2 = select(name for name, age in q if age < 24 and 'e' in name) + self.assertEqual(set(q2), {'Mary White', 'Bruce Lee'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_14(self): # select from entity with composite key + q = select(c for c in Course if c.semester == 1) + q2 = select(x.name for x in q if x.name.startswith('M')) + self.assertEqual(set(q2), {'Math'}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_15(self): # SELECT ... FROM (SELECT alias.* FROM ... + q = left_join(s for g in Group for s in g.students if g.number == 123 and s.scholarship > 0) + q2 = select(x.full_name for x in q if x.scholarship > 100) + self.assertEqual(set(q2), {'Mary White'}) + self.assertEqual(db.last_sql.count('SELECT'), 2) + self.assertEqual(db.last_sql.count('LEFT JOIN'), 1) + self.assertTrue('*' in db.last_sql) + + @db_session + def test_16(self): # SELECT ... FROM (grouped-query) + q = select(g for g in Group if count(g.students) > 2) + q2 = select(x.number for x in q) + + self.assertEqual(set(q2), {123}) + self.assertEqual(db.last_sql.count('SELECT'), 2) + self.assertEqual(db.last_sql.count('LEFT JOIN'), 1) + self.assertEqual(db.last_sql.count('GROUP BY'), 1) + self.assertEqual(db.last_sql.count('HAVING'), 1) + self.assertTrue('WHERE' not in db.last_sql) + + @db_session + def test_17(self): # SELECT ... FROM (grouped-query), t1 WHERE ... + q = select(g for g in Group if count(g.students) > 2) + q2 = select(x.major for x in q) + + self.assertEqual(set(q2), {'Computer Science'}) + self.assertEqual(db.last_sql.count('SELECT'), 2) + self.assertEqual(db.last_sql.count('LEFT JOIN'), 1) + self.assertEqual(db.last_sql.count('GROUP BY'), 1) + self.assertEqual(db.last_sql.count('HAVING'), 1) + + @db_session + def test_18(self): # SELECT ... FROM (grouped-query returns composite keys), t1 WHERE ... + q = select((c, count(c.students)) for c in Course if c.semester == 1 and count(c.students) > 1) + q2 = select((x.name, x.credits, y) for x, y in q if x.credits > 10 and y < 3) + + self.assertEqual(set(q2), {('Computer Science', 20, 2)}) + self.assertEqual(db.last_sql.count('SELECT'), 2) + self.assertEqual(db.last_sql.count('LEFT JOIN'), 1) + self.assertEqual(db.last_sql.count('GROUP BY'), 1) + self.assertEqual(db.last_sql.count('HAVING'), 1) + self.assertEqual(db.last_sql.count('WHERE'), 2) + + @db_session + def test_19(self): # multiple for loops in the inner query + q = select((g, s.first_name.lower()) for g in Group for s in g.students) + q2 = select((g.major, n) for g, n in q if g.number == 123 and n[0] == 'j') + self.assertEqual(set(q2), {('Computer Science', 'john')}) + + @db_session + def test_20(self): # additional for loop with inlined subquery + q = select((g, x.first_name.upper()) + for g in Group + for x in select(s for s in Student if s.age < 22) + if x.group == g and g.number == 123 and x.first_name[0] == 'J') + q2 = select(name for g, name in q if g.number == 123) + self.assertEqual(set(q2), {'JOHN'}) + + @db_session + def test_21(self): + objects = select(s for s in Student if s.scholarship > 200)[:] # not query, but query result + q = select(s.first_name for s in Student if s not in objects) + self.assertEqual(set(q), {'John', 'Alex'}) + + @db_session + @raises_exception(TypeError, 'Query can only iterate over entity or another query (not a list of objects)') + def test_22(self): + objects = select(s for s in Student if s.scholarship > 200)[:] # not query, but query result + q = select(s.first_name for s in objects) + + +if __name__ == '__main__': + unittest.main() From b7586458f048b39b32721636c9a7549d1cc735ed Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 21 Jul 2018 15:17:15 +0300 Subject: [PATCH 335/547] New module pony.flask --- pony/flask/__init__.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) create mode 100644 pony/flask/__init__.py diff --git a/pony/flask/__init__.py b/pony/flask/__init__.py new file mode 100644 index 000000000..d63fa6856 --- /dev/null +++ b/pony/flask/__init__.py @@ -0,0 +1,24 @@ +from pony.orm import db_session +from flask import request + +def _enter_session(): + session = db_session() + request.pony_session = session + session.__enter__() + +def _exit_session(exception): + session = getattr(request, 'pony_session', None) + if session is None: + raise RuntimeError('Request object lost db_session') + session.__exit__(exc=exception) + +class Pony(object): + def __init__(self, app=None): + self.app = None + if app is not None: + self.init_app(app) + + def init_app(self, app): + self.app = app + self.app.before_request(_enter_session) + self.app.teardown_request(_exit_session) \ No newline at end of file From f12cbcbaf1c23b246428c98427de286c957b4544 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 21 Jul 2018 15:19:08 +0300 Subject: [PATCH 336/547] Example for pony.flask added --- pony/flask/example/__init__.py | 0 pony/flask/example/__main__.py | 7 ++++ pony/flask/example/app.py | 16 +++++++ pony/flask/example/config.py | 9 ++++ pony/flask/example/models.py | 10 +++++ pony/flask/example/templates/index.html | 36 ++++++++++++++++ pony/flask/example/templates/login.html | 30 +++++++++++++ pony/flask/example/templates/reg.html | 30 +++++++++++++ pony/flask/example/views.py | 56 +++++++++++++++++++++++++ 9 files changed, 194 insertions(+) create mode 100644 pony/flask/example/__init__.py create mode 100644 pony/flask/example/__main__.py create mode 100644 pony/flask/example/app.py create mode 100644 pony/flask/example/config.py create mode 100644 pony/flask/example/models.py create mode 100644 pony/flask/example/templates/index.html create mode 100644 pony/flask/example/templates/login.html create mode 100644 pony/flask/example/templates/reg.html create mode 100644 pony/flask/example/views.py diff --git a/pony/flask/example/__init__.py b/pony/flask/example/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pony/flask/example/__main__.py b/pony/flask/example/__main__.py new file mode 100644 index 000000000..93d11b071 --- /dev/null +++ b/pony/flask/example/__main__.py @@ -0,0 +1,7 @@ +from .views import * +from .app import app + +if __name__ == '__main__': + db.bind(**app.config['PONY']) + db.generate_mapping(create_tables=True) + app.run() \ No newline at end of file diff --git a/pony/flask/example/app.py b/pony/flask/example/app.py new file mode 100644 index 000000000..51d33ad90 --- /dev/null +++ b/pony/flask/example/app.py @@ -0,0 +1,16 @@ +from flask import Flask +from flask_login import LoginManager +from pony.flask import Pony +from .config import config +from .models import db + +app = Flask(__name__) +app.config.update(config) + +Pony(app) +login_manager = LoginManager(app) +login_manager.login_view = 'login' + +@login_manager.user_loader +def load_user(user_id): + return db.User.get(id=user_id) diff --git a/pony/flask/example/config.py b/pony/flask/example/config.py new file mode 100644 index 000000000..0c33fe839 --- /dev/null +++ b/pony/flask/example/config.py @@ -0,0 +1,9 @@ +config = dict( + DEBUG = False, + SECRET_KEY = 'secret_xxx', + PONY = { + 'provider': 'sqlite', + 'filename': 'db.db3', + 'create_db': True + } +) \ No newline at end of file diff --git a/pony/flask/example/models.py b/pony/flask/example/models.py new file mode 100644 index 000000000..f97d56425 --- /dev/null +++ b/pony/flask/example/models.py @@ -0,0 +1,10 @@ +from pony.orm import Database, Required, Optional +from flask_login import UserMixin +from datetime import datetime + +db = Database() + +class User(db.Entity, UserMixin): + login = Required(str, unique=True) + password = Required(str) + last_login = Optional(datetime) \ No newline at end of file diff --git a/pony/flask/example/templates/index.html b/pony/flask/example/templates/index.html new file mode 100644 index 000000000..a40948d31 --- /dev/null +++ b/pony/flask/example/templates/index.html @@ -0,0 +1,36 @@ + + + Hello! + + + + +
+ {% with messages = get_flashed_messages() %} + {% if messages %} + {% for message in messages %} + + {% endfor %} + {% endif %} + {% endwith %} + {% if not current_user.is_authenticated %} +

Hi, please log in or register


+ {% else %} +

Hi, {{ current_user.login }}. Your last login: {{ current_user.last_login.strftime('%Y-%m-%d') }}

+ Logout +

List of users

+
    + {% for user in users %} +
  • + {% if user == current_user %} + {{ user.login }} + {% else %} + {{ user.login }} + {% endif %} +
  • + {% endfor %} +
+ {% endif %} +
+ + \ No newline at end of file diff --git a/pony/flask/example/templates/login.html b/pony/flask/example/templates/login.html new file mode 100644 index 000000000..562525904 --- /dev/null +++ b/pony/flask/example/templates/login.html @@ -0,0 +1,30 @@ + + + Login page + + + + +
+ {% with messages = get_flashed_messages() %} + {% if messages %} + {% for message in messages %} + + {% endfor %} + {% endif %} + {% endwith %} +

Please login

+
+
+ + + +
+ {% if error %} +

Error: {{ error }} + {% endif %} +

+ + \ No newline at end of file diff --git a/pony/flask/example/templates/reg.html b/pony/flask/example/templates/reg.html new file mode 100644 index 000000000..ae9a27d91 --- /dev/null +++ b/pony/flask/example/templates/reg.html @@ -0,0 +1,30 @@ + + + Login page + + + + +
+ {% with messages = get_flashed_messages() %} + {% if messages %} + {% for message in messages %} + + {% endfor %} + {% endif %} + {% endwith %} +

Register

+
+
+ + + +
+ {% if error %} +

Error: {{ error }} + {% endif %} +

+ + \ No newline at end of file diff --git a/pony/flask/example/views.py b/pony/flask/example/views.py new file mode 100644 index 000000000..a8477a025 --- /dev/null +++ b/pony/flask/example/views.py @@ -0,0 +1,56 @@ +from .app import app +from .models import db +from flask import render_template, request, flash, redirect, abort +from flask_login import current_user, logout_user, login_user, login_required +from datetime import datetime +from pony.orm import flush + +@app.route('/') +def index(): + users = db.User.select() + return render_template('index.html', user=current_user, users=users) + +@app.route('/login', methods=['GET', 'POST']) +def login(): + if request.method == 'POST': + username = request.form['username'] + password = request.form['password'] + possible_user = db.User.get(login=username) + if not possible_user: + flash('Wrong username') + return redirect('/login') + if possible_user.password == password: + possible_user.last_login = datetime.now() + login_user(possible_user) + return redirect('/') + + flash('Wrong password') + return redirect('/login') + else: + return render_template('login.html') + +@app.route('/reg', methods=['GET', 'POST']) +def reg(): + if request.method == 'POST': + username = request.form['username'] + password = request.form['password'] + exist = db.User.get(login=username) + if exist: + flash('Username %s is already taken, choose another one' % username) + return redirect('/reg') + + user = db.User(login=username, password=password) + user.last_login = datetime.now() + flush() + login_user(user) + flash('Successfully registered') + return redirect('/') + else: + return render_template('reg.html') + +@app.route('/logout') +@login_required +def logout(): + logout_user() + flash('Logged out') + return redirect('/') \ No newline at end of file From b89af3df7f214146768def7b30d474b127e123b5 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 21 Jul 2018 19:20:38 +0300 Subject: [PATCH 337/547] Update setup.py: use setuptools instead of distutils.core --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e1851a511..e60652052 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from __future__ import print_function -from distutils.core import setup +from setuptools import setup import sys name = "pony" From 63af82b20a262822f561ad4229f81456a4359389 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 14 Mar 2018 13:19:23 +0300 Subject: [PATCH 338/547] Include pony.flask and pony.flask.example to setup.py --- MANIFEST.in | 3 ++- setup.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index 1bf3c80a3..8f03fd3f8 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,2 @@ -include pony/orm/tests/queries.txt \ No newline at end of file +include pony/orm/tests/queries.txt +include pony/flask/example/templates *.html diff --git a/setup.py b/setup.py index e60652052..892beb2a7 100644 --- a/setup.py +++ b/setup.py @@ -77,6 +77,8 @@ packages = [ "pony", + "pony.flask", + "pony.flask.example", "pony.orm", "pony.orm.dbproviders", "pony.orm.examples", @@ -87,6 +89,11 @@ "pony.utils" ] +package_data = { + 'pony.flask.example': ['templates/*.html'], + 'pony.orm.tests': ['queries.txt'] +} + download_url = "http://pypi.python.org/pypi/pony/" if __name__ == "__main__": @@ -108,5 +115,6 @@ url=url, license=licence, packages=packages, + package_data=package_data, download_url=download_url ) From a67fb001bf9932578af85ee8f7a01ccd08f02c88 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 21 Jul 2018 18:54:50 +0300 Subject: [PATCH 339/547] Update changelog and change Pony version: 0.7.4-dev -> 0.7.4 --- CHANGELOG.md | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ pony/__init__.py | 2 +- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 313c2aacf..d4cfc237b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,51 @@ +# Pony ORM Release 0.7.4 (2018-07-23) + +## Major features + +* Hybrid methods and properties added: https://docs.ponyorm.com/entities.html#hybrid-methods-and-properties +* Allow to base queries on another queries: `select(x.a for x in prev_query if x.b)` +* Added support of Python 3.7 +* Added support of PyPy +* `group_concat()` aggregate function added +* pony.flask subpackage added for integration with Flask + +## Other features + +* `distinct` option added to aggregate functions +* Support of explicit casting to `float` and `bool` in queries + +## Improvements + +* Apply @cut_traceback decorator only when pony.MODE is 'INTERACTIVE' + +## Bugfixes + +* In SQLite3 `LIKE` is case sensitive now +* #249: Fix incorrect mixin used for Timedelta +* #251: correct dealing with qualified table names +* #301: Fix aggregation over JSON Column +* #306: Support of frozenset constants added +* #308: Fixed an error when assigning JSON attribute value to the same attribute: obj.json_attr = obj.json_attr +* #313: Fix missed retry on exception raised during db_session.__exit__ +* #314: Fix AttributeError: 'NoneType' object has no attribute 'seeds' +* #315: Fix attribute lifting for JSON attributes +* #321: Fix KeyError on obj.delete() +* #325: duplicating percentage sign in raw SQL queries without parameters +* #331: Overriding __len__ in entity fails +* #336: entity declaration serialization +* #357: reconnect after PostgreSQL server closed the connection unexpectedly +* Fix Python implementation of between() function and rename arguments: between(a, x, y) -> between(x, a, b) +* Fix retry handling: in PostgreSQL and Oracle an error can be raised during commit +* Fix optimistic update checks for composite foreign keys +* Don't raise OptimisticCheckError if db_session is not optimistic +* Handling incorrect datetime values in MySQL +* Improved ImportError exception messages when MySQLdb, pymysql, psycopg2 or psycopg2cffi driver was not found +* desc() function fixed to allow reverse its effect by calling desc(desc(x)) +* __contains__ method should check if objects belong to the same db_session +* Fix pony.MODE detection; mod_wsgi detection according to official doc +* A lot of inner fixes + + # Pony ORM Release 0.7.3 (2017-10-23) ## New features diff --git a/pony/__init__.py b/pony/__init__.py index 2702f8965..6606c1567 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.4-dev' +__version__ = '0.7.4' uid = str(random.randint(1, 1000000)) From c1c3d15891b74c44d489103511d8fd1cb9bbc796 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 24 Jul 2018 15:08:55 +0300 Subject: [PATCH 340/547] Update pony version: 0.7.4 -> 0.7.5-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index 6606c1567..892ceba21 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.4' +__version__ = '0.7.5-dev' uid = str(random.randint(1, 1000000)) From 4603f82faed8620e3b8149761a223f2a7a72c634 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Tue, 24 Jul 2018 15:19:51 +0300 Subject: [PATCH 341/547] Fix year in README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cc7e5f5c7..05416ff6b 100644 --- a/README.md +++ b/README.md @@ -50,4 +50,4 @@ Meet the PonyORM team, chat with the community members, and get your questions a Join our newsletter at [ponyorm.com](https://ponyorm.com). Reach us on [Twitter](https://twitter.com/ponyorm). -Copyright (c) 2016 Pony ORM, LLC. All rights reserved. team (at) ponyorm.com +Copyright (c) 2018 Pony ORM, LLC. All rights reserved. team (at) ponyorm.com From 3d6ee92475eaac6d3faa582023702bacf47e80a0 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Tue, 24 Jul 2018 14:54:24 +0300 Subject: [PATCH 342/547] Hybrid method test added --- pony/orm/tests/test_hybrid_methods_and_properties.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pony/orm/tests/test_hybrid_methods_and_properties.py b/pony/orm/tests/test_hybrid_methods_and_properties.py index 7b2b4256b..6190e3fb0 100644 --- a/pony/orm/tests/test_hybrid_methods_and_properties.py +++ b/pony/orm/tests/test_hybrid_methods_and_properties.py @@ -130,6 +130,11 @@ def test12(self): @db_session def test13(self): + persons = select(p.full_name for p in Person if count(p.cars_by_color1('white')) > 1) + self.assertEqual(set(persons), {'Alexander Tischenko'}) + + @db_session + def test14(self): # This test checks if accessing function-specific globals works correctly persons = select(p.incorrect_full_name for p in Person if p.has_car)[:] self.assertEqual(set(persons), {'Alexander ***', 'Alexei ***', 'Alexander ***'}) From 021e42dd42861ad493f4e05d8de7ae89d87e9920 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 24 Jul 2018 15:18:53 +0300 Subject: [PATCH 343/547] Fix namespace bug in query.where(...) --- pony/orm/sqltranslation.py | 67 +++++++++------- .../tests/test_select_from_select_queries.py | 80 +++++++++++++++++++ 2 files changed, 117 insertions(+), 30 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index b8bcdbbef..e82cec1e3 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -245,9 +245,7 @@ def check_name_is_single(): translator.sqlquery = monad._subselect(translator.sqlquery, extract_outer_conditions=False) tableref = monad.tableref else: assert False # pragma: no cover - new_namespace = translator.namespace.copy() - new_namespace[name] = ObjectIterMonad(translator, tableref, entity) - translator.namespace_stack.append(new_namespace) + translator.namespace[name] = ObjectIterMonad(translator, tableref, entity) elif node.external: iterable = translator.root_translator.vartypes[translator.filter_num, node.src] if isinstance(iterable, SetType): @@ -800,35 +798,44 @@ def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_ if not original_names: assert argnames - translator.namespace_stack.append({name: monad for name, monad in izip(argnames, translator.expr_monads)}) + namespace = {name: monad for name, monad in izip(argnames, translator.expr_monads)} elif argnames: - translator.namespace_stack.append({name: translator.namespace[name] for name in argnames}) - - translator.dispatch(func_ast) - if isinstance(func_ast, ast.Tuple): nodes = func_ast.nodes - else: nodes = (func_ast,) - if order_by: - translator.inside_order_by = True - new_order = [] - for node in nodes: - if isinstance(node.monad, SetMixin): - t = node.monad.type.item_type - if isinstance(type(t), type): t = t.__name__ - throw(TranslationError, 'Set of %s (%s) cannot be used for ordering' - % (t, ast2src(node))) - new_order.extend(node.monad.getsql()) - translator.order[:0] = new_order - translator.inside_order_by = False + namespace = {name: translator.namespace[name] for name in argnames} else: - for node in nodes: - monad = node.monad - if isinstance(monad, AndMonad): cond_monads = monad.operands - else: cond_monads = [ monad ] - for m in cond_monads: - if not m.aggregated: translator.conditions.extend(m.getsql()) - else: translator.having_conditions.extend(m.getsql()) - translator.vars = None - return translator + namespace = None + if namespace is not None: + translator.namespace_stack.append(namespace) + + try: + translator.dispatch(func_ast) + if isinstance(func_ast, ast.Tuple): nodes = func_ast.nodes + else: nodes = (func_ast,) + if order_by: + translator.inside_order_by = True + new_order = [] + for node in nodes: + if isinstance(node.monad, SetMixin): + t = node.monad.type.item_type + if isinstance(type(t), type): t = t.__name__ + throw(TranslationError, 'Set of %s (%s) cannot be used for ordering' + % (t, ast2src(node))) + new_order.extend(node.monad.getsql()) + translator.order[:0] = new_order + translator.inside_order_by = False + else: + for node in nodes: + monad = node.monad + if isinstance(monad, AndMonad): cond_monads = monad.operands + else: cond_monads = [ monad ] + for m in cond_monads: + if not m.aggregated: translator.conditions.extend(m.getsql()) + else: translator.having_conditions.extend(m.getsql()) + translator.vars = None + return translator + finally: + if namespace is not None: + ns = translator.namespace_stack.pop() + assert ns is namespace def preGenExpr(translator, node): inner_tree = node.code translator_cls = translator.__class__ diff --git a/pony/orm/tests/test_select_from_select_queries.py b/pony/orm/tests/test_select_from_select_queries.py index cbcf60adf..1d63e75c6 100644 --- a/pony/orm/tests/test_select_from_select_queries.py +++ b/pony/orm/tests/test_select_from_select_queries.py @@ -220,6 +220,86 @@ def test_22(self): objects = select(s for s in Student if s.scholarship > 200)[:] # not query, but query result q = select(s.first_name for s in objects) + @db_session + def test_23(self): + q = select(s for s in Student) + q2 = q.filter(lambda x: x.scholarship > 450) + q3 = q2.where(lambda s: s.scholarship < 520) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_24(self): + q = select(s for s in Student) + q2 = q.where(lambda s: s.scholarship > 450) + q3 = q2.filter(lambda x: x.scholarship < 520) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_25(self): + q = Student.select().filter(lambda x: x.scholarship > 450) + q2 = select(s for s in q) + q3 = q2.where(lambda s: s.scholarship < 520) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_26(self): + q = Student.select().filter(lambda x: x.scholarship > 450) + q2 = q.where(lambda s: s.scholarship < 520) + q3 = select(s for s in q2) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_27(self): + q = Student.select().where(lambda s: s.scholarship > 450) + q2 = select(s for s in q) + q3 = q2.filter(lambda x: x.scholarship < 520) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_28(self): + q = Student.select().where(lambda s: s.scholarship > 450) + q2 = q.filter(lambda x: x.scholarship < 520) + q3 = select(s for s in q2) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_29(self): + q = select(s for s in Student) + q2 = q.where(lambda s: s.scholarship > 450) + q3 = q2.where(lambda s: s.scholarship < 520) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_30(self): + q = select(s for s in Student) + q2 = q.filter(lambda x: x.scholarship > 450) + q3 = q2.filter(lambda z: z.scholarship < 520) + self.assertEqual(set(q3), {Student[3]}) + + @db_session + def test_31(self): + q = select(s for s in Student).order_by(lambda s: s.scholarship) + q2 = q.where(lambda s: s.scholarship > 450) + self.assertEqual(set(q2), {Student[3]}) + + @db_session + def test_32(self): + q = select(s for s in Student).order_by(lambda s: s.scholarship) + q2 = q.filter(lambda z: z.scholarship > 450) + self.assertEqual(set(q2), {Student[3]}) + + @db_session + def test_33(self): + q = select(s for s in Student).sort_by(lambda x: x.scholarship) + q2 = q.where(lambda s: s.scholarship > 450) + self.assertEqual(set(q2), {Student[3]}) + + @db_session + def test_34(self): + q = select(s for s in Student).sort_by(lambda x: x.scholarship) + q2 = q.filter(lambda s: s.scholarship > 450) + self.assertEqual(set(q2), {Student[3]}) + if __name__ == '__main__': unittest.main() From 607b8d9df32a873690a31c5667cb45ac6bfd19b3 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 24 Jul 2018 15:35:47 +0300 Subject: [PATCH 344/547] Update changelog and pony version: 0.7.5-dev -> 0.7.5 --- CHANGELOG.md | 7 +++++++ pony/__init__.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d4cfc237b..af46c62b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,10 @@ +# Pony ORM Release 0.7.5 (2018-07-24) + +## Bugfixes + +* `query.where` and `query.filter` method bug introduced in 0.7.4 was fixed + + # Pony ORM Release 0.7.4 (2018-07-23) ## Major features diff --git a/pony/__init__.py b/pony/__init__.py index 892ceba21..2cbb6adf2 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.5-dev' +__version__ = '0.7.5' uid = str(random.randint(1, 1000000)) From d5b9376d62dcef87b49a9404a3285145ace0f7e7 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 31 Jul 2018 10:16:22 +0300 Subject: [PATCH 345/547] Update Pony version: 0.7.5 -> 0.7.6-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index 2cbb6adf2..90909ed2a 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.5' +__version__ = '0.7.6-dev' uid = str(random.randint(1, 1000000)) From 8bdfc531eaafdab43499063b45b8944d65aa2c8e Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 27 Jul 2018 16:58:12 +0300 Subject: [PATCH 346/547] Fix bulk delete bug introduced in 0.7.4 --- pony/orm/sqltranslation.py | 6 +++++- .../orm/tests/test_declarative_sqltranslator.py | 17 ++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index e82cec1e3..40ae94f71 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -890,7 +890,11 @@ def postName(translator, node): def resolve_name(translator, name): if name not in translator.namespace: throw(TranslationError, 'Name %s is not found in %s' % (name, translator.namespace)) - return translator.namespace[name] + monad = translator.namespace[name] + assert isinstance(monad, Monad) + if monad.translator is not translator: + monad.translator.sqlquery.used_from_subquery = True + return monad def postAdd(translator, node): return node.left.monad + node.right.monad def postSub(translator, node): diff --git a/pony/orm/tests/test_declarative_sqltranslator.py b/pony/orm/tests/test_declarative_sqltranslator.py index 396732a63..65549c6ed 100644 --- a/pony/orm/tests/test_declarative_sqltranslator.py +++ b/pony/orm/tests/test_declarative_sqltranslator.py @@ -372,7 +372,22 @@ def test_optimized_1(self): def test_optimized_2(self): q = select((s, count(s.courses)) for s in Student if count(s.courses) > 1) self.assertEqual(set(q), {(Student[2], 2)}) - + def test_delete(self): + q = select(g for g in Grade if g.teacher.id == 101).delete() + q2 = select(g for g in Grade)[:] + self.assertEqual([g.value for g in q2], ['C']) + def test_delete_2(self): + delete(g for g in Grade if g.teacher.id == 101) + q2 = select(g for g in Grade)[:] + self.assertEqual([g.value for g in q2], ['C']) + def test_delete_3(self): + select(g for g in Grade if g.teacher.id == 101).delete(bulk=True) + q2 = select(g for g in Grade)[:] + self.assertEqual([g.value for g in q2], ['C']) + def test_delete_4(self): + select(g for g in Grade if exists(g2 for g2 in Grade if g2.value > g.value)).delete(bulk=True) + q2 = select(g for g in Grade)[:] + self.assertEqual([g.value for g in q2], ['C']) if __name__ == "__main__": unittest.main() From 1e07706055a16b049e215389a37f47c19485121c Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Fri, 27 Jul 2018 14:20:55 +0300 Subject: [PATCH 347/547] Close #344: Limit without offset --- pony/orm/core.py | 13 +++++--- pony/orm/dbproviders/oracle.py | 14 ++++---- pony/orm/sqltranslation.py | 8 ++++- pony/orm/tests/queries.txt | 30 +++++++++++++++++ .../tests/test_declarative_orderby_limit.py | 33 +++++++++++++++++-- 5 files changed, 83 insertions(+), 15 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index bd6dad0bd..e3b4c2d57 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -3329,7 +3329,7 @@ def select(wrapper, *args): query = query.filter(func, globals, locals) return query filter = select - def limit(wrapper, limit, offset=None): + def limit(wrapper, limit=None, offset=None): return wrapper.select().limit(limit, offset) def page(wrapper, pagenum, pagesize=10): return wrapper.select().page(pagenum, pagesize) @@ -5809,9 +5809,12 @@ def __getitem__(query, key): elif start < 0: throw(TypeError, "Parameter 'start' of slice object cannot be negative") stop = key.stop if stop is None: - if not start: return query._fetch() - else: throw(TypeError, "Parameter 'stop' of slice object should be specified") - if start >= stop: return [] + if not start: + return query._fetch() + else: + return query._fetch(limit=None, offset=start) + if start >= stop: + return query._fetch(limit=0) return query._fetch(limit=stop-start, offset=start) def _fetch(query, limit=None, offset=None, lazy=False): return QueryResult(query, limit, offset, lazy=lazy) @@ -5819,7 +5822,7 @@ def _fetch(query, limit=None, offset=None, lazy=False): def fetch(query, limit=None, offset=None): return query._fetch(limit, offset) @cut_traceback - def limit(query, limit, offset=None): + def limit(query, limit=None, offset=None): return query._fetch(limit, offset, lazy=True) @cut_traceback def page(query, pagenum, pagesize=10): diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index bb6eaddad..4da915a20 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -175,7 +175,8 @@ def SELECT(builder, *sections): indent0 = '' x = 't.*' - if not limit: pass + if not limit and not offset: + pass elif not offset: result = [ indent0, 'SELECT * FROM (\n' ] builder.indent += 1 @@ -188,13 +189,14 @@ def SELECT(builder, *sections): builder.indent += 2 result.extend(builder._subquery(*sections)) builder.indent -= 2 - result.extend((indent2, ') t ')) - if limit[0] == 'VALUE' and offset[0] == 'VALUE' \ - and isinstance(limit[1], int) and isinstance(offset[1], int): + if limit[1] is None: + result.extend((indent2, ') t\n')) + result.extend((indent, ') t WHERE "row-num" > ', builder(offset), '\n')) + else: + result.extend((indent2, ') t ')) total_limit = [ 'VALUE', limit[1] + offset[1] ] result.extend(('WHERE ROWNUM <= ', builder(total_limit), '\n')) - else: result.extend(('WHERE ROWNUM <= ', builder(limit), ' + ', builder(offset), '\n')) - result.extend((indent, ') t WHERE "row-num" > ', builder(offset), '\n')) + result.extend((indent, ') t WHERE "row-num" > ', builder(offset), '\n')) if builder.indent: indent = builder.indent_spaces * builder.indent return '(\n', result, indent + ')' diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 40ae94f71..adb5cacbb 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -670,8 +670,14 @@ def ast_transformer(ast): if translator.order and not aggr_func_name: sql_ast.append([ 'ORDER_BY' ] + translator.order) - if limit is not None: + if limit is not None or offset is not None: assert not aggr_func_name + provider = translator.database.provider + if limit is None: + if provider.dialect == 'SQLite': + limit = -1 + elif provider.dialect == 'MySQL': + limit = 18446744073709551615 limit_section = [ 'LIMIT', [ 'VALUE', limit ]] if offset: limit_section.append([ 'VALUE', offset ]) sql_ast = sql_ast + [ limit_section ] diff --git a/pony/orm/tests/queries.txt b/pony/orm/tests/queries.txt index 6d16da848..29c4a1e17 100644 --- a/pony/orm/tests/queries.txt +++ b/pony/orm/tests/queries.txt @@ -944,3 +944,33 @@ SELECT "g"."NUMBER", LISTAGG("s"."NAME", '+') WITHIN GROUP(ORDER BY 1) FROM "GROUP" "g", "STUDENT" "s" WHERE "g"."NUMBER" = "s"."GROUP" GROUP BY "g"."NUMBER" + +# Test offset without limit + +>>> select(s for s in Student)[3:] + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "Student" "s" +LIMIT -1 OFFSET 3 + +PostgreSQL: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "student" "s" +LIMIT null OFFSET 3 + +MySQL: + +SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` +FROM `student` `s` +LIMIT 18446744073709551615 OFFSET 3 + +Oracle: + +SELECT t.* FROM ( + SELECT t.*, ROWNUM "row-num" FROM ( + SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" + FROM "STUDENT" "s" + ) t +) t WHERE "row-num" > 3 + diff --git a/pony/orm/tests/test_declarative_orderby_limit.py b/pony/orm/tests/test_declarative_orderby_limit.py index 39aeca96d..b7467dfba 100644 --- a/pony/orm/tests/test_declarative_orderby_limit.py +++ b/pony/orm/tests/test_declarative_orderby_limit.py @@ -74,9 +74,9 @@ def test10(self): students = set(select(s for s in Student).order_by(Student.id)[:4]) self.assertEqual(students, {Student[1], Student[2], Student[3], Student[4]}) - @raises_exception(TypeError, "Parameter 'stop' of slice object should be specified") - def test11(self): - students = select(s for s in Student).order_by(Student.id)[4:] + # @raises_exception(TypeError, "Parameter 'stop' of slice object should be specified") + # def test11(self): + # students = select(s for s in Student).order_by(Student.id)[4:] @raises_exception(TypeError, "Parameter 'start' of slice object cannot be negative") def test12(self): @@ -116,5 +116,32 @@ def test19(self): students = q[:] self.assertEqual(students, [Student[1], Student[2], Student[3], Student[4], Student[5]]) + def test20(self): + q = select(s for s in Student).limit(offset=2) + self.assertEqual(set(q), {Student[3], Student[4], Student[5]}) + self.assertTrue('LIMIT -1 OFFSET 2' in db.last_sql) + + def test21(self): + q = select(s for s in Student).limit(0, offset=2) + self.assertEqual(set(q), set()) + + def test22(self): + q = select(s for s in Student).order_by(Student.id).limit(offset=1) + self.assertEqual(set(q), {Student[2], Student[3], Student[4], Student[5]}) + + def test23(self): + q = select(s for s in Student)[2:2] + self.assertEqual(set(q), set()) + self.assertTrue('LIMIT 0' in db.last_sql) + + def test24(self): + q = select(s for s in Student)[2:] + self.assertEqual(set(q), {Student[3], Student[4], Student[5]}) + + def test25(self): + q = select(s for s in Student)[:2] + self.assertEqual(set(q), {Student[2], Student[1]}) + + if __name__ == "__main__": unittest.main() From 79070307dfbe1bf052356a13e4848ac8d316e7a0 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 30 Jul 2018 17:43:07 +0300 Subject: [PATCH 348/547] Refactoring: simplify SQLBuilder.LIMIT() method --- pony/orm/core.py | 4 ++-- pony/orm/dbproviders/oracle.py | 14 ++++++-------- pony/orm/sqlbuilding.py | 13 ++++++++++--- pony/orm/sqltranslation.py | 6 +++--- 4 files changed, 21 insertions(+), 16 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index e3b4c2d57..cc6b55fbe 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -3098,7 +3098,7 @@ def is_empty(wrapper): select_list = [ 'ALL' ] + [ [ 'COLUMN', None, column ] for column in attr.columns ] attr_offsets = None sql_ast = [ 'SELECT', select_list, [ 'FROM', [ None, 'TABLE', table_name ] ], - where_list, [ 'LIMIT', [ 'VALUE', 1 ] ] ] + where_list, [ 'LIMIT', 1 ] ] sql, adapter = database._ast2sql(sql_ast) attr.cached_empty_sql = sql, adapter, attr_offsets else: sql, adapter, attr_offsets = cached_sql @@ -3993,7 +3993,7 @@ def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_upda if not for_update: sql_ast = [ 'SELECT', select_list, from_list, where_list ] else: sql_ast = [ 'SELECT_FOR_UPDATE', bool(nowait), select_list, from_list, where_list ] if order_by_pk: sql_ast.append([ 'ORDER_BY' ] + [ [ 'COLUMN', None, column ] for column in entity._pk_columns_ ]) - if limit is not None: sql_ast.append([ 'LIMIT', [ 'VALUE', limit ] ]) + if limit is not None: sql_ast.append([ 'LIMIT', limit ]) database = entity._database_ sql, adapter = database._ast2sql(sql_ast) cached_sql = sql, adapter, attr_offsets diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 4da915a20..ed07806d3 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -182,21 +182,19 @@ def SELECT(builder, *sections): builder.indent += 1 result.extend(builder._subquery(*sections)) builder.indent -= 1 - result.extend((indent, ') WHERE ROWNUM <= ', builder(limit), '\n')) + result.extend((indent, ') WHERE ROWNUM <= %d\n' % limit)) else: indent2 = indent + builder.indent_spaces result = [ indent0, 'SELECT %s FROM (\n' % x, indent2, 'SELECT t.*, ROWNUM "row-num" FROM (\n' ] builder.indent += 2 result.extend(builder._subquery(*sections)) builder.indent -= 2 - if limit[1] is None: - result.extend((indent2, ') t\n')) - result.extend((indent, ') t WHERE "row-num" > ', builder(offset), '\n')) + if limit is None: + result.append('%s) t\n' % indent2) + result.append('%s) t WHERE "row-num" > %d\n' % (indent, offset)) else: - result.extend((indent2, ') t ')) - total_limit = [ 'VALUE', limit[1] + offset[1] ] - result.extend(('WHERE ROWNUM <= ', builder(total_limit), '\n')) - result.extend((indent, ') t WHERE "row-num" > ', builder(offset), '\n')) + result.append('%s) t WHERE ROWNUM <= %d\n' % (indent2, limit + offset)) + result.append('%s) t WHERE "row-num" > %d\n' % (indent, offset)) if builder.indent: indent = builder.indent_spaces * builder.indent return '(\n', result, indent + ')' diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 043e33f7a..b38034388 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -1,5 +1,5 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PY2, izip, imap, itervalues, basestring, unicode, buffer +from pony.py23compat import PY2, izip, imap, itervalues, basestring, unicode, buffer, int_types from operator import attrgetter from decimal import Decimal @@ -359,8 +359,15 @@ def DESC(builder, expr): return builder(expr), ' DESC' @indentable def LIMIT(builder, limit, offset=None): - if not offset: return 'LIMIT ', builder(limit), '\n' - else: return 'LIMIT ', builder(limit), ' OFFSET ', builder(offset), '\n' + if limit is None: + limit = 'null' + else: + assert isinstance(limit, int_types) + assert offset is None or isinstance(offset, int) + if offset: + return 'LIMIT %s OFFSET %d\n' % (limit, offset) + else: + return 'LIMIT %s\n' % limit def COLUMN(builder, table_alias, col_name): if builder.suppress_aliases or not table_alias: return [ '%s' % builder.quote_name(col_name) ] diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index adb5cacbb..dbc873530 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -678,9 +678,9 @@ def ast_transformer(ast): limit = -1 elif provider.dialect == 'MySQL': limit = 18446744073709551615 - limit_section = [ 'LIMIT', [ 'VALUE', limit ]] - if offset: limit_section.append([ 'VALUE', offset ]) - sql_ast = sql_ast + [ limit_section ] + limit_section = [ 'LIMIT', limit ] + if offset: limit_section.append(offset) + sql_ast.append(limit_section) sql_ast = ast_transformer(sql_ast) return sql_ast, attr_offsets From b1cb57cd7dfbd9764441e8afa2a87796e61dabb4 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 30 Jul 2018 17:19:29 +0300 Subject: [PATCH 349/547] Fixes #369: add QueryResult.__add__(), .__radd__(), .shuffle() and .to_list(); add error stubs for mutable list methods --- pony/orm/core.py | 33 ++++++++++ .../tests/test_declarative_sqltranslator.py | 64 +++++++++++++++++++ 2 files changed, 97 insertions(+) diff --git a/pony/orm/core.py b/pony/orm/core.py index cc6b55fbe..7b84e38b2 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5919,6 +5919,11 @@ def __length_hint__(self): return len(self._query_result) - self._position +def make_query_result_method_error_stub(name, title=None): + def func(self, *args, **kwargs): + throw(TypeError, 'In order to do %s, cast QueryResult to list first' % (title or name)) + return func + class QueryResult(object): __slots__ = '_query', '_limit', '_offset', '_items', '_expr_type', '_col_names' def __init__(self, query, limit, offset, lazy): @@ -5985,6 +5990,8 @@ def reverse(self): self._get_items().reverse() def sort(self, *args, **kwargs): self._get_items().sort(*args, **kwargs) + def shuffle(self): + shuffle(self._get_items()) @cut_traceback def show(self, width=None): if self._items is None: @@ -6041,6 +6048,32 @@ def to_str(x): def to_json(self, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): return self._query._database.to_json(self, include, exclude, converter, with_schema, schema_hash) + def __add__(self, other): + result = [] + result.extend(self) + result.extend(other) + return result + def __radd__(self, other): + result = [] + result.extend(other) + result.extend(self) + return result + def to_list(self): + return list(self) + + __setitem__ = make_query_result_method_error_stub('__setitem__', 'item assignment') + __delitem__ = make_query_result_method_error_stub('__delitem__', 'item deletion') + __iadd__ = make_query_result_method_error_stub('__iadd__', '+=') + __imul__ = make_query_result_method_error_stub('__imul__', '*=') + __mul__ = make_query_result_method_error_stub('__mul__', '*') + __rmul__ = make_query_result_method_error_stub('__rmul__', '*') + append = make_query_result_method_error_stub('append', 'append') + clear = make_query_result_method_error_stub('clear', 'clear') + extend = make_query_result_method_error_stub('extend', 'extend') + insert = make_query_result_method_error_stub('insert', 'insert') + pop = make_query_result_method_error_stub('pop', 'pop') + remove = make_query_result_method_error_stub('remove', 'remove') + @cut_traceback def show(entity): diff --git a/pony/orm/tests/test_declarative_sqltranslator.py b/pony/orm/tests/test_declarative_sqltranslator.py index 65549c6ed..8181646c7 100644 --- a/pony/orm/tests/test_declarative_sqltranslator.py +++ b/pony/orm/tests/test_declarative_sqltranslator.py @@ -388,6 +388,70 @@ def test_delete_4(self): select(g for g in Grade if exists(g2 for g2 in Grade if g2.value > g.value)).delete(bulk=True) q2 = select(g for g in Grade)[:] self.assertEqual([g.value for g in q2], ['C']) + def test_select_2(self): + result = select(s for s in Student)[:] + self.assertEqual(result, [Student[1], Student[2], Student[3]]) + def test_select_add(self): + result = [None] + select(s for s in Student)[:] + self.assertEqual(result, [None, Student[1], Student[2], Student[3]]) + def test_query_result_radd(self): + result = select(s for s in Student)[:] + [None] + self.assertEqual(result, [Student[1], Student[2], Student[3], None]) + def test_query_result_sort(self): + result = select(s for s in Student)[:] + result.sort() + self.assertEqual(result, [Student[1], Student[2], Student[3]]) + def test_query_result_reverse(self): + result = select(s for s in Student)[:] + items = list(result) + result.reverse() + self.assertEqual(items, list(reversed(result))) + def test_query_result_shuffle(self): + result = select(s for s in Student)[:] + items = set(result) + result.shuffle() + self.assertEqual(items, set(result)) + def test_query_result_to_list(self): + result = select(s for s in Student)[:] + items = result.to_list() + self.assertTrue(type(items) is list) + @raises_exception(TypeError, 'In order to do item assignment, cast QueryResult to list first') + def test_query_result_setitem(self): + result = select(s for s in Student)[:] + result[0] = None + @raises_exception(TypeError, 'In order to do item deletion, cast QueryResult to list first') + def test_query_result_delitem(self): + result = select(s for s in Student)[:] + del result[0] + @raises_exception(TypeError, 'In order to do +=, cast QueryResult to list first') + def test_query_result_iadd(self): + result = select(s for s in Student)[:] + result += None + @raises_exception(TypeError, 'In order to do append, cast QueryResult to list first') + def test_query_result_append(self): + result = select(s for s in Student)[:] + result.append(None) + @raises_exception(TypeError, 'In order to do clear, cast QueryResult to list first') + def test_query_result_clear(self): + result = select(s for s in Student)[:] + result.clear() + @raises_exception(TypeError, 'In order to do extend, cast QueryResult to list first') + def test_query_result_extend(self): + result = select(s for s in Student)[:] + result.extend([]) + @raises_exception(TypeError, 'In order to do insert, cast QueryResult to list first') + def test_query_result_insert(self): + result = select(s for s in Student)[:] + result.insert(0, None) + @raises_exception(TypeError, 'In order to do pop, cast QueryResult to list first') + def test_query_result_pop(self): + result = select(s for s in Student)[:] + result.pop() + @raises_exception(TypeError, 'In order to do remove, cast QueryResult to list first') + def test_query_result_remove(self): + result = select(s for s in Student)[:] + result.remove(None) + if __name__ == "__main__": unittest.main() From 68e0f1f5bf008224a2f7204c1b20f272316fa4ec Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Tue, 31 Jul 2018 10:12:23 +0300 Subject: [PATCH 350/547] Interactive mode support for PyCharm console --- pony/__init__.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/pony/__init__.py b/pony/__init__.py index 90909ed2a..388a44fc2 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -20,9 +20,12 @@ def detect_mode(): except: pass else: return 'MOD_WSGI' - try: - sys.modules['__main__'].__file__ - except AttributeError: + main = sys.modules['__main__'] + + if not hasattr(main, '__file__'): # console + return 'INTERACTIVE' + + if getattr(main, 'INTERACTIVE_MODE_AVAILABLE', False): # pycharm console return 'INTERACTIVE' if 'flup.server.fcgi' in sys.modules: return 'FCGI-FLUP' From cfad342601224e8fefdd9bc7fd996892396c922b Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Tue, 31 Jul 2018 12:04:40 +0300 Subject: [PATCH 351/547] Fixes #355: Python2 buffer PrimaryKey returns as read-write buffer --- pony/orm/dbapiprovider.py | 3 ++- pony/orm/tests/test_bug_355.py | 23 +++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) create mode 100644 pony/orm/tests/test_bug_355.py diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index cd913230f..ef7ad8bd8 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -641,7 +641,8 @@ def validate(converter, val, obj=None): if isinstance(val, str): return buffer(val) throw(TypeError, "Attribute %r: expected type is 'buffer'. Got: %r" % (converter.attr, type(val))) def sql2py(converter, val): - if not isinstance(val, buffer): + if not isinstance(val, buffer) or \ + (PY2 and converter.attr.pk_offset is not None and 'read-write' in repr(val)): # Issue 355 try: val = buffer(val) except: pass return val diff --git a/pony/orm/tests/test_bug_355.py b/pony/orm/tests/test_bug_355.py new file mode 100644 index 000000000..9a579ebab --- /dev/null +++ b/pony/orm/tests/test_bug_355.py @@ -0,0 +1,23 @@ +import unittest + +from pony import orm +from pony.py23compat import buffer + +class Test(unittest.TestCase): + def test_1(self): + db = orm.Database('sqlite', ':memory:') + + class Buf(db.Entity): + pk = orm.PrimaryKey(buffer) + + db.generate_mapping(create_tables=True) + + x = buffer(b'123') + + with orm.db_session: + Buf(pk=x) + orm.commit() + + with orm.db_session: + t = Buf[x] + From 8d9cbec0d4a916a1cad3edecd5c6f2b508e8daed Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 31 Jul 2018 13:59:51 +0300 Subject: [PATCH 352/547] Fixes #370 Memory leak in 0.7.4-0.7.5, caused by 88cb468c --- pony/orm/core.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 7b84e38b2..79e33afb4 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5303,24 +5303,23 @@ def extract_vars(filter_num, extractors, globals, locals, cells=None): def unpickle_query(query_result): return query_result -filter_num_counter = itertools.count() - class Query(object): def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False): assert isinstance(tree, ast.GenExprInner) tree, extractors = create_extractors(code_key, tree, globals, locals, special_functions, const_functions) - filter_num = next(filter_num_counter) + filter_num = 0 vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) node = tree.quals[0].iter origin = vars[filter_num, node.src] if isinstance(origin, Query): - database = origin._translator.database + base_query = origin elif isinstance(origin, QueryResult): - database = origin._query._translator.database + base_query = origin._query elif isinstance(origin, QueryResultIterator): - database = origin._query_result._query._translator.database + base_query = origin._query_result._query else: + base_query = None if isinstance(origin, EntityIter): origin = origin.entity elif not isinstance(origin, EntityMeta): @@ -5331,6 +5330,12 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False if database is None: throw(TranslationError, 'Entity %s is not mapped to a database' % origin.__name__) if database.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % origin.__name__) + if base_query is not None: + database = base_query._translator.database + filter_num = base_query._filter_num + 1 + vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) + + query._filter_num = filter_num database.provider.normalize_vars(vars, vartypes) query._key = HashableDict(code_key=code_key, vartypes=vartypes, left_join=left_join, filters=()) @@ -5694,23 +5699,23 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names else: original_names = True - filter_num = next(filter_num_counter) + new_filter_num = query._filter_num + 1 func_ast, extractors = create_extractors( func_id, func_ast, globals, locals, special_functions, const_functions, argnames or prev_translator.namespace) if extractors: - vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) + vars, vartypes = extract_vars(new_filter_num, extractors, globals, locals, cells) query._database.provider.normalize_vars(vars, vartypes) new_vars = query._vars.copy() new_vars.update(vars) else: new_vars, vartypes = query._vars, HashableDict() tup = (('order_by' if order_by else 'where' if original_names else 'filter', func_id, vartypes),) new_key = HashableDict(query._key, filters=query._key['filters'] + tup) - new_filters = query._filters + (('apply_lambda', filter_num, order_by, func_ast, argnames, original_names, extractors, None, vartypes),) + new_filters = query._filters + (('apply_lambda', new_filter_num, order_by, func_ast, argnames, original_names, extractors, None, vartypes),) new_translator, new_vars = query._get_translator(new_key, new_vars) if new_translator is None: prev_optimized = prev_translator.optimize - new_translator = prev_translator.apply_lambda(filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) + new_translator = prev_translator.apply_lambda(new_filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) if not prev_optimized: name_path = new_translator.can_be_optimized() if name_path: @@ -5724,9 +5729,10 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names except UseAnotherTranslator: assert False new_translator = query._reapply_filters(new_translator) - new_translator = new_translator.apply_lambda(filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) + new_translator = new_translator.apply_lambda(new_filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) query._database._translator_cache[new_key] = new_translator - return query._clone(_vars=new_vars, _key=new_key, _filters=new_filters, _translator=new_translator) + return query._clone(_filter_num=new_filter_num, _vars=new_vars, _key=new_key, _filters=new_filters, + _translator=new_translator) def _reapply_filters(query, translator): for tup in query._filters: method_name, args = tup[0], tup[1:] From f5c76c47c52f21a3a8f7d0c5287c14cbb5ad7c44 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 2 Aug 2018 04:41:21 +0300 Subject: [PATCH 353/547] Internal rename: base_query -> prev_query, base_translator -> prev_translator --- pony/orm/core.py | 14 +++++++------- pony/orm/sqltranslation.py | 10 +++++----- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 79e33afb4..e429f9474 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5313,13 +5313,13 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False node = tree.quals[0].iter origin = vars[filter_num, node.src] if isinstance(origin, Query): - base_query = origin + prev_query = origin elif isinstance(origin, QueryResult): - base_query = origin._query + prev_query = origin._query elif isinstance(origin, QueryResultIterator): - base_query = origin._query_result._query + prev_query = origin._query_result._query else: - base_query = None + prev_query = None if isinstance(origin, EntityIter): origin = origin.entity elif not isinstance(origin, EntityMeta): @@ -5330,9 +5330,9 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False if database is None: throw(TranslationError, 'Entity %s is not mapped to a database' % origin.__name__) if database.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % origin.__name__) - if base_query is not None: - database = base_query._translator.database - filter_num = base_query._filter_num + 1 + if prev_query is not None: + database = prev_query._translator.database + filter_num = prev_query._filter_num + 1 vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) query._filter_num = filter_num diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index dbc873530..892dc6a58 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -263,10 +263,10 @@ def check_name_is_single(): tableref.make_join() translator.namespace[name] = node.monad = ObjectIterMonad(translator, tableref, entity) elif isinstance(iterable, QueryType): - base_translator = deepcopy(iterable.translator) - database = base_translator.database + prev_translator = deepcopy(iterable.translator) + database = prev_translator.database try: - translator.process_query_qual(base_translator, names, try_extend_base_query=not i) + translator.process_query_qual(prev_translator, names, try_extend_prev_query=not i) except UseAnotherTranslator as e: translator = e.translator else: throw(TranslationError, 'Inside declarative query, iterator must be entity. ' @@ -444,7 +444,7 @@ def can_be_optimized(translator): if not aggr_path.startswith(name): return False return aggr_path - def process_query_qual(translator, other_translator, names, try_extend_base_query=False): + def process_query_qual(translator, other_translator, names, try_extend_prev_query=False): sqlquery = translator.sqlquery tablerefs = sqlquery.tablerefs expr_types = other_translator.expr_type @@ -468,7 +468,7 @@ def process_query_qual(translator, other_translator, names, try_extend_base_quer ', '.join(ast2src(m.node) for m in other_translator.expr_monads), len(names), expr_count)) - if try_extend_base_query: + if try_extend_prev_query: if other_translator.aggregated: pass elif other_translator.left_join: pass else: From bb0fa0c39bb746a7fac4945542ad0e069cacb76d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 2 Aug 2018 17:59:07 +0300 Subject: [PATCH 354/547] Internal rename: other_translator -> prev_translator --- pony/orm/sqltranslation.py | 50 +++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 892dc6a58..7c794ad1a 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -444,54 +444,54 @@ def can_be_optimized(translator): if not aggr_path.startswith(name): return False return aggr_path - def process_query_qual(translator, other_translator, names, try_extend_prev_query=False): + def process_query_qual(translator, prev_translator, names, try_extend_prev_query=False): sqlquery = translator.sqlquery tablerefs = sqlquery.tablerefs - expr_types = other_translator.expr_type + expr_types = prev_translator.expr_type if not isinstance(expr_types, tuple): expr_types = (expr_types,) expr_count = len(expr_types) if expr_count > 1 and len(names) == 1: throw(NotImplementedError, 'Please unpack a tuple of (%s) in for-loop to individual variables (like: "for x, y in ...")' - % (', '.join(ast2src(m.node) for m in other_translator.expr_monads))) + % (', '.join(ast2src(m.node) for m in prev_translator.expr_monads))) elif expr_count > len(names): throw(TranslationError, 'Not enough values to unpack "for %s in select(%s for ...)" (expected %d, got %d)' % (', '.join(names), - ', '.join(ast2src(m.node) for m in other_translator.expr_monads), + ', '.join(ast2src(m.node) for m in prev_translator.expr_monads), len(names), expr_count)) elif expr_count < len(names): throw(TranslationError, 'Too many values to unpack "for %s in select(%s for ...)" (expected %d, got %d)' % (', '.join(names), - ', '.join(ast2src(m.node) for m in other_translator.expr_monads), + ', '.join(ast2src(m.node) for m in prev_translator.expr_monads), len(names), expr_count)) if try_extend_prev_query: - if other_translator.aggregated: pass - elif other_translator.left_join: pass + if prev_translator.aggregated: pass + elif prev_translator.left_join: pass else: assert translator.parent is None - assert other_translator.vars is None - other_translator.filter_num = translator.filter_num - other_translator.extractors.update(translator.extractors) - other_translator.vars = translator.vars - other_translator.vartypes.update(translator.vartypes) - other_translator.left_join = translator.left_join - other_translator.optimize = translator.optimize - other_translator.namespace_stack = [ - {name: expr for name, expr in izip(names, other_translator.expr_monads)} + assert prev_translator.vars is None + prev_translator.filter_num = translator.filter_num + prev_translator.extractors.update(translator.extractors) + prev_translator.vars = translator.vars + prev_translator.vartypes.update(translator.vartypes) + prev_translator.left_join = translator.left_join + prev_translator.optimize = translator.optimize + prev_translator.namespace_stack = [ + {name: expr for name, expr in izip(names, prev_translator.expr_monads)} ] - raise UseAnotherTranslator(other_translator) + raise UseAnotherTranslator(prev_translator) - if len(names) == 1 and isinstance(other_translator.expr_type, EntityMeta) \ - and not other_translator.aggregated and not other_translator.distinct: + if len(names) == 1 and isinstance(prev_translator.expr_type, EntityMeta) \ + and not prev_translator.aggregated and not prev_translator.distinct: name = names[0] - entity = other_translator.expr_type - [expr_monad] = other_translator.expr_monads + entity = prev_translator.expr_type + [expr_monad] = prev_translator.expr_monads entity_alias = expr_monad.tableref.alias - subquery_ast = other_translator.construct_subquery_ast(star=entity_alias) + subquery_ast = prev_translator.construct_subquery_ast(star=entity_alias) tableref = StarTableRef(sqlquery, name, entity, subquery_ast) tablerefs[name] = tableref tableref.make_join() @@ -499,7 +499,7 @@ def process_query_qual(translator, other_translator, names, try_extend_prev_quer else: aliases = [] aliases_dict = {} - for name, base_expr_monad in izip(names, other_translator.expr_monads): + for name, base_expr_monad in izip(names, prev_translator.expr_monads): t = base_expr_monad.type if isinstance(t, EntityMeta): t_aliases = [] @@ -512,13 +512,13 @@ def process_query_qual(translator, other_translator, names, try_extend_prev_quer aliases.append(name) aliases_dict[base_expr_monad] = name - subquery_ast = other_translator.construct_subquery_ast(aliases=aliases) + subquery_ast = prev_translator.construct_subquery_ast(aliases=aliases) tableref = ExprTableRef(sqlquery, 't', subquery_ast, names, aliases) for name in names: tablerefs[name] = tableref tableref.make_join() - for name, base_expr_monad in izip(names, other_translator.expr_monads): + for name, base_expr_monad in izip(names, prev_translator.expr_monads): t = base_expr_monad.type if isinstance(t, EntityMeta): columns = aliases_dict[base_expr_monad] From f185b053e8d267840d44d65de93f9070636f615b Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 2 Aug 2018 04:56:09 +0300 Subject: [PATCH 355/547] Minor refactoring --- pony/orm/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index e429f9474..d43dbd7b4 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5456,8 +5456,8 @@ def _actual_fetch(query, limit=None, offset=None): cache = database._get_cache() if query._for_update: cache.immediate = True cache.prepare_connection_for_query_execution() # may clear cache.query_results - try: items = cache.query_results[query_key] - except KeyError: + items = cache.query_results.get(query_key) + if items is None: cursor = database._exec_sql(sql, arguments) if isinstance(translator.expr_type, EntityMeta): entity = translator.expr_type From fe6a9fde7fbe28544db543b8759d68e8be969187 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 2 Aug 2018 14:11:07 +0300 Subject: [PATCH 356/547] Fixes #373: 0.7.4/0.7.5 breaks queries using the `in` operator to test membership of another query result --- pony/orm/ormtypes.py | 2 + pony/orm/sqltranslation.py | 36 +++++++++++++++- .../tests/test_select_from_select_queries.py | 42 +++++++++++++++++++ 3 files changed, 78 insertions(+), 2 deletions(-) diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 48ae017f5..c29942350 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -184,6 +184,8 @@ def normalize(value): def normalize_type(t): tt = type(t) if tt is tuple: return tuple(normalize_type(item) for item in t) + if not isinstance(t, type): + return t assert t.__name__ != 'EntityMeta' if tt.__name__ == 'EntityMeta': return t if t is NoneType: return t diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 7c794ad1a..dc05ba704 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -94,6 +94,13 @@ def dispatch_external(translator, node): if isinstance(t.item_type, EntityMeta): monad = EntityMonad(translator, t.item_type) else: throw(NotImplementedError) # pragma: no cover + elif tt is QueryType: + prev_translator = deepcopy(t.translator) + prev_translator.parent = translator + prev_translator.injected = True + if translator.database is not prev_translator.database: + throw(TranslationError, 'Mixing queries from different databases') + monad = QuerySetMonad(translator, prev_translator) elif tt is FuncType: func = t.func func_monad_class = translator.registered_functions.get(func, ErrorSpecialFuncMonad) @@ -171,6 +178,7 @@ def __init__(translator, tree, parent_translator, filter_num=None, extractors=No ASTTranslator.__init__(translator, tree) translator.can_be_cached = True translator.parent = parent_translator + translator.injected = False if parent_translator is None: translator.root_translator = translator translator.database = None @@ -3132,13 +3140,37 @@ def contains(monad, item, not_in=False): subquery_ast = sub.construct_subquery_ast(distinct=False, is_not_null_checks=not_in) sql_ast = [ 'NOT_IN' if not_in else 'IN', [ 'ROW' ] + item_columns, subquery_ast ] else: + ambiguous_names = set() + if sub.injected: + for name in translator.sqlquery.tablerefs: + if name in sub.sqlquery.tablerefs: + ambiguous_names.add(name) subquery_ast = sub.construct_subquery_ast(distinct=False) + if ambiguous_names: + select_ast = subquery_ast[1] + expr_aliases = [] + for i, expr_ast in enumerate(select_ast): + if i > 0: + if expr_ast[0] == 'AS': + expr_ast = expr_ast[1] + expr_alias = 'expr-%d' % i + expr_aliases.append(expr_alias) + expr_ast = [ 'AS', expr_ast, expr_alias ] + select_ast[i] = expr_ast + + new_table_alias = translator.sqlquery.make_alias('t') + new_select_ast = [ 'ALL' ] + for expr_alias in expr_aliases: + new_select_ast.append([ 'COLUMN', new_table_alias, expr_alias ]) + new_from_ast = [ 'FROM', [ new_table_alias, 'SELECT', subquery_ast[1:] ] ] + new_where_ast = [ 'WHERE' ] + subquery_ast = [ 'SELECT', new_select_ast, new_from_ast, new_where_ast ] select_ast, from_ast, where_ast = subquery_ast[1:4] in_conditions = [ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(item_columns, select_ast[1:]) ] - if not sub.aggregated: where_ast += in_conditions - else: + if not ambiguous_names and sub.aggregated: having_ast = find_or_create_having_ast(subquery_ast) having_ast += in_conditions + else: where_ast += in_conditions sql_ast = [ 'NOT_EXISTS' if not_in else 'EXISTS' ] + subquery_ast[2:] return BoolExprMonad(translator, sql_ast, nullable=False) def nonzero(monad): diff --git a/pony/orm/tests/test_select_from_select_queries.py b/pony/orm/tests/test_select_from_select_queries.py index 1d63e75c6..b94f2cf58 100644 --- a/pony/orm/tests/test_select_from_select_queries.py +++ b/pony/orm/tests/test_select_from_select_queries.py @@ -300,6 +300,48 @@ def test_34(self): q2 = q.filter(lambda s: s.scholarship > 450) self.assertEqual(set(q2), {Student[3]}) + @db_session + def test_35(self): + q = select(s for s in Student if s.scholarship > 0) + q2 = select(s.id for s in Student if s not in q) + self.assertEqual(set(q2), {1}) + self.assertEqual(db.last_sql.count('SELECT'), 2) + + @db_session + def test_36(self): + q = select(s for s in Student if s.scholarship > 0) + q2 = select(s.id for s in Student if s not in q[:]) + self.assertEqual(set(q2), {1}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_37(self): + q = select(s.last_name for s in Student if s.scholarship > 0) + q2 = select(s.id for s in Student if s.last_name not in q) + self.assertEqual(set(q2), {1}) + self.assertEqual(db.last_sql.count('SELECT'), 2) + + @db_session + def test_38(self): + q = select(s.last_name for s in Student if s.scholarship > 0) + q2 = select(s.id for s in Student if s.last_name not in q[:]) + self.assertEqual(set(q2), {1}) + self.assertEqual(db.last_sql.count('SELECT'), 1) + + @db_session + def test_39(self): + q = select((s.first_name, s.last_name) for s in Student if s.scholarship > 0) + q2 = select(s.id for s in Student if (s.first_name, s.last_name) not in q) + self.assertEqual(set(q2), {1}) + self.assertTrue(db.last_sql.count('SELECT') > 1) + + # @db_session + # def test_40(self): # TODO + # q = select((s.first_name, s.last_name) for s in Student if s.scholarship > 0) + # q2 = select(s.id for s in Student if (s.first_name, s.last_name) not in q[:]) + # self.assertEqual(set(q2), {1}) + # self.assertTrue(db.last_sql.count('SELECT'), 1) + if __name__ == '__main__': unittest.main() From a20c9e75a9c0a9040aed395ba1e05c4cb2f00926 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Fri, 3 Aug 2018 14:20:55 +0300 Subject: [PATCH 357/547] Fixes #374: auto=True for all types --- pony/orm/core.py | 2 -- pony/orm/dbschema.py | 4 ++-- pony/orm/tests/test_diagram_attribute.py | 5 +++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index d43dbd7b4..fcfa64c1f 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1908,8 +1908,6 @@ def __init__(attr, py_type, *args, **kwargs): attr.entity = attr.name = None attr.args = args attr.auto = kwargs.pop('auto', False) - if attr.auto and (attr.py_type not in int_types): throw(TypeError, - '`auto=True` option can be specified for `int` attributes only, not for `%s`' % (attr.py_type.__name__)) attr.cascade_delete = kwargs.pop('cascade_delete', None) attr.reverse = kwargs.pop('reverse', None) diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index 5e733ed84..0fd971ae8 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -1,5 +1,5 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import itervalues, basestring +from pony.py23compat import itervalues, basestring, int_types from operator import attrgetter @@ -219,7 +219,7 @@ def get_sql(column): result = [] append = result.append append(quote_name(column.name)) - if column.is_pk == 'auto' and column.auto_template: + if column.is_pk == 'auto' and column.auto_template and column.converter.py_type in int_types: append(case(column.auto_template % dict(type=column.sql_type))) else: append(case(column.sql_type)) diff --git a/pony/orm/tests/test_diagram_attribute.py b/pony/orm/tests/test_diagram_attribute.py index 9a8997d03..fccda95ee 100644 --- a/pony/orm/tests/test_diagram_attribute.py +++ b/pony/orm/tests/test_diagram_attribute.py @@ -233,11 +233,12 @@ class Entity2(db.Entity): d = Optional('Entity1', reverse='a') db.generate_mapping() - @raises_exception(TypeError, '`auto=True` option can be specified for `int` attributes only, not for `str`') def test_attribute24(self): db = Database('sqlite', ':memory:') class Entity1(db.Entity): - a = Required(str, auto=True) + a = PrimaryKey(str, auto=True) + db.generate_mapping(create_tables=True) + self.assertTrue('AUTOINCREMENT' not in db.schema.tables['Entity1'].get_create_command()) @raises_exception(TypeError, "Parameters 'column' and 'columns' cannot be specified simultaneously") def test_columns1(self): From 345b5f57ee46c99f91ca61f1a489fd13879aaa38 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Fri, 3 Aug 2018 11:27:54 +0300 Subject: [PATCH 358/547] f-strings support --- pony/orm/asttranslation.py | 29 ++++++++++++ pony/orm/decompiling.py | 13 ++++++ pony/orm/sqltranslation.py | 17 +++++++- pony/orm/tests/test-f-strings.py | 75 ++++++++++++++++++++++++++++++++ pony/thirdparty/compiler/ast.py | 41 +++++++++++++++++ 5 files changed, 174 insertions(+), 1 deletion(-) create mode 100644 pony/orm/tests/test-f-strings.py diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 704985076..379afef6c 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -67,6 +67,7 @@ def ast2src(tree): class PythonTranslator(ASTTranslator): def __init__(translator, tree): ASTTranslator.__init__(translator, tree) + translator.top_level_f_str = None translator.dispatch(tree) def call(translator, method, node): node.src = method(translator, node) @@ -223,6 +224,34 @@ def postAssName(translator, node): return node.name def postKeyword(translator, node): return '='.join((node.name, node.expr.src)) + def preStr(self, node): + if self.top_level_f_str is None: + self.top_level_f_str = node + def postStr(self, node): + if self.top_level_f_str is node: + self.top_level_f_str = None + return "f%r" % ('{%s}' % node.value.src) + return '{%s}' % node.value.src + def preJoinedStr(self, node): + if self.top_level_f_str is None: + self.top_level_f_str = node + def postJoinedStr(self, node): + result = ''.join( + value.value if isinstance(value, ast.Const) else value.src + for value in node.values) + if self.top_level_f_str is node: + self.top_level_f_str = None + return "f%r" % result + return result + def preFormattedValue(self, node): + if self.top_level_f_str is None: + self.top_level_f_str = node + def postFormattedValue(self, node): + res = '{%s:%s}' % (node.value.src, node.fmt_spec.src) + if self.top_level_f_str is node: + self.top_level_f_str = None + return "f%r" % res + return res nonexternalizable_types = (ast.Keyword, ast.Sliceobj, ast.List, ast.Tuple) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index 6a8b3a6f1..c5513d593 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -185,6 +185,10 @@ def BUILD_SLICE(decompiler, size): def BUILD_TUPLE(decompiler, size): return ast.Tuple(decompiler.pop_items(size)) + def BUILD_STRING(decompiler, count): + values = list(reversed([decompiler.stack.pop() for _ in range(count)])) + return ast.JoinedStr(values) + def CALL_FUNCTION(decompiler, argc, star=None, star2=None): pop = decompiler.stack.pop kwarg, posarg = divmod(argc, 256) @@ -261,6 +265,15 @@ def FOR_ITER(decompiler, endpos): ifs = [] return ast.GenExprFor(assign, iter, ifs) + def FORMAT_VALUE(decompiler, flags): + if flags in (0, 1, 2, 3): + value = decompiler.stack.pop() + return ast.Str(value, flags) + elif flags == 4: + fmt_spec = decompiler.stack.pop() + value = decompiler.stack.pop() + return ast.FormattedValue(value, fmt_spec) + def GET_ITER(decompiler): pass diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index dc05ba704..88b8756e4 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1045,7 +1045,22 @@ def postIfExp(translator, node): nullable=test_monad.nullable or then_monad.nullable or else_monad.nullable) result.aggregated = test_monad.aggregated or then_monad.aggregated or else_monad.aggregated return result - + def postStr(translator, node): + val_monad = node.value.monad + if isinstance(val_monad, StringMixin): + return val_monad + sql = ['TO_STR', val_monad.getsql()[0] ] + return StringExprMonad(translator, unicode, sql, nullable=val_monad.nullable) + def postJoinedStr(translator, node): + nullable = False + for subnode in node.values: + assert isinstance(subnode.monad, StringMixin), (subnode.monad, subnode) + if subnode.monad.nullable: + nullable = True + sql = [ 'CONCAT' ] + [ value.monad.getsql()[0] for value in node.values ] + return StringExprMonad(translator, unicode, sql, nullable=nullable) + def postFormattedValue(translator, node): + throw(NotImplementedError, 'You cannot set width and precision markers in query') def coerce_monads(m1, m2, for_comparison=False): result_type = coerce_types(m1.type, m2.type) if result_type in numeric_types and bool in (m1.type, m2.type) and ( diff --git a/pony/orm/tests/test-f-strings.py b/pony/orm/tests/test-f-strings.py new file mode 100644 index 000000000..71a7c4950 --- /dev/null +++ b/pony/orm/tests/test-f-strings.py @@ -0,0 +1,75 @@ +import unittest +from pony.orm.core import * +from pony.orm.tests.testutils import * +from sys import version_info + +db = Database('sqlite', ':memory:') + +class Person(db.Entity): + first_name = Required(str) + last_name = Required(str) + age = Optional(int) + value = Required(float) + + +db.generate_mapping(create_tables=True) + +with db_session: + Person(id=1, first_name='Alexander', last_name='Tischenko', age=23, value=1.4) + Person(id=2, first_name='Alexander', last_name='Kozlovskiy', age=42, value=1.2) + Person(id=3, first_name='Arthur', last_name='Pendragon', age=54, value=1.33) + Person(id=4, first_name='Okita', last_name='Souji', age=15, value=2.1) + Person(id=5, first_name='Musashi', last_name='Miyamoto', age=None, value=0.9) + Person(id=6, first_name='Jeanne', last_name="d'Arc", age=30, value=43.212) + +class TestFString(unittest.TestCase): + def setUp(self): + rollback() + db_session.__enter__() + def tearDown(self): + rollback() + db_session.__exit__() + + if version_info[:2] >= (3, 6): + + def test_1(self): + x = 'Alexander' + y = 'Tischenko' + q = select(p.id for p in Person if p.first_name + ' ' + p.last_name == f'{x} {y}') + self.assertEqual(set(q), {1}) + + def test_2(self): + q = select(p.id for p in Person if f'{p.first_name} {p.last_name}' == 'Alexander Tischenko') + self.assertEqual(set(q), {1}) + + def test_3(self): + x = 'Great' + q = select(f'{p.first_name} the {x}' for p in Person if p.id == 1) + self.assertEqual(set(q), {'Alexander the Great'}) + + def test_4(self): + q = select(f'{p.first_name} {p.age}' for p in Person if p.id == 1) + self.assertEqual(set(q), {'Alexander 23'}) + + def test_5(self): + q = select(f'{p.first_name} {p.age}' for p in Person if p.id == 1) + self.assertEqual(set(q), {'Alexander 23'}) + + @raises_exception(NotImplementedError, 'You cannot set width and precision markers in query') + def test_6(self): + width = 3 + precision = 4 + q = select(p.id for p in Person if f'{p.value:{width}.{precision}}')[:] + self.assertEqual({2,}, set(q)) + + def test_7(self): + x = 'Tischenko' + q = select(p.first_name + f"{' ' + x}" for p in Person if p.id == 1) + self.assertEqual(set(q), {'Alexander Tischenko'}) + + @sql_debugging + def test_8(self): + q = select(p for p in Person if not p.age).show() + + + diff --git a/pony/thirdparty/compiler/ast.py b/pony/thirdparty/compiler/ast.py index 7268a2573..e5596d4b6 100644 --- a/pony/thirdparty/compiler/ast.py +++ b/pony/thirdparty/compiler/ast.py @@ -530,6 +530,20 @@ def getChildNodes(self): def __repr__(self): return "For(%s, %s, %s, %s)" % (repr(self.assign), repr(self.list), repr(self.body), repr(self.else_)) +class FormattedValue(Node): + def __init__(self, value, fmt_spec): + self.value = value + self.fmt_spec = fmt_spec + + def getChildren(self): + return self.value, self.fmt_spec + + def getChildNodes(self): + return self.value, self.fmt_spec + + def __repr__(self): + return "FormattedValue(%s, %s)" % (self.value, self.fmt_spec) + class From(Node): def __init__(self, modname, names, level, lineno=None): self.modname = modname @@ -1231,6 +1245,33 @@ def getChildNodes(self): def __repr__(self): return "Stmt(%s)" % (repr(self.nodes),) +class Str(Node): + def __init__(self, value, flags): + self.value = value + self.flags = flags + + def getChildren(self): + return self.value, self.flags + + def getChildNodes(self): + return self.value, + + def __repr__(self): + return "Str(%s, %d)" % (self.value, self.flags) + +class JoinedStr(Node): + def __init__(self, values): + self.values = values + + def getChildren(self): + return self.values + + def getChildNodes(self): + return self.values + + def __repr__(self): + return "JoinedStr(%s)" % (', '.join(repr(value) for value in self.values)) + class Sub(Node): def __init__(self, leftright, lineno=None): self.left = leftright[0] From 841b8fd83187bafa48cf8a5ff4ed628acb7e93d8 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 3 Aug 2018 16:51:34 +0300 Subject: [PATCH 359/547] Local variable renaming: key -> varkey --- pony/orm/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index fcfa64c1f..f552f5aa6 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5260,7 +5260,7 @@ def extract_vars(filter_num, extractors, globals, locals, cells=None): vars = {} vartypes = HashableDict() for src, extractor in iteritems(extractors): - key = filter_num, src + varkey = filter_num, src try: value = extractor(globals, locals) except Exception as cause: raise ExprEvalError(src, cause) @@ -5282,7 +5282,7 @@ def extract_vars(filter_num, extractors, globals, locals, cells=None): if src == 'None' and value is not None: throw(TranslationError) if src == 'True' and value is not True: throw(TranslationError) if src == 'False' and value is not False: throw(TranslationError) - try: vartypes[key], value = normalize(value) + try: vartypes[varkey], value = normalize(value) except TypeError: if not isinstance(value, dict): unsupported = False @@ -5294,8 +5294,8 @@ def extract_vars(filter_num, extractors, globals, locals, cells=None): if src == '.0': throw(TypeError, 'Query cannot iterate over anything but entity class or another query') throw(TypeError, 'Expression `%s` has unsupported type %r' % (src, typename)) - vartypes[key], value = normalize(value) - vars[key] = value + vartypes[varkey], value = normalize(value) + vars[varkey] = value return vars, vartypes def unpickle_query(query_result): From 797f9e2f5651e1b461426d69251b22cac521b749 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 3 Aug 2018 20:22:42 +0300 Subject: [PATCH 360/547] Fix complex queries: use code_key as part of varkeys --- pony/orm/core.py | 26 +++++++++------- pony/orm/sqltranslation.py | 31 ++++++++++++------- .../tests/test_select_from_select_queries.py | 14 +++++++++ 3 files changed, 47 insertions(+), 24 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index f552f5aa6..20364a2ce 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5249,7 +5249,7 @@ def raw_sql(sql, result_type=None): locals = sys._getframe(1).f_locals return RawSQL(sql, globals, locals, result_type) -def extract_vars(filter_num, extractors, globals, locals, cells=None): +def extract_vars(code_key, filter_num, extractors, globals, locals, cells=None): if cells: locals = locals.copy() for name, cell in cells.items(): @@ -5260,7 +5260,7 @@ def extract_vars(filter_num, extractors, globals, locals, cells=None): vars = {} vartypes = HashableDict() for src, extractor in iteritems(extractors): - varkey = filter_num, src + varkey = filter_num, src, code_key try: value = extractor(globals, locals) except Exception as cause: raise ExprEvalError(src, cause) @@ -5306,10 +5306,11 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False assert isinstance(tree, ast.GenExprInner) tree, extractors = create_extractors(code_key, tree, globals, locals, special_functions, const_functions) filter_num = 0 - vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) + vars, vartypes = extract_vars(code_key, filter_num, extractors, globals, locals, cells) node = tree.quals[0].iter - origin = vars[filter_num, node.src] + varkey = filter_num, node.src, code_key + origin = vars[varkey] if isinstance(origin, Query): prev_query = origin elif isinstance(origin, QueryResult): @@ -5331,11 +5332,12 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False if prev_query is not None: database = prev_query._translator.database filter_num = prev_query._filter_num + 1 - vars, vartypes = extract_vars(filter_num, extractors, globals, locals, cells) + vars, vartypes = extract_vars(code_key, filter_num, extractors, globals, locals, cells) query._filter_num = filter_num database.provider.normalize_vars(vars, vartypes) + query._code_key = code_key query._key = HashableDict(code_key=code_key, vartypes=vartypes, left_join=left_join, filters=()) query._database = database @@ -5347,14 +5349,14 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) translator_cls = database.provider.translator_cls try: - translator = translator_cls(tree_copy, None, filter_num, extractors, vars, vartypes.copy(), left_join=left_join) + translator = translator_cls(tree_copy, None, code_key, filter_num, extractors, vars, vartypes.copy(), left_join=left_join) except UseAnotherTranslator as e: translator = e.translator name_path = translator.can_be_optimized() if name_path: tree_copy = unpickle_ast(pickled_tree) # tree = deepcopy(tree) try: - translator = translator_cls(tree_copy, None, filter_num, extractors, vars, vartypes.copy(), + translator = translator_cls(tree_copy, None, code_key, filter_num, extractors, vars, vartypes.copy(), left_join=True, optimize=name_path) except UseAnotherTranslator as e: translator = e.translator @@ -5701,19 +5703,19 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names func_ast, extractors = create_extractors( func_id, func_ast, globals, locals, special_functions, const_functions, argnames or prev_translator.namespace) if extractors: - vars, vartypes = extract_vars(new_filter_num, extractors, globals, locals, cells) + vars, vartypes = extract_vars(func_id, new_filter_num, extractors, globals, locals, cells) query._database.provider.normalize_vars(vars, vartypes) new_vars = query._vars.copy() new_vars.update(vars) else: new_vars, vartypes = query._vars, HashableDict() tup = (('order_by' if order_by else 'where' if original_names else 'filter', func_id, vartypes),) new_key = HashableDict(query._key, filters=query._key['filters'] + tup) - new_filters = query._filters + (('apply_lambda', new_filter_num, order_by, func_ast, argnames, original_names, extractors, None, vartypes),) + new_filters = query._filters + (('apply_lambda', func_id, new_filter_num, order_by, func_ast, argnames, original_names, extractors, None, vartypes),) new_translator, new_vars = query._get_translator(new_key, new_vars) if new_translator is None: prev_optimized = prev_translator.optimize - new_translator = prev_translator.apply_lambda(new_filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) + new_translator = prev_translator.apply_lambda(func_id, new_filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) if not prev_optimized: name_path = new_translator.can_be_optimized() if name_path: @@ -5721,13 +5723,13 @@ def _process_lambda(query, func, globals, locals, order_by=False, original_names translator_cls = prev_translator.__class__ try: new_translator = translator_cls( - tree_copy, None, prev_translator.original_filter_num, + tree_copy, None, prev_translator.original_code_key, prev_translator.original_filter_num, prev_translator.extractors, None, prev_translator.vartypes.copy(), left_join=True, optimize=name_path) except UseAnotherTranslator: assert False new_translator = query._reapply_filters(new_translator) - new_translator = new_translator.apply_lambda(new_filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) + new_translator = new_translator.apply_lambda(func_id, new_filter_num, order_by, func_ast, argnames, original_names, extractors, new_vars, vartypes) query._database._translator_cache[new_key] = new_translator return query._clone(_filter_num=new_filter_num, _vars=new_vars, _key=new_key, _filters=new_filters, _translator=new_translator) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 88b8756e4..ef7b3bbbd 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -85,7 +85,7 @@ def dispatch(translator, node): translator.call(translator.__class__.dispatch_external, node) def dispatch_external(translator, node): - varkey = translator.filter_num, node.src + varkey = translator.filter_num, node.src, translator.code_key t = translator.root_translator.vartypes[varkey] tt = type(t) if t is NoneType: @@ -172,7 +172,7 @@ def call(translator, method, node): else: throw(TranslationError, 'Too complex aggregation, expressions cannot be combined: %s' % ast2src(node)) return monad - def __init__(translator, tree, parent_translator, filter_num=None, extractors=None, vars=None, vartypes=None, left_join=False, optimize=None): + def __init__(translator, tree, parent_translator, code_key=None, filter_num=None, extractors=None, vars=None, vartypes=None, left_join=False, optimize=None): this = translator assert isinstance(tree, ast.GenExprInner), tree ASTTranslator.__init__(translator, tree) @@ -183,14 +183,17 @@ def __init__(translator, tree, parent_translator, filter_num=None, extractors=No translator.root_translator = translator translator.database = None translator.sqlquery = SqlQuery(left_join=left_join) - assert filter_num is not None + assert code_key is not None and filter_num is not None + translator.code_key = translator.original_code_key = code_key translator.filter_num = translator.original_filter_num = filter_num else: translator.root_translator = parent_translator.root_translator translator.database = parent_translator.database translator.sqlquery = SqlQuery(parent_translator.sqlquery, left_join=left_join) + assert code_key is None and filter_num is None + translator.code_key = parent_translator.code_key translator.filter_num = parent_translator.filter_num - translator.original_filter_num = None + translator.original_code_key = translator.original_filter_num = None translator.extractors = extractors translator.vars = vars translator.vartypes = vartypes @@ -255,7 +258,8 @@ def check_name_is_single(): else: assert False # pragma: no cover translator.namespace[name] = ObjectIterMonad(translator, tableref, entity) elif node.external: - iterable = translator.root_translator.vartypes[translator.filter_num, node.src] + varkey = translator.filter_num, node.src, translator.code_key + iterable = translator.root_translator.vartypes[varkey] if isinstance(iterable, SetType): check_name_is_single() entity = iterable.item_type @@ -482,6 +486,7 @@ def process_query_qual(translator, prev_translator, names, try_extend_prev_query else: assert translator.parent is None assert prev_translator.vars is None + prev_translator.code_key = translator.code_key prev_translator.filter_num = translator.filter_num prev_translator.extractors.update(translator.extractors) prev_translator.vars = translator.vars @@ -801,9 +806,10 @@ def apply_kwfilters(translator, filterattrs, original_names=False): monads.append(CmpMonad('==', attr_monad, param_monad)) for m in monads: translator.conditions.extend(m.getsql()) return translator - def apply_lambda(translator, filter_num, order_by, func_ast, argnames, original_names, extractors, vars, vartypes): + def apply_lambda(translator, func_id, filter_num, order_by, func_ast, argnames, original_names, extractors, vars, vartypes): translator = deepcopy(translator) func_ast = copy_ast(func_ast) # func_ast = deepcopy(func_ast) + translator.code_key = func_id translator.filter_num = filter_num translator.extractors.update(extractors) translator.vars = vars.copy() if vars is not None else None @@ -1541,7 +1547,6 @@ def __call__(monad, *args, **kwargs): if PY2 and isinstance(func, types.UnboundMethodType): func = func.im_func func_id = id(func) - func_filter_num = translator.filter_num, 'func', id(func) func_ast, external_names, cells = decompile(func) func_ast, func_extractors = create_extractors( @@ -1549,7 +1554,7 @@ def __call__(monad, *args, **kwargs): root_translator = translator.root_translator if func not in root_translator.func_extractors_map: - func_vars, func_vartypes = extract_vars(func_filter_num, func_extractors, func.__globals__, {}, cells) + func_vars, func_vartypes = extract_vars(func_id, translator.filter_num, func_extractors, func.__globals__, {}, cells) translator.database.provider.normalize_vars(func_vars, func_vartypes) if func.__closure__: translator.can_be_cached = False @@ -1561,17 +1566,19 @@ def __call__(monad, *args, **kwargs): stack = translator.namespace_stack stack.append(name_mapping) - prev_filter_num = translator.filter_num - translator.filter_num = func_filter_num func_ast = copy_ast(func_ast) try: - translator.dispatch(func_ast) + prev_code_key = translator.code_key + translator.code_key = func_id + try: + translator.dispatch(func_ast) + finally: + translator.code_key = prev_code_key except Exception as e: if len(e.args) == 1 and isinstance(e.args[0], basestring): msg = e.args[0] + ' (inside %s.%s)' % (monad.parent.type.__name__, monad.attrname) e.args = (msg,) raise - translator.filter_num = prev_filter_num stack.pop() return func_ast.monad diff --git a/pony/orm/tests/test_select_from_select_queries.py b/pony/orm/tests/test_select_from_select_queries.py index b94f2cf58..1b3e88997 100644 --- a/pony/orm/tests/test_select_from_select_queries.py +++ b/pony/orm/tests/test_select_from_select_queries.py @@ -342,6 +342,20 @@ def test_39(self): # self.assertEqual(set(q2), {1}) # self.assertTrue(db.last_sql.count('SELECT'), 1) + @db_session + def test_41(self): + def f1(): + x = 21 + return select(s for s in Student if s.age > x) + + def f2(q): + x = 23 + return select(s.last_name for s in Student if s.age < x and s in q) + + q = f1() + q2 = f2(q) + self.assertEqual(set(q2), {'Lee'}) + if __name__ == '__main__': unittest.main() From d5d913c64a2085dc2c85ca8f5a6a45d0c0e319cf Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sun, 5 Aug 2018 11:44:28 +0300 Subject: [PATCH 361/547] f-strings test fixed --- pony/orm/tests/{test-f-strings.py => test_f_strings.py} | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) rename pony/orm/tests/{test-f-strings.py => test_f_strings.py} (96%) diff --git a/pony/orm/tests/test-f-strings.py b/pony/orm/tests/test_f_strings.py similarity index 96% rename from pony/orm/tests/test-f-strings.py rename to pony/orm/tests/test_f_strings.py index 71a7c4950..67c801475 100644 --- a/pony/orm/tests/test-f-strings.py +++ b/pony/orm/tests/test_f_strings.py @@ -67,9 +67,5 @@ def test_7(self): q = select(p.first_name + f"{' ' + x}" for p in Person if p.id == 1) self.assertEqual(set(q), {'Alexander Tischenko'}) - @sql_debugging def test_8(self): - q = select(p for p in Person if not p.age).show() - - - + q = select(p for p in Person if not p.age)[:] From b4f87b368e6fe6ab5fc5f8b248d3eff63e3607f1 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sun, 5 Aug 2018 14:43:23 +0300 Subject: [PATCH 362/547] Fix test f strings for Python version < 3.6 --- pony/orm/tests/py36_test_f_strings.py | 68 +++++++++++++++++++++++++ pony/orm/tests/test_f_strings.py | 71 +-------------------------- 2 files changed, 70 insertions(+), 69 deletions(-) create mode 100644 pony/orm/tests/py36_test_f_strings.py diff --git a/pony/orm/tests/py36_test_f_strings.py b/pony/orm/tests/py36_test_f_strings.py new file mode 100644 index 000000000..6c291ba57 --- /dev/null +++ b/pony/orm/tests/py36_test_f_strings.py @@ -0,0 +1,68 @@ +import unittest +from pony.orm.core import * +from pony.orm.tests.testutils import * + +db = Database('sqlite', ':memory:') + +class Person(db.Entity): + first_name = Required(str) + last_name = Required(str) + age = Optional(int) + value = Required(float) + + +db.generate_mapping(create_tables=True) + +with db_session: + Person(id=1, first_name='Alexander', last_name='Tischenko', age=23, value=1.4) + Person(id=2, first_name='Alexander', last_name='Kozlovskiy', age=42, value=1.2) + Person(id=3, first_name='Arthur', last_name='Pendragon', age=54, value=1.33) + Person(id=4, first_name='Okita', last_name='Souji', age=15, value=2.1) + Person(id=5, first_name='Musashi', last_name='Miyamoto', age=None, value=0.9) + Person(id=6, first_name='Jeanne', last_name="d'Arc", age=30, value=43.212) + +class TestFString(unittest.TestCase): + def setUp(self): + rollback() + db_session.__enter__() + def tearDown(self): + rollback() + db_session.__exit__() + + def test_1(self): + x = 'Alexander' + y = 'Tischenko' + q = select(p.id for p in Person if p.first_name + ' ' + p.last_name == f'{x} {y}') + self.assertEqual(set(q), {1}) + + def test_2(self): + q = select(p.id for p in Person if f'{p.first_name} {p.last_name}' == 'Alexander Tischenko') + self.assertEqual(set(q), {1}) + + def test_3(self): + x = 'Great' + q = select(f'{p.first_name} the {x}' for p in Person if p.id == 1) + self.assertEqual(set(q), {'Alexander the Great'}) + + def test_4(self): + q = select(f'{p.first_name} {p.age}' for p in Person if p.id == 1) + self.assertEqual(set(q), {'Alexander 23'}) + + def test_5(self): + q = select(f'{p.first_name} {p.age}' for p in Person if p.id == 1) + self.assertEqual(set(q), {'Alexander 23'}) + + @raises_exception(NotImplementedError, 'You cannot set width and precision markers in query') + def test_6(self): + width = 3 + precision = 4 + q = select(p.id for p in Person if f'{p.value:{width}.{precision}}')[:] + self.assertEqual({2,}, set(q)) + + def test_7(self): + x = 'Tischenko' + q = select(p.first_name + f"{' ' + x}" for p in Person if p.id == 1) + self.assertEqual(set(q), {'Alexander Tischenko'}) + + def test_8(self): + q = select(p for p in Person if not p.age)[:] diff --git a/pony/orm/tests/test_f_strings.py b/pony/orm/tests/test_f_strings.py index 67c801475..fa6414e49 100644 --- a/pony/orm/tests/test_f_strings.py +++ b/pony/orm/tests/test_f_strings.py @@ -1,71 +1,4 @@ -import unittest -from pony.orm.core import * -from pony.orm.tests.testutils import * from sys import version_info -db = Database('sqlite', ':memory:') - -class Person(db.Entity): - first_name = Required(str) - last_name = Required(str) - age = Optional(int) - value = Required(float) - - -db.generate_mapping(create_tables=True) - -with db_session: - Person(id=1, first_name='Alexander', last_name='Tischenko', age=23, value=1.4) - Person(id=2, first_name='Alexander', last_name='Kozlovskiy', age=42, value=1.2) - Person(id=3, first_name='Arthur', last_name='Pendragon', age=54, value=1.33) - Person(id=4, first_name='Okita', last_name='Souji', age=15, value=2.1) - Person(id=5, first_name='Musashi', last_name='Miyamoto', age=None, value=0.9) - Person(id=6, first_name='Jeanne', last_name="d'Arc", age=30, value=43.212) - -class TestFString(unittest.TestCase): - def setUp(self): - rollback() - db_session.__enter__() - def tearDown(self): - rollback() - db_session.__exit__() - - if version_info[:2] >= (3, 6): - - def test_1(self): - x = 'Alexander' - y = 'Tischenko' - q = select(p.id for p in Person if p.first_name + ' ' + p.last_name == f'{x} {y}') - self.assertEqual(set(q), {1}) - - def test_2(self): - q = select(p.id for p in Person if f'{p.first_name} {p.last_name}' == 'Alexander Tischenko') - self.assertEqual(set(q), {1}) - - def test_3(self): - x = 'Great' - q = select(f'{p.first_name} the {x}' for p in Person if p.id == 1) - self.assertEqual(set(q), {'Alexander the Great'}) - - def test_4(self): - q = select(f'{p.first_name} {p.age}' for p in Person if p.id == 1) - self.assertEqual(set(q), {'Alexander 23'}) - - def test_5(self): - q = select(f'{p.first_name} {p.age}' for p in Person if p.id == 1) - self.assertEqual(set(q), {'Alexander 23'}) - - @raises_exception(NotImplementedError, 'You cannot set width and precision markers in query') - def test_6(self): - width = 3 - precision = 4 - q = select(p.id for p in Person if f'{p.value:{width}.{precision}}')[:] - self.assertEqual({2,}, set(q)) - - def test_7(self): - x = 'Tischenko' - q = select(p.first_name + f"{' ' + x}" for p in Person if p.id == 1) - self.assertEqual(set(q), {'Alexander Tischenko'}) - - def test_8(self): - q = select(p for p in Person if not p.age)[:] +if version_info[:2] >= (3, 6): + from pony.orm.tests.py36_test_f_strings import * \ No newline at end of file From 1683ac0fb0053c38b902ef1cc43848b26d7067ac Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 5 Aug 2018 15:51:36 +0300 Subject: [PATCH 363/547] Fix query optimization --- pony/orm/sqltranslation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index ef7b3bbbd..63ba0c1eb 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -452,8 +452,8 @@ def can_be_optimized(translator): if translator.groupby_monads: return False if len(translator.aggregated_subquery_paths) != 1: return False aggr_path = next(iter(translator.aggregated_subquery_paths)) - for name in translator.sqlquery.tablerefs: - if not aggr_path.startswith(name): + for tableref in translator.sqlquery.tablerefs.values(): + if tableref.joined and not aggr_path.startswith(tableref.name_path): return False return aggr_path def process_query_qual(translator, prev_translator, names, try_extend_prev_query=False): @@ -1258,7 +1258,7 @@ def make_join(tableref, pk_only=False): tableref.alias = parent_alias tableref.pk_columns = left_columns tableref.optimized = True - tableref.joined = True + # tableref.joined = True return parent_alias, left_columns alias = sqlquery.make_alias(tableref.var_name or entity.__name__) join_cond = join_tables(parent_alias, alias, left_columns, pk_columns) From e1dfdafebfd90dba3ce2e062537fef208812e1f8 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Tue, 7 Aug 2018 12:31:21 +0300 Subject: [PATCH 364/547] Now exists() in query does not throw away generator expression --- pony/orm/sqltranslation.py | 13 ++++++- pony/orm/tests/test_exists.py | 72 +++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 pony/orm/tests/test_exists.py diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 63ba0c1eb..e0255a0d2 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1329,7 +1329,8 @@ def mixin_init(monad): def cmp(monad, op, monad2): return CmpMonad(op, monad, monad2) def contains(monad, item, not_in=False): throw(TypeError) - def nonzero(monad): throw(TypeError) + def nonzero(monad): + return CmpMonad('is not', monad, NoneMonad(monad.translator)) def negate(monad): return NotMonad(monad) def getattr(monad, attrname): @@ -2255,6 +2256,8 @@ class DatetimeConstMonad(DatetimeMixin, ConstMonad): pass class BoolMonad(Monad): def __init__(monad, translator, nullable=True): Monad.__init__(monad, translator, bool, nullable=nullable) + def nonzero(monad): + return monad sql_negation = { 'IN' : 'NOT_IN', 'EXISTS' : 'NOT_EXISTS', 'LIKE' : 'NOT_LIKE', 'BETWEEN' : 'NOT_BETWEEN', 'IS_NULL' : 'IS_NOT_NULL' } sql_negation.update((value, key) for key, value in items_list(sql_negation)) @@ -3197,6 +3200,14 @@ def contains(monad, item, not_in=False): return BoolExprMonad(translator, sql_ast, nullable=False) def nonzero(monad): subquery_ast = monad.subtranslator.construct_subquery_ast(distinct=False) + expr_monads = monad.subtranslator.expr_monads + if len(expr_monads) > 1: + throw(NotImplementedError) + expr_monad = expr_monads[0] + if not isinstance(expr_monad, ObjectIterMonad): + sql = expr_monad.nonzero().getsql() + assert subquery_ast[3][0] == 'WHERE' + subquery_ast[3].append(sql[0]) subquery_ast = [ 'EXISTS' ] + subquery_ast[2:] translator = monad.translator return BoolExprMonad(translator, subquery_ast, nullable=False) diff --git a/pony/orm/tests/test_exists.py b/pony/orm/tests/test_exists.py new file mode 100644 index 000000000..7227db88d --- /dev/null +++ b/pony/orm/tests/test_exists.py @@ -0,0 +1,72 @@ +import unittest + +from pony.orm.core import * +from pony.orm.tests.testutils import * + +db = Database('sqlite', ':memory:') + +class Group(db.Entity): + students = Set('Student') + +class Student(db.Entity): + first_name = Required(str) + last_name = Required(str) + login = Optional(str, nullable=True) + graduated = Optional(bool, default=False) + group = Required(Group) + passport = Optional('Passport', column='passport') + +class Passport(db.Entity): + student = Optional(Student) + +db.generate_mapping(create_tables=True) + +with db_session: + g1 = Group() + g2 = Group() + + p = Passport() + + Student(first_name='Mashu', last_name='Kyrielight', login='Shielder', group=g1) + Student(first_name='Okita', last_name='Souji', login='Sakura', group=g1) + Student(first_name='Francis', last_name='Drake', group=g2, graduated=True) + Student(first_name='Oda', last_name='Nobunaga', group=g2, graduated=True) + Student(first_name='William', last_name='Shakespeare', group=g2, graduated=True, passport=p) + +class TestExists(unittest.TestCase): + def setUp(self): + rollback() + db_session.__enter__() + + def tearDown(self): + rollback() + db_session.__exit__() + + def test_1(self): + q = select(g for g in Group if exists(s.login for s in g.students))[:] + self.assertEqual(q[0], Group[1]) + + def test_2(self): + q = select(g for g in Group if exists(s.graduated for s in g.students))[:] + self.assertEqual(q[0], Group[2]) + + def test_3(self): + q = select(s for s in Student if + exists(len(s2.first_name) == len(s.first_name) and s != s2 for s2 in Student))[:] + self.assertEqual(set(q), {Student[1], Student[2], Student[3], Student[5]}) + + def test_4(self): + q = select(g for g in Group if not exists(not s.graduated for s in g.students))[:] + self.assertEqual(q[0], Group[2]) + + def test_5(self): + q = select(g for g in Group if exists(s for s in g.students))[:] + self.assertEqual(set(q), {Group[1], Group[2]}) + + def test_6(self): + q = select(g for g in Group if exists(s.login for s in g.students if s.first_name != 'Okita') and g.id != 10)[:] + self.assertEqual(q[0], Group[1]) + + def test_7(self): + q = select(g for g in Group if exists(s.passport for s in g.students))[:] + self.assertEqual(q[0], Group[2]) \ No newline at end of file From ffa446bc9ccf238e434ad69e8cb843a852607103 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Tue, 7 Aug 2018 13:20:23 +0300 Subject: [PATCH 365/547] Get rid of ponytest and click --- pony/orm/tests/test_inner_join_syntax.py | 2 - pony/orm/tests/test_json.py | 105 +++++++++++------------ 2 files changed, 48 insertions(+), 59 deletions(-) diff --git a/pony/orm/tests/test_inner_join_syntax.py b/pony/orm/tests/test_inner_join_syntax.py index 20d4977db..c70238aed 100644 --- a/pony/orm/tests/test_inner_join_syntax.py +++ b/pony/orm/tests/test_inner_join_syntax.py @@ -3,8 +3,6 @@ from pony.orm import * from pony import orm -import pony.orm.tests.fixtures - class TestJoin(unittest.TestCase): exclude_fixtures = {'test': ['clear_tables']} diff --git a/pony/orm/tests/test_json.py b/pony/orm/tests/test_json.py index d223cf60f..7e4dc42fd 100644 --- a/pony/orm/tests/test_json.py +++ b/pony/orm/tests/test_json.py @@ -2,75 +2,66 @@ import unittest -import click - from pony.orm import * from pony.orm.tests.testutils import raises_exception, raises_if from pony.orm.ormtypes import Json, TrackedValue, TrackedList, TrackedDict -from contextlib import contextmanager - -import pony.orm.tests.fixtures -from ponytest import with_cli_args, TestCase - -class TestJson(TestCase): +class TestJson(unittest.TestCase): - @classmethod - def make_entities(cls): - class Product(cls.db.Entity): + def setUp(self): + self.db = Database('sqlite', ':memory:') + class Product(self.db.Entity): name = Required(str) info = Optional(Json) tags = Optional(Json) - cls.Product = cls.db.Product - - - @db_session - def setUp(self): - self.db.execute('delete from %s' % self.db.Product._table_) - - self.Product( - name='Apple iPad Air 2', - info={ - 'name': 'Apple iPad Air 2', - 'display': { - 'size': 9.7, - 'resolution': [2048, 1536], - 'matrix-type': 'IPS', - 'multi-touch': True - }, - 'os': { - 'type': 'iOS', - 'version': '8' + self.db.generate_mapping(create_tables=True) + + self.Product = Product + + with db_session: + self.Product( + name='Apple iPad Air 2', + info={ + 'name': 'Apple iPad Air 2', + 'display': { + 'size': 9.7, + 'resolution': [2048, 1536], + 'matrix-type': 'IPS', + 'multi-touch': True + }, + 'os': { + 'type': 'iOS', + 'version': '8' + }, + 'cpu': 'Apple A8X', + 'ram': '8GB', + 'colors': ['Gold', 'Silver', 'Space Gray'], + 'models': [ + { + 'name': 'Wi-Fi', + 'capacity': ['16GB', '64GB'], + 'height': 240, + 'width': 169.5, + 'depth': 6.1, + 'weight': 437, + }, + { + 'name': 'Wi-Fi + Cellular', + 'capacity': ['16GB', '64GB'], + 'height': 240, + 'width': 169.5, + 'depth': 6.1, + 'weight': 444, + }, + ], + 'discontinued': False, + 'videoUrl': None, + 'non-ascii-attr': u'\u0442\u0435\u0441\u0442' }, - 'cpu': 'Apple A8X', - 'ram': '8GB', - 'colors': ['Gold', 'Silver', 'Space Gray'], - 'models': [ - { - 'name': 'Wi-Fi', - 'capacity': ['16GB', '64GB'], - 'height': 240, - 'width': 169.5, - 'depth': 6.1, - 'weight': 437, - }, - { - 'name': 'Wi-Fi + Cellular', - 'capacity': ['16GB', '64GB'], - 'height': 240, - 'width': 169.5, - 'depth': 6.1, - 'weight': 444, - }, - ], - 'discontinued': False, - 'videoUrl': None, - 'non-ascii-attr': u'\u0442\u0435\u0441\u0442' - }, - tags=['Tablets', 'Apple', 'Retina']) + tags=['Tablets', 'Apple', 'Retina']) def test(self): From 6be926639db0225b53449296c6e1fc03057e5d9a Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 5 Aug 2018 19:23:52 +0300 Subject: [PATCH 366/547] Translator context added --- pony/orm/dbproviders/oracle.py | 8 +- pony/orm/dbproviders/sqlite.py | 2 +- pony/orm/sqltranslation.py | 490 ++++++++++++++++----------------- 3 files changed, 245 insertions(+), 255 deletions(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index ed07806d3..2928c58c1 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -106,15 +106,15 @@ class OraSchema(DBSchema): column_class = OraColumn class OraNoneMonad(sqltranslation.NoneMonad): - def __init__(monad, translator, value=None): + def __init__(monad, value=None): assert value in (None, '') - sqltranslation.ConstMonad.__init__(monad, translator, None) + sqltranslation.ConstMonad.__init__(monad, None) class OraConstMonad(sqltranslation.ConstMonad): @staticmethod - def new(translator, value): + def new(value): if value == '': value = None - return sqltranslation.ConstMonad.new(translator, value) + return sqltranslation.ConstMonad.new(value) class OraTranslator(sqltranslation.SQLTranslator): dialect = 'Oracle' diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index d8d20e82f..ff3196b04 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -40,7 +40,7 @@ def func(translator, monad): sql = monad.getsql() assert len(sql) == 1 translator = monad.translator - return StringExprMonad(translator, monad.type, [ sqlop, sql[0] ]) + return StringExprMonad(monad.type, [ sqlop, sql[0] ]) func.__name__ = sqlop return func diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index e0255a0d2..ac674ac7d 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -12,7 +12,7 @@ from pony.thirdparty.compiler import ast from pony import options, utils -from pony.utils import is_ident, throw, reraise, copy_ast, between, concat, coalesce +from pony.utils import localbase, is_ident, throw, reraise, copy_ast, between, concat, coalesce from pony.orm.asttranslation import ASTTranslator, ast2src, TranslationError, create_extractors from pony.orm.decompiling import decompile from pony.orm.ormtypes import \ @@ -68,6 +68,16 @@ def type2str(t): try: return t.__name__ except: return str(t) +class Local(localbase): + def __init__(local): + local.translators = [] + + @property + def translator(self): + return local.translators[-1] + +local = Local() + class SQLTranslator(ASTTranslator): dialect = None row_value_syntax = True @@ -75,6 +85,16 @@ class SQLTranslator(ASTTranslator): json_values_are_comparable = True rowid_support = False + def __enter__(translator): + local.translators.append(translator) + + def __exit__(translator, exc_type, exc_val, exc_tb): + t = local.translators.pop() + if isinstance(exc_val, UseAnotherTranslator): + assert t is exc_val.translator + else: + assert t is translator + def default_post(translator, node): throw(NotImplementedError) # pragma: no cover @@ -89,10 +109,10 @@ def dispatch_external(translator, node): t = translator.root_translator.vartypes[varkey] tt = type(t) if t is NoneType: - monad = ConstMonad.new(translator, None) + monad = ConstMonad.new(None) elif tt is SetType: if isinstance(t.item_type, EntityMeta): - monad = EntityMonad(translator, t.item_type) + monad = EntityMonad(t.item_type) else: throw(NotImplementedError) # pragma: no cover elif tt is QueryType: prev_translator = deepcopy(t.translator) @@ -100,35 +120,35 @@ def dispatch_external(translator, node): prev_translator.injected = True if translator.database is not prev_translator.database: throw(TranslationError, 'Mixing queries from different databases') - monad = QuerySetMonad(translator, prev_translator) + monad = QuerySetMonad(prev_translator) elif tt is FuncType: func = t.func func_monad_class = translator.registered_functions.get(func, ErrorSpecialFuncMonad) - monad = func_monad_class(translator, func) + monad = func_monad_class(func) elif tt is MethodType: obj, func = t.obj, t.func if isinstance(obj, EntityMeta): - entity_monad = EntityMonad(translator, obj) + entity_monad = EntityMonad(obj) if obj.__class__.__dict__.get(func.__name__) is not func: throw(NotImplementedError) monad = MethodMonad(entity_monad, func.__name__) elif node.src == 'random': # For PyPy - monad = FuncRandomMonad(translator, t) + monad = FuncRandomMonad(t) else: throw(NotImplementedError) elif isinstance(node, ast.Name) and node.name in ('True', 'False'): value = True if node.name == 'True' else False - monad = ConstMonad.new(translator, value) + monad = ConstMonad.new(value) elif tt is tuple: params = [] for i, item_type in enumerate(t): if item_type is NoneType: throw(TypeError, 'Expression `%s` should not contain None values' % node.src) - param = ParamMonad.new(translator, item_type, (varkey, i, None)) + param = ParamMonad.new(item_type, (varkey, i, None)) params.append(param) - monad = ListMonad(translator, params) + monad = ListMonad(params) elif isinstance(t, RawSQLType): - monad = RawSQLMonad(translator, t, varkey) + monad = RawSQLMonad(t, varkey) else: - monad = ParamMonad.new(translator, t, (varkey, None, None)) + monad = ParamMonad.new(t, (varkey, None, None)) node.monad = monad monad.node = node monad.aggregated = monad.nogroup = False @@ -173,6 +193,20 @@ def call(translator, method, node): return monad def __init__(translator, tree, parent_translator, code_key=None, filter_num=None, extractors=None, vars=None, vartypes=None, left_join=False, optimize=None): + local.translators.append(translator) + try: + translator.init(tree, parent_translator, code_key, filter_num, extractors, vars, vartypes, left_join, optimize) + except UseAnotherTranslator as e: + assert local.translators + t = local.translators.pop() + assert t is e.translator + raise + else: + assert local.translators + t = local.translators.pop() + assert t is translator + + def init(translator, tree, parent_translator, code_key=None, filter_num=None, extractors=None, vars=None, vartypes=None, left_join=False, optimize=None): this = translator assert isinstance(tree, ast.GenExprInner), tree ASTTranslator.__init__(translator, tree) @@ -256,7 +290,7 @@ def check_name_is_single(): translator.sqlquery = monad._subselect(translator.sqlquery, extract_outer_conditions=False) tableref = monad.tableref else: assert False # pragma: no cover - translator.namespace[name] = ObjectIterMonad(translator, tableref, entity) + translator.namespace[name] = ObjectIterMonad(tableref, entity) elif node.external: varkey = translator.filter_num, node.src, translator.code_key iterable = translator.root_translator.vartypes[varkey] @@ -273,14 +307,16 @@ def check_name_is_single(): tableref = TableRef(translator.sqlquery, name, entity) translator.sqlquery.tablerefs[name] = tableref tableref.make_join() - translator.namespace[name] = node.monad = ObjectIterMonad(translator, tableref, entity) + translator.namespace[name] = node.monad = ObjectIterMonad(tableref, entity) elif isinstance(iterable, QueryType): prev_translator = deepcopy(iterable.translator) database = prev_translator.database try: translator.process_query_qual(prev_translator, names, try_extend_prev_query=not i) except UseAnotherTranslator as e: + assert local.translators and local.translators[-1] is translator translator = e.translator + local.translators[-1] = translator else: throw(TranslationError, 'Inside declarative query, iterator must be entity. ' 'Got: for %s in %s' % (name, ast2src(qual.iter))) @@ -331,7 +367,7 @@ def check_name_is_single(): else: name_path += '-' + attr.name tableref = translator.sqlquery.add_tableref(name_path, parent_tableref, attr) if j == last_index: - translator.namespace[name] = ObjectIterMonad(translator, tableref, tableref.entity) + translator.namespace[name] = ObjectIterMonad(tableref, tableref.entity) if can_affect_distinct is not None: tableref.can_affect_distinct = can_affect_distinct parent_tableref = tableref @@ -508,7 +544,7 @@ def process_query_qual(translator, prev_translator, names, try_extend_prev_query tableref = StarTableRef(sqlquery, name, entity, subquery_ast) tablerefs[name] = tableref tableref.make_join() - translator.namespace[name] = ObjectIterMonad(translator, tableref, entity) + translator.namespace[name] = ObjectIterMonad(tableref, entity) else: aliases = [] aliases_dict = {} @@ -536,11 +572,11 @@ def process_query_qual(translator, prev_translator, names, try_extend_prev_query if isinstance(t, EntityMeta): columns = aliases_dict[base_expr_monad] expr_tableref = ExprJoinedTableRef(sqlquery, tableref, columns, name, t) - expr_monad = ObjectIterMonad(translator, expr_tableref, t) + expr_monad = ObjectIterMonad(expr_tableref, t) else: column = aliases_dict[base_expr_monad] expr_ast = ['COLUMN', tableref.alias, column] - expr_monad = ExprMonad.new(translator, t, expr_ast, base_expr_monad.nullable) + expr_monad = ExprMonad.new(t, expr_ast, base_expr_monad.nullable) assert name not in translator.namespace translator.namespace[name] = expr_monad def construct_subquery_ast(translator, aliases=None, star=None, distinct=None, is_not_null_checks=False): @@ -788,24 +824,25 @@ def order_by_attributes(translator, attrs): return translator def apply_kwfilters(translator, filterattrs, original_names=False): translator = deepcopy(translator) - if original_names: - object_monad = translator.tree.quals[0].iter.monad - assert isinstance(object_monad.type, EntityMeta) - else: - object_monad = translator.tree.expr.monad - if not isinstance(object_monad.type, EntityMeta): - throw(TypeError, 'Keyword arguments are not allowed when query result is not entity objects') - - monads = [] - none_monad = NoneMonad(translator) - for attr, id, is_none in filterattrs: - attr_monad = object_monad.getattr(attr.name) - if is_none: monads.append(CmpMonad('is', attr_monad, none_monad)) + with translator: + if original_names: + object_monad = translator.tree.quals[0].iter.monad + assert isinstance(object_monad.type, EntityMeta) else: - param_monad = ParamMonad.new(translator, attr.py_type, (id, None, None)) - monads.append(CmpMonad('==', attr_monad, param_monad)) - for m in monads: translator.conditions.extend(m.getsql()) - return translator + object_monad = translator.tree.expr.monad + if not isinstance(object_monad.type, EntityMeta): + throw(TypeError, 'Keyword arguments are not allowed when query result is not entity objects') + + monads = [] + none_monad = NoneMonad() + for attr, id, is_none in filterattrs: + attr_monad = object_monad.getattr(attr.name) + if is_none: monads.append(CmpMonad('is', attr_monad, none_monad)) + else: + param_monad = ParamMonad.new(attr.py_type, (id, None, None)) + monads.append(CmpMonad('==', attr_monad, param_monad)) + for m in monads: translator.conditions.extend(m.getsql()) + return translator def apply_lambda(translator, func_id, filter_num, order_by, func_ast, argnames, original_names, extractors, vars, vartypes): translator = deepcopy(translator) func_ast = copy_ast(func_ast) # func_ast = deepcopy(func_ast) @@ -826,36 +863,37 @@ def apply_lambda(translator, func_id, filter_num, order_by, func_ast, argnames, if namespace is not None: translator.namespace_stack.append(namespace) - try: - translator.dispatch(func_ast) - if isinstance(func_ast, ast.Tuple): nodes = func_ast.nodes - else: nodes = (func_ast,) - if order_by: - translator.inside_order_by = True - new_order = [] - for node in nodes: - if isinstance(node.monad, SetMixin): - t = node.monad.type.item_type - if isinstance(type(t), type): t = t.__name__ - throw(TranslationError, 'Set of %s (%s) cannot be used for ordering' - % (t, ast2src(node))) - new_order.extend(node.monad.getsql()) - translator.order[:0] = new_order - translator.inside_order_by = False - else: - for node in nodes: - monad = node.monad - if isinstance(monad, AndMonad): cond_monads = monad.operands - else: cond_monads = [ monad ] - for m in cond_monads: - if not m.aggregated: translator.conditions.extend(m.getsql()) - else: translator.having_conditions.extend(m.getsql()) - translator.vars = None - return translator - finally: - if namespace is not None: - ns = translator.namespace_stack.pop() - assert ns is namespace + with translator: + try: + translator.dispatch(func_ast) + if isinstance(func_ast, ast.Tuple): nodes = func_ast.nodes + else: nodes = (func_ast,) + if order_by: + translator.inside_order_by = True + new_order = [] + for node in nodes: + if isinstance(node.monad, SetMixin): + t = node.monad.type.item_type + if isinstance(type(t), type): t = t.__name__ + throw(TranslationError, 'Set of %s (%s) cannot be used for ordering' + % (t, ast2src(node))) + new_order.extend(node.monad.getsql()) + translator.order[:0] = new_order + translator.inside_order_by = False + else: + for node in nodes: + monad = node.monad + if isinstance(monad, AndMonad): cond_monads = monad.operands + else: cond_monads = [ monad ] + for m in cond_monads: + if not m.aggregated: translator.conditions.extend(m.getsql()) + else: translator.having_conditions.extend(m.getsql()) + translator.vars = None + return translator + finally: + if namespace is not None: + ns = translator.namespace_stack.pop() + assert ns is namespace def preGenExpr(translator, node): inner_tree = node.code translator_cls = translator.__class__ @@ -863,7 +901,7 @@ def preGenExpr(translator, node): subtranslator = translator_cls(inner_tree, translator) except UseAnotherTranslator: assert False - return QuerySetMonad(translator, subtranslator) + return QuerySetMonad(subtranslator) def postGenExprIf(translator, node): monad = node.test.monad if monad.type is not bool: monad = monad.nonzero() @@ -894,15 +932,15 @@ def postConst(translator, node): if type(value) is frozenset: value = tuple(sorted(value)) if type(value) is not tuple: - return ConstMonad.new(translator, value) + return ConstMonad.new(value) else: - return ListMonad(translator, [ ConstMonad.new(translator, item) for item in value ]) + return ListMonad([ ConstMonad.new(item) for item in value ]) def postEllipsis(translator, node): - return ConstMonad.new(translator, Ellipsis) + return ConstMonad.new(Ellipsis) def postList(translator, node): - return ListMonad(translator, [ item.monad for item in node.nodes ]) + return ListMonad([ item.monad for item in node.nodes ]) def postTuple(translator, node): - return ListMonad(translator, [ item.monad for item in node.nodes ]) + return ListMonad([ item.monad for item in node.nodes ]) def postName(translator, node): monad = translator.resolve_name(node.name) assert monad is not None @@ -990,7 +1028,7 @@ def preCallFunc(translator, node): subtranslator = translator_cls(inner_expr, translator) except UseAnotherTranslator: assert False - monad = QuerySetMonad(translator, subtranslator) + monad = QuerySetMonad(subtranslator) if method_name == 'exists': monad = monad.nonzero() return monad @@ -1013,7 +1051,7 @@ def postSubscript(translator, node): if len(node.subs) > 1: for x in node.subs: if isinstance(x, ast.Sliceobj): throw(TypeError) - key = ListMonad(translator, [ item.monad for item in node.subs ]) + key = ListMonad([ item.monad for item in node.subs ]) return node.expr.monad[key] sub = node.subs[0] if isinstance(sub, ast.Sliceobj): @@ -1047,7 +1085,7 @@ def postIfExp(translator, node): elif not translator.row_value_syntax: throw(NotImplementedError) else: then_sql, else_sql = [ 'ROW' ] + then_sql, [ 'ROW' ] + else_sql expr = [ 'CASE', None, [ [ test_sql, then_sql ] ], else_sql ] - result = ExprMonad.new(translator, result_type, expr, + result = ExprMonad.new(result_type, expr, nullable=test_monad.nullable or then_monad.nullable or else_monad.nullable) result.aggregated = test_monad.aggregated or then_monad.aggregated or else_monad.aggregated return result @@ -1056,7 +1094,7 @@ def postStr(translator, node): if isinstance(val_monad, StringMixin): return val_monad sql = ['TO_STR', val_monad.getsql()[0] ] - return StringExprMonad(translator, unicode, sql, nullable=val_monad.nullable) + return StringExprMonad(unicode, sql, nullable=val_monad.nullable) def postJoinedStr(translator, node): nullable = False for subnode in node.values: @@ -1064,7 +1102,7 @@ def postJoinedStr(translator, node): if subnode.monad.nullable: nullable = True sql = [ 'CONCAT' ] + [ value.monad.getsql()[0] for value in node.values ] - return StringExprMonad(translator, unicode, sql, nullable=nullable) + return StringExprMonad(unicode, sql, nullable=nullable) def postFormattedValue(translator, node): throw(NotImplementedError, 'You cannot set width and precision markers in query') def coerce_monads(m1, m2, for_comparison=False): @@ -1076,11 +1114,11 @@ def coerce_monads(m1, m2, for_comparison=False): if result_type is bool: result_type = int if m1.type is bool: - new_m1 = NumericExprMonad(translator, int, [ 'TO_INT', m1.getsql()[0] ], nullable=m1.nullable) + new_m1 = NumericExprMonad(int, [ 'TO_INT', m1.getsql()[0] ], nullable=m1.nullable) new_m1.aggregated = m1.aggregated m1 = new_m1 if m2.type is bool: - new_m2 = NumericExprMonad(translator, int, [ 'TO_INT', m2.getsql()[0] ], nullable=m2.nullable) + new_m2 = NumericExprMonad(int, [ 'TO_INT', m2.getsql()[0] ], nullable=m2.nullable) new_m2.aggregated = m2.aggregated m2 = new_m2 return result_type, m1, m2 @@ -1318,9 +1356,9 @@ class MonadMixin(with_metaclass(MonadMeta)): class Monad(with_metaclass(MonadMeta)): disable_distinct = False disable_ordering = False - def __init__(monad, translator, type, nullable=True): + def __init__(monad, type, nullable=True): monad.node = None - monad.translator = translator + monad.translator = local.translator monad.type = type monad.nullable = nullable monad.mixin_init() @@ -1330,7 +1368,7 @@ def cmp(monad, op, monad2): return CmpMonad(op, monad, monad2) def contains(monad, item, not_in=False): throw(TypeError) def nonzero(monad): - return CmpMonad('is not', monad, NoneMonad(monad.translator)) + return CmpMonad('is not', monad, NoneMonad()) def negate(monad): return NotMonad(monad) def getattr(monad, attrname): @@ -1366,7 +1404,7 @@ def count(monad, distinct=None): '%s database provider does not support entities ' 'with composite primary keys inside aggregate functions. Got: {EXPR}' % translator.dialect) - result = ExprMonad.new(translator, int, [ 'COUNT', distinct, expr ], nullable=False) + result = ExprMonad.new(int, [ 'COUNT', distinct, expr ], nullable=False) result.aggregated = True return result def aggregate(monad, func_name, distinct=None, sep=None): @@ -1408,7 +1446,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): if func_name == 'GROUP_CONCAT': if sep is not None: aggr_ast.append(['VALUE', sep]) - result = ExprMonad.new(translator, result_type, aggr_ast, nullable=True) + result = ExprMonad.new(result_type, aggr_ast, nullable=True) result.aggregated = True return result def __call__(monad, *args, **kwargs): throw(TypeError) @@ -1426,9 +1464,9 @@ def __xor__(monad): throw(TypeError) def abs(monad): throw(TypeError) def cast_from_json(monad, type): assert False, monad def to_int(monad): - return NumericExprMonad(monad.translator, int, [ 'TO_INT', monad.getsql()[0] ], nullable=monad.nullable) + return NumericExprMonad(int, [ 'TO_INT', monad.getsql()[0] ], nullable=monad.nullable) def to_real(monad): - return NumericExprMonad(monad.translator, float, [ 'TO_REAL', monad.getsql()[0] ], nullable=monad.nullable) + return NumericExprMonad(float, [ 'TO_REAL', monad.getsql()[0] ], nullable=monad.nullable) def distinct_from_monad(distinct, default=None): if distinct is None: @@ -1438,10 +1476,10 @@ def distinct_from_monad(distinct, default=None): throw(TypeError, '`distinct` value should be True or False. Got: %s' % ast2src(distinct.node)) class RawSQLMonad(Monad): - def __init__(monad, translator, rawtype, varkey, nullable=True): + def __init__(monad, rawtype, varkey, nullable=True): if rawtype.result_type is None: type = rawtype else: type = normalize_type(rawtype.result_type) - Monad.__init__(monad, translator, type, nullable=nullable) + Monad.__init__(monad, type, nullable=nullable) monad.rawtype = rawtype monad.varkey = varkey def contains(monad, item, not_in=False): @@ -1453,7 +1491,7 @@ def contains(monad, item, not_in=False): '%s database provider does not support tuples. Got: {EXPR} ' % translator.dialect) op = 'NOT_IN' if not_in else 'IN' sql = [ op, expr, monad.getsql() ] - return BoolExprMonad(translator, sql, nullable=item.nullable) + return BoolExprMonad(sql, nullable=item.nullable) def nonzero(monad): return monad def getsql(monad, sqlquery=None): provider = monad.translator.database.provider @@ -1510,7 +1548,7 @@ def raise_forgot_parentheses(monad): class MethodMonad(Monad): def __init__(monad, parent, attrname): - Monad.__init__(monad, parent.translator, 'METHOD', nullable=False) + Monad.__init__(monad, 'METHOD', nullable=False) monad.parent = parent monad.attrname = attrname def getattr(monad, attrname): @@ -1584,8 +1622,9 @@ def __call__(monad, *args, **kwargs): return func_ast.monad class EntityMonad(Monad): - def __init__(monad, translator, entity): - Monad.__init__(monad, translator, SetType(entity)) + def __init__(monad, entity): + Monad.__init__(monad, SetType(entity)) + translator = monad.translator if translator.database is None: translator.database = entity._database_ elif translator.database is not entity._database_: @@ -1594,11 +1633,10 @@ def __getitem__(monad, *args): throw(NotImplementedError) class ListMonad(Monad): - def __init__(monad, translator, items): - Monad.__init__(monad, translator, tuple(item.type for item in items)) + def __init__(monad, items): + Monad.__init__(monad, tuple(item.type for item in items)) monad.items = items def contains(monad, x, not_in=False): - translator = monad.translator if isinstance(x.type, SetType): throw(TypeError, "Type of `%s` is '%s'. Expression `{EXPR}` is not supported" % (ast2src(x.node), type2str(x.type))) for item in monad.items: check_comparable(x, item) @@ -1610,7 +1648,7 @@ def contains(monad, x, not_in=False): sql = sqland([ sqlor([ [ 'NE', a, b ] for a, b in izip(left_sql, item.getsql()) ]) for item in monad.items ]) else: sql = sqlor([ sqland([ [ 'EQ', a, b ] for a, b in izip(left_sql, item.getsql()) ]) for item in monad.items ]) - return BoolExprMonad(translator, sql, nullable=x.nullable or any(item.nullable for item in monad.items)) + return BoolExprMonad(sql, nullable=x.nullable or any(item.nullable for item in monad.items)) def getsql(monad, sqlquery=None): return [ [ 'ROW' ] + [ item.getsql()[0] for item in monad.items ] ] @@ -1624,7 +1662,6 @@ class UuidMixin(MonadMixin): def make_numeric_binop(op, sqlop): def numeric_binop(monad, monad2): - translator = monad.translator if isinstance(monad2, (AttrSetMonad, NumericSetExprMonad)): return NumericSetExprMonad(op, sqlop, monad, monad2) if monad2.type == 'METHOD': raise_forgot_parentheses(monad2) @@ -1633,7 +1670,7 @@ def numeric_binop(monad, monad2): throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), op)) left_sql = monad.getsql()[0] right_sql = monad2.getsql()[0] - return NumericExprMonad(translator, result_type, [ sqlop, left_sql, right_sql ]) + return NumericExprMonad(result_type, [ sqlop, left_sql, right_sql ]) numeric_binop.__name__ = sqlop return numeric_binop @@ -1647,28 +1684,25 @@ def mixin_init(monad): __floordiv__ = make_numeric_binop('//', 'FLOORDIV') __mod__ = make_numeric_binop('%', 'MOD') def __pow__(monad, monad2): - translator = monad.translator if not isinstance(monad2, NumericMixin): throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), '**')) left_sql = monad.getsql() right_sql = monad2.getsql() assert len(left_sql) == len(right_sql) == 1 - return NumericExprMonad(translator, float, [ 'POW', left_sql[0], right_sql[0] ], + return NumericExprMonad(float, [ 'POW', left_sql[0], right_sql[0] ], nullable=monad.nullable or monad2.nullable) def __neg__(monad): sql = monad.getsql()[0] - translator = monad.translator - return NumericExprMonad(translator, monad.type, [ 'NEG', sql ], nullable=monad.nullable) + return NumericExprMonad(monad.type, [ 'NEG', sql ], nullable=monad.nullable) def abs(monad): sql = monad.getsql()[0] - translator = monad.translator - return NumericExprMonad(translator, monad.type, [ 'ABS', sql ], nullable=monad.nullable) + return NumericExprMonad(monad.type, [ 'ABS', sql ], nullable=monad.nullable) def nonzero(monad): translator = monad.translator sql = monad.getsql()[0] if not (translator.dialect == 'PostgreSQL' and monad.type is bool): sql = [ 'NE', sql, [ 'VALUE', 0 ] ] - return BoolExprMonad(translator, sql, nullable=False) + return BoolExprMonad(sql, nullable=False) def negate(monad): sql = monad.getsql()[0] translator = monad.translator @@ -1681,24 +1715,22 @@ def negate(monad): result_sql = [ 'NOT', [ 'COALESCE', sql, [ 'VALUE', True ] ] ] else: result_sql = [ 'EQ', [ 'COALESCE', sql, [ 'VALUE', 0 ] ], [ 'VALUE', 0 ] ] - return BoolExprMonad(translator, result_sql, nullable=False) + return BoolExprMonad(result_sql, nullable=False) def numeric_attr_factory(name): def attr_func(monad): sql = [ name, monad.getsql()[0] ] - translator = monad.translator - return NumericExprMonad(translator, int, sql, nullable=monad.nullable) + return NumericExprMonad(int, sql, nullable=monad.nullable) attr_func.__name__ = name.lower() return attr_func def make_datetime_binop(op, sqlop): def datetime_binop(monad, monad2): - translator = monad.translator if monad2.type != timedelta: throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), op)) expr_monad_cls = DateExprMonad if monad.type is date else DatetimeExprMonad delta = monad2.value if isinstance(monad2, TimedeltaConstMonad) else monad2.getsql()[0] - return expr_monad_cls(translator, monad.type, [ sqlop, monad.getsql()[0], delta ], + return expr_monad_cls(monad.type, [ sqlop, monad.getsql()[0], delta ], nullable=monad.nullable or monad2.nullable) datetime_binop.__name__ = sqlop return datetime_binop @@ -1727,9 +1759,8 @@ class DatetimeMixin(DateMixin): def mixin_init(monad): assert monad.type is datetime def call_date(monad): - translator = monad.translator sql = [ 'DATE', monad.getsql()[0] ] - return ExprMonad.new(translator, date, sql, nullable=monad.nullable) + return ExprMonad.new(date, sql, nullable=monad.nullable) attr_hour = numeric_attr_factory('HOUR') attr_minute = numeric_attr_factory('MINUTE') attr_second = numeric_attr_factory('SECOND') @@ -1738,14 +1769,13 @@ def call_date(monad): def make_string_binop(op, sqlop): def string_binop(monad, monad2): - translator = monad.translator if not are_comparable_types(monad.type, monad2.type, sqlop): if monad2.type == 'METHOD': raise_forgot_parentheses(monad2) throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), op)) left_sql = monad.getsql() right_sql = monad2.getsql() assert len(left_sql) == len(right_sql) == 1 - return StringExprMonad(translator, monad.type, [ sqlop, left_sql[0], right_sql[0] ], + return StringExprMonad(monad.type, [ sqlop, left_sql[0], right_sql[0] ], nullable=monad.nullable or monad2.nullable) string_binop.__name__ = sqlop return string_binop @@ -1754,8 +1784,7 @@ def make_string_func(sqlop): def func(monad): sql = monad.getsql() assert len(sql) == 1 - translator = monad.translator - return StringExprMonad(translator, monad.type, [ sqlop, sql[0] ], nullable=monad.nullable) + return StringExprMonad(monad.type, [ sqlop, sql[0] ], nullable=monad.nullable) func.__name__ = sqlop return func @@ -1764,7 +1793,6 @@ def mixin_init(monad): assert issubclass(monad.type, basestring), monad.type __add__ = make_string_binop('+', 'CONCAT') def __getitem__(monad, index): - translator = monad.translator if isinstance(index, ListMonad): throw(TypeError, "String index must be of 'int' type. Got 'tuple' in {EXPR}") elif isinstance(index, slice): if index.step is not None: throw(TypeError, 'Step is not supported in {EXPR}') @@ -1775,7 +1803,7 @@ def __getitem__(monad, index): and (stop is None or isinstance(stop, NumericConstMonad)): if start is not None: start = start.value if stop is not None: stop = stop.value - return ConstMonad.new(translator, monad.value[start:stop]) + return ConstMonad.new(monad.value[start:stop]) if start is not None and start.type is not int: throw(TypeError, "Invalid type of start index (expected 'int', got %r) in string slice {EXPR}" % type2str(start.type)) @@ -1783,7 +1811,7 @@ def __getitem__(monad, index): throw(TypeError, "Invalid type of stop index (expected 'int', got %r) in string slice {EXPR}" % type2str(stop.type)) expr_sql = monad.getsql()[0] - if start is None: start = ConstMonad.new(translator, 0) + if start is None: start = ConstMonad.new(0) if isinstance(start, NumericConstMonad): if start.value < 0: throw(NotImplementedError, 'Negative indices are not supported in string slice {EXPR}') @@ -1808,11 +1836,11 @@ def __getitem__(monad, index): len_sql = [ 'SUB', stop_sql, start.getsql()[0] ] sql = [ 'SUBSTR', expr_sql, start_sql, len_sql ] - return StringExprMonad(translator, monad.type, sql, + return StringExprMonad(monad.type, sql, nullable=monad.nullable or start.nullable or stop is not None and stop.nullable) if isinstance(monad, StringConstMonad) and isinstance(index, NumericConstMonad): - return ConstMonad.new(translator, monad.value[index.value]) + return ConstMonad.new(monad.value[index.value]) if index.type is not int: throw(TypeError, 'String indices must be integers. Got %r in expression {EXPR}' % type2str(index.type)) expr_sql = monad.getsql()[0] @@ -1824,7 +1852,7 @@ def __getitem__(monad, index): inner_sql = index.getsql()[0] index_sql = [ 'ADD', inner_sql, [ 'CASE', None, [ (['GE', inner_sql, [ 'VALUE', 0 ]], [ 'VALUE', 1 ]) ], [ 'VALUE', 0 ] ] ] sql = [ 'SUBSTR', expr_sql, index_sql, [ 'VALUE', 1 ] ] - return StringExprMonad(translator, monad.type, sql, nullable=monad.nullable) + return StringExprMonad(monad.type, sql, nullable=monad.nullable) def negate(monad): sql = monad.getsql()[0] translator = monad.translator @@ -1837,7 +1865,7 @@ def negate(monad): result_sql = [ 'OR', result_sql, [ 'IS_NULL', sql ] ] else: result_sql = [ 'EQ', [ 'COALESCE', sql, [ 'VALUE', '' ] ], [ 'VALUE', '' ]] - result = BoolExprMonad(translator, result_sql, nullable=False) + result = BoolExprMonad(result_sql, nullable=False) result.aggregated = monad.aggregated return result def nonzero(monad): @@ -1847,13 +1875,12 @@ def nonzero(monad): result_sql = [ 'IS_NOT_NULL', sql ] else: result_sql = [ 'NE', sql, [ 'VALUE', '' ] ] - result = BoolExprMonad(translator, result_sql, nullable=False) + result = BoolExprMonad(result_sql, nullable=False) result.aggregated = monad.aggregated return result def len(monad): sql = monad.getsql()[0] - translator = monad.translator - return NumericExprMonad(translator, int, [ 'LENGTH', sql ]) + return NumericExprMonad(int, [ 'LENGTH', sql ]) def contains(monad, item, not_in=False): check_comparable(item, monad, 'LIKE') return monad._like(item, before='%', after='%', not_like=not_in) @@ -1898,9 +1925,8 @@ def _like(monad, item, before=None, after=None, not_like=False): if escape: result_sql.append([ 'VALUE', '!' ]) if not_like and monad.nullable and (isinstance(monad, AttrMonad) or translator.dialect == 'Oracle'): result_sql = [ 'OR', result_sql, [ 'IS_NULL', sql ] ] - return BoolExprMonad(translator, result_sql, nullable=not_like) + return BoolExprMonad(result_sql, nullable=not_like) def strip(monad, chars, strip_type): - translator = monad.translator if chars is not None and not are_comparable_types(monad.type, chars.type, None): if chars.type == 'METHOD': raise_forgot_parentheses(chars) throw(TypeError, "'chars' argument must be of %r type in {EXPR}, got: %r" @@ -1908,7 +1934,7 @@ def strip(monad, chars, strip_type): parent_sql = monad.getsql()[0] sql = [ strip_type, parent_sql ] if chars is not None: sql.append(chars.getsql()[0]) - return StringExprMonad(translator, monad.type, sql, nullable=monad.nullable) + return StringExprMonad(monad.type, sql, nullable=monad.nullable) def call_strip(monad, chars=None): return monad.strip(chars, 'TRIM') def call_lstrip(monad, chars=None): @@ -1938,37 +1964,31 @@ def contains(monad, key, not_in=False): key_sql = key.getsql()[0] sql = [ 'JSON_CONTAINS', base_sql, path, key_sql ] if not_in: sql = [ 'NOT', sql ] - return BoolExprMonad(translator, sql) + return BoolExprMonad(sql) def __or__(monad, other): - translator = monad.translator if not isinstance(other, JsonMixin): raise TypeError('Should be JSON: %s' % ast2src(other.node)) left_sql = monad.getsql()[0] right_sql = other.getsql()[0] sql = [ 'JSON_CONCAT', left_sql, right_sql ] - return JsonExprMonad(translator, Json, sql) + return JsonExprMonad(Json, sql) def len(monad): - translator = monad.translator sql = [ 'JSON_ARRAY_LENGTH', monad.getsql()[0] ] - return NumericExprMonad(translator, int, sql) + return NumericExprMonad(int, sql) def cast_from_json(monad, type): if type in (Json, NoneType): return monad throw(TypeError, 'Cannot compare whole JSON value, you need to select specific sub-item: {EXPR}') def nonzero(monad): - translator = monad.translator - return BoolExprMonad(translator, [ 'JSON_NONZERO', monad.getsql()[0] ]) + return BoolExprMonad([ 'JSON_NONZERO', monad.getsql()[0] ]) class ObjectMixin(MonadMixin): def mixin_init(monad): assert isinstance(monad.type, EntityMeta) def negate(monad): - translator = monad.translator - return CmpMonad('is', monad, NoneMonad(translator)) + return CmpMonad('is', monad, NoneMonad()) def nonzero(monad): - translator = monad.translator - return CmpMonad('is not', monad, NoneMonad(translator)) + return CmpMonad('is not', monad, NoneMonad()) def getattr(monad, attrname): - translator = monad.translator entity = monad.type attr = entity._adict_.get(attrname) or entity._subclass_adict_.get(attrname) if attr is None: @@ -1991,8 +2011,8 @@ def requires_distinct(monad, joined=False): return monad.attr.reverse.is_collection or monad.parent.requires_distinct(joined) # parent ??? class ObjectIterMonad(ObjectMixin, Monad): - def __init__(monad, translator, tableref, entity): - Monad.__init__(monad, translator, entity) + def __init__(monad, tableref, entity): + Monad.__init__(monad, entity) monad.tableref = tableref def getsql(monad, sqlquery=None): entity = monad.type @@ -2004,7 +2024,6 @@ def requires_distinct(monad, joined=False): class AttrMonad(Monad): @staticmethod def new(parent, attr, *args, **kwargs): - translator = parent.translator type = normalize_type(attr.py_type) if type in numeric_types: cls = NumericAttrMonad elif type is unicode: cls = StringAttrMonad @@ -2023,9 +2042,8 @@ def __new__(cls, *args): return Monad.__new__(cls) def __init__(monad, parent, attr): assert monad.__class__ is not AttrMonad - translator = parent.translator attr_type = normalize_type(attr.py_type) - Monad.__init__(monad, parent.translator, attr_type) + Monad.__init__(monad, attr_type) monad.parent = parent monad.attr = attr monad.nullable = attr.nullable @@ -2073,7 +2091,7 @@ class JsonAttrMonad(JsonMixin, AttrMonad): pass class ParamMonad(Monad): @staticmethod - def new(translator, type, paramkey): + def new(type, paramkey): type = normalize_type(type) if type in numeric_types: cls = NumericParamMonad elif type is unicode: cls = StringParamMonad @@ -2086,27 +2104,28 @@ def new(translator, type, paramkey): elif type is Json: cls = JsonParamMonad elif isinstance(type, EntityMeta): cls = ObjectParamMonad else: throw(NotImplementedError, 'Parameter {EXPR} has unsupported type %r' % (type,)) - result = cls(translator, type, paramkey) + result = cls(type, paramkey) result.aggregated = False return result def __new__(cls, *args): if cls is ParamMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) - def __init__(monad, translator, type, paramkey): + def __init__(monad, type, paramkey): type = normalize_type(type) - Monad.__init__(monad, translator, type, nullable=False) + Monad.__init__(monad, type, nullable=False) monad.paramkey = paramkey if not isinstance(type, EntityMeta): - provider = translator.database.provider + provider = monad.translator.database.provider monad.converter = provider.get_converter_by_py_type(type) else: monad.converter = None def getsql(monad, sqlquery=None): return [ [ 'PARAM', monad.paramkey, monad.converter ] ] class ObjectParamMonad(ObjectMixin, ParamMonad): - def __init__(monad, translator, entity, paramkey): - assert translator.database is entity._database_ - ParamMonad.__init__(monad, translator, entity, paramkey) + def __init__(monad, entity, paramkey): + ParamMonad.__init__(monad, entity, paramkey) + if monad.translator.database is not entity._database_: + assert monad.translator.database is entity._database_, (paramkey, monad.translator.database, entity._database_) varkey, i, j = paramkey assert j is None monad.params = tuple((varkey, i, j) for j in xrange(len(entity._pk_converters_))) @@ -2132,7 +2151,7 @@ def getsql(monad, sqlquery=None): class ExprMonad(Monad): @staticmethod - def new(translator, type, sql, nullable=True): + def new(type, sql, nullable=True): if type in numeric_types: cls = NumericExprMonad elif type is unicode: cls = StringExprMonad elif type is date: cls = DateExprMonad @@ -2142,12 +2161,12 @@ def new(translator, type, sql, nullable=True): elif type is Json: cls = JsonExprMonad elif isinstance(type, EntityMeta): cls = ObjectExprMonad else: throw(NotImplementedError, type) # pragma: no cover - return cls(translator, type, sql, nullable=nullable) + return cls(type, sql, nullable=nullable) def __new__(cls, *args, **kwargs): if cls is ExprMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) - def __init__(monad, translator, type, sql, nullable=True): - Monad.__init__(monad, translator, type, nullable=nullable) + def __init__(monad, type, sql, nullable=True): + Monad.__init__(monad, type, nullable=nullable) monad.sql = sql def getsql(monad, sqlquery=None): return [ monad.sql ] @@ -2167,8 +2186,7 @@ class JsonExprMonad(JsonMixin, ExprMonad): pass class JsonItemMonad(JsonMixin, Monad): def __init__(monad, parent, key): assert isinstance(parent, JsonMixin), parent - translator = parent.translator - Monad.__init__(monad, translator, Json) + Monad.__init__(monad, Json) monad.parent = parent if isinstance(key, slice): if key != slice(None, None, None): throw(NotImplementedError) @@ -2176,6 +2194,7 @@ def __init__(monad, parent, key): elif isinstance(key, (ParamMonad, StringConstMonad, NumericConstMonad, EllipsisMonad)): monad.key_ast = key.getsql()[0] else: throw(TypeError, 'Invalid JSON path item: %s' % ast2src(key.node)) + translator = monad.translator if isinstance(key, (slice, EllipsisMonad)) and not translator.json_path_wildcard_syntax: throw(TranslationError, '%s does not support wildcards in JSON path: {EXPR}' % translator.dialect) def get_path(monad): @@ -2195,7 +2214,7 @@ def cast_from_json(monad, type): return monad base_monad, path = monad.get_path() sql = [ 'JSON_VALUE', base_monad.getsql()[0], path, type ] - return ExprMonad.new(translator, Json if type is NoneType else type, sql) + return ExprMonad.new(Json if type is NoneType else type, sql) def getsql(monad): base_monad, path = monad.get_path() base_sql = base_monad.getsql()[0] @@ -2206,7 +2225,7 @@ def getsql(monad): class ConstMonad(Monad): @staticmethod - def new(translator, value): + def new(value): value_type, value = normalize(value) if value_type in numeric_types: cls = NumericConstMonad elif value_type is unicode: cls = StringConstMonad @@ -2219,31 +2238,31 @@ def new(translator, value): elif value_type is Json: cls = JsonConstMonad elif issubclass(value_type, type(Ellipsis)): cls = EllipsisMonad else: throw(NotImplementedError, value_type) # pragma: no cover - result = cls(translator, value) + result = cls(value) result.aggregated = False return result def __new__(cls, *args): if cls is ConstMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) - def __init__(monad, translator, value): + def __init__(monad, value): value_type, value = normalize(value) - Monad.__init__(monad, translator, value_type, nullable=value_type is NoneType) + Monad.__init__(monad, value_type, nullable=value_type is NoneType) monad.value = value def getsql(monad, sqlquery=None): return [ [ 'VALUE', monad.value ] ] class NoneMonad(ConstMonad): type = NoneType - def __init__(monad, translator, value=None): + def __init__(monad, value=None): assert value is None - ConstMonad.__init__(monad, translator, value) + ConstMonad.__init__(monad, value) class EllipsisMonad(ConstMonad): pass class StringConstMonad(StringMixin, ConstMonad): def len(monad): - return ConstMonad.new(monad.translator, len(monad.value)) + return ConstMonad.new(len(monad.value)) class JsonConstMonad(JsonMixin, ConstMonad): pass class BufferConstMonad(BufferMixin, ConstMonad): pass @@ -2254,8 +2273,8 @@ class TimedeltaConstMonad(TimedeltaMixin, ConstMonad): pass class DatetimeConstMonad(DatetimeMixin, ConstMonad): pass class BoolMonad(Monad): - def __init__(monad, translator, nullable=True): - Monad.__init__(monad, translator, bool, nullable=nullable) + def __init__(monad, nullable=True): + Monad.__init__(monad, bool, nullable=nullable) def nonzero(monad): return monad @@ -2263,13 +2282,12 @@ def nonzero(monad): sql_negation.update((value, key) for key, value in items_list(sql_negation)) class BoolExprMonad(BoolMonad): - def __init__(monad, translator, sql, nullable=True): - BoolMonad.__init__(monad, translator, nullable=nullable) + def __init__(monad, sql, nullable=True): + BoolMonad.__init__(monad, nullable=nullable) monad.sql = sql def getsql(monad, sqlquery=None): return [ monad.sql ] def negate(monad): - translator = monad.translator sql = monad.sql sqlop = sql[0] negated_op = sql_negation.get(sqlop) @@ -2279,7 +2297,7 @@ def negate(monad): assert len(sql) == 2 negated_sql = sql[1] else: return NotMonad(monad) - return BoolExprMonad(translator, negated_sql, nullable=monad.nullable) + return BoolExprMonad(negated_sql, nullable=monad.nullable) cmp_ops = { '>=' : 'GE', '>' : 'GT', '<=' : 'LE', '<' : 'LT' } @@ -2290,7 +2308,6 @@ class CmpMonad(BoolMonad): EQ = 'EQ' NE = 'NE' def __init__(monad, op, left, right): - translator = left.translator if op == '<>': op = '!=' if left.type is NoneType: assert right.type is not NoneType @@ -2302,7 +2319,7 @@ def __init__(monad, op, left, right): elif op == 'is not': op = '!=' check_comparable(left, right, op) result_type, left, right = coerce_monads(left, right, for_comparison=True) - BoolMonad.__init__(monad, translator, nullable=left.nullable or right.nullable) + BoolMonad.__init__(monad, nullable=left.nullable or right.nullable) monad.op = op monad.aggregated = getattr(left, 'aggregated', False) or getattr(right, 'aggregated', False) @@ -2349,15 +2366,13 @@ class LogicalBinOpMonad(BoolMonad): def __init__(monad, operands): assert len(operands) >= 2 items = [] - translator = operands[0].translator - monad.translator = translator for operand in operands: if operand.type is not bool: items.append(operand.nonzero()) elif isinstance(operand, LogicalBinOpMonad) and monad.binop == operand.binop: items.extend(operand.operands) else: items.append(operand) nullable = any(item.nullable for item in items) - BoolMonad.__init__(monad, items[0].translator, nullable=nullable) + BoolMonad.__init__(monad, nullable=nullable) monad.operands = items def getsql(monad, sqlquery=None): result = [ monad.binop ] @@ -2376,7 +2391,7 @@ class OrMonad(LogicalBinOpMonad): class NotMonad(BoolMonad): def __init__(monad, operand): if operand.type is not bool: operand = operand.nonzero() - BoolMonad.__init__(monad, operand.translator, nullable=operand.nullable) + BoolMonad.__init__(monad, nullable=operand.nullable) monad.operand = operand def negate(monad): return monad.operand @@ -2384,8 +2399,8 @@ def getsql(monad, sqlquery=None): return [ [ 'NOT', monad.operand.getsql()[0] ] ] class ErrorSpecialFuncMonad(Monad): - def __init__(monad, translator, func): - Monad.__init__(monad, translator, func) + def __init__(monad, func): + Monad.__init__(monad, func) monad.func = func registered_functions = SQLTranslator.registered_functions = {} @@ -2402,7 +2417,6 @@ def __new__(meta, cls_name, bases, cls_dict): class FuncMonad(with_metaclass(FuncMonadMeta, Monad)): def __call__(monad, *args, **kwargs): - translator = monad.translator for arg in args: assert isinstance(arg, Monad) for value in kwargs.values(): @@ -2414,7 +2428,6 @@ def __call__(monad, *args, **kwargs): class FuncBufferMonad(FuncMonad): func = buffer def call(monad, source, encoding=None, errors=None): - translator = monad.translator if not isinstance(source, StringConstMonad): throw(TypeError) source = source.value if encoding is not None: @@ -2426,12 +2439,12 @@ def call(monad, source, encoding=None, errors=None): if PY2: if encoding and errors: source = source.encode(encoding, errors) elif encoding: source = source.encode(encoding) - return ConstMonad.new(translator, buffer(source)) + return ConstMonad.new(buffer(source)) else: if encoding and errors: value = buffer(source, encoding, errors) elif encoding: value = buffer(source, encoding) else: value = buffer(source) - return ConstMonad.new(translator, value) + return ConstMonad.new(value) class FuncBoolMonad(FuncMonad): func = bool @@ -2451,38 +2464,33 @@ def call(monad, x): class FuncDecimalMonad(FuncMonad): func = Decimal def call(monad, x): - translator = monad.translator if not isinstance(x, StringConstMonad): throw(TypeError) - return ConstMonad.new(translator, Decimal(x.value)) + return ConstMonad.new(Decimal(x.value)) class FuncDateMonad(FuncMonad): func = date def call(monad, year, month, day): - translator = monad.translator for arg, name in izip((year, month, day), ('year', 'month', 'day')): if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of date(year, month, day) function must be of 'int' type. " "Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) - return ConstMonad.new(translator, date(year.value, month.value, day.value)) + return ConstMonad.new(date(year.value, month.value, day.value)) def call_today(monad): - translator = monad.translator - return DateExprMonad(translator, date, [ 'TODAY' ], nullable=monad.nullable) + return DateExprMonad(date, [ 'TODAY' ], nullable=monad.nullable) class FuncTimeMonad(FuncMonad): func = time def call(monad, *args): - translator = monad.translator for arg, name in izip(args, ('hour', 'minute', 'second', 'microsecond')): if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of time(...) function must be of 'int' type. Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) - return ConstMonad.new(translator, time(*tuple(arg.value for arg in args))) + return ConstMonad.new(time(*tuple(arg.value for arg in args))) class FuncTimedeltaMonad(FuncMonad): func = timedelta def call(monad, days=None, seconds=None, microseconds=None, milliseconds=None, minutes=None, hours=None, weeks=None): - translator = monad.translator args = days, seconds, microseconds, milliseconds, minutes, hours, weeks for arg, name in izip(args, ('days', 'seconds', 'microseconds', 'milliseconds', 'minutes', 'hours', 'weeks')): if arg is None: continue @@ -2490,23 +2498,21 @@ def call(monad, days=None, seconds=None, microseconds=None, milliseconds=None, m "'%s' argument of timedelta(...) function must be of 'int' type. Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) value = timedelta(*(arg.value if arg is not None else 0 for arg in args)) - return ConstMonad.new(translator, value) + return ConstMonad.new(value) class FuncDatetimeMonad(FuncDateMonad): func = datetime def call(monad, year, month, day, hour=None, minute=None, second=None, microsecond=None): args = year, month, day, hour, minute, second, microsecond - translator = monad.translator for arg, name in izip(args, ('year', 'month', 'day', 'hour', 'minute', 'second', 'microsecond')): if arg is None: continue if not isinstance(arg, NumericMixin) or arg.type is not int: throw(TypeError, "'%s' argument of datetime(...) function must be of 'int' type. Got: %r" % (name, type2str(arg.type))) if not isinstance(arg, ConstMonad): throw(NotImplementedError) value = datetime(*(arg.value if arg is not None else 0 for arg in args)) - return ConstMonad.new(translator, value) + return ConstMonad.new(value) def call_now(monad): - translator = monad.translator - return DatetimeExprMonad(translator, datetime, [ 'NOW' ], nullable=monad.nullable) + return DatetimeExprMonad(datetime, [ 'NOW' ], nullable=monad.nullable) class FuncBetweenMonad(FuncMonad): func = between @@ -2515,22 +2521,20 @@ def call(monad, x, a, b): check_comparable(x, b, '<') if isinstance(x.type, EntityMeta): throw(TypeError, '%s instance cannot be argument of between() function: {EXPR}' % x.type.__name__) - translator = x.translator sql = [ 'BETWEEN', x.getsql()[0], a.getsql()[0], b.getsql()[0] ] - return BoolExprMonad(translator, sql, nullable=x.nullable or a.nullable or b.nullable) + return BoolExprMonad(sql, nullable=x.nullable or a.nullable or b.nullable) class FuncConcatMonad(FuncMonad): func = concat def call(monad, *args): if len(args) < 2: throw(TranslationError, 'concat() function requires at least two arguments') - translator = args[0].translator result_ast = [ 'CONCAT' ] for arg in args: t = arg.type if isinstance(t, EntityMeta) or type(t) in (tuple, SetType): throw(TranslationError, 'Invalid argument of concat() function: %s' % ast2src(arg.node)) result_ast.extend(arg.getsql()) - return ExprMonad.new(translator, unicode, result_ast, nullable=any(arg.nullable for arg in args)) + return ExprMonad.new(unicode, result_ast, nullable=any(arg.nullable for arg in args)) class FuncLenMonad(FuncMonad): func = len @@ -2559,10 +2563,9 @@ def call(monad, obj_monad, name_monad): class FuncCountMonad(FuncMonad): func = itertools.count, utils.count, core.count def call(monad, x=None, distinct=None): - translator = monad.translator if isinstance(x, StringConstMonad) and x.value == '*': x = None if x is not None: return x.count(distinct) - result = ExprMonad.new(translator, int, [ 'COUNT', None ], nullable=False) + result = ExprMonad.new(int, [ 'COUNT', None ], nullable=False) result.aggregated = True return result @@ -2594,7 +2597,6 @@ class FuncCoalesceMonad(FuncMonad): func = coalesce def call(monad, *args): if len(args) < 2: throw(TranslationError, 'coalesce() function requires at least two arguments') - translator = args[0].translator arg = args[0] t = arg.type result = [ [ sql ] for sql in arg.getsql() ] @@ -2604,7 +2606,7 @@ def call(monad, *args): result[i].append(sql) sql = [ [ 'COALESCE' ] + coalesce_args for coalesce_args in result ] if not isinstance(t, EntityMeta): sql = sql[0] - return ExprMonad.new(translator, t, sql, nullable=all(arg.nullable for arg in args)) + return ExprMonad.new(t, sql, nullable=all(arg.nullable for arg in args)) class FuncDistinctMonad(FuncMonad): func = utils.distinct, core.distinct @@ -2648,14 +2650,13 @@ def minmax(monad, sqlop, *args): args = list(args) for i, arg in enumerate(args): if arg.type is bool: - args[i] = NumericExprMonad(translator, int, [ 'TO_INT', arg.getsql() ], nullable=arg.nullable) + args[i] = NumericExprMonad(int, [ 'TO_INT', arg.getsql() ], nullable=arg.nullable) sql = [ sqlop, None ] + [ arg.getsql()[0] for arg in args ] - return ExprMonad.new(translator, t, sql, nullable=any(arg.nullable for arg in args)) + return ExprMonad.new(t, sql, nullable=any(arg.nullable for arg in args)) class FuncSelectMonad(FuncMonad): func = core.select def call(monad, queryset): - translator = monad.translator if not isinstance(queryset, QuerySetMonad): throw(TypeError, "'select' function expects generator expression, got: {EXPR}") return queryset @@ -2674,14 +2675,15 @@ def call(monad, expr): class DescMonad(Monad): def __init__(monad, expr): - Monad.__init__(monad, expr.translator, expr.type, nullable=expr.nullable) + Monad.__init__(monad, expr.type, nullable=expr.nullable) monad.expr = expr def getsql(monad): return [ [ 'DESC', item ] for item in monad.expr.getsql() ] class JoinMonad(Monad): - def __init__(monad, translator, type): - Monad.__init__(monad, translator, type) + def __init__(monad, type): + Monad.__init__(monad, type) + translator = monad.translator monad.hint_join_prev = translator.hint_join translator.hint_join = True def __call__(monad, x): @@ -2691,11 +2693,11 @@ def __call__(monad, x): class FuncRandomMonad(FuncMonad): func = random - def __init__(monad, translator, type): - FuncMonad.__init__(monad, translator, type) - translator.query_result_is_cacheable = False + def __init__(monad, type): + FuncMonad.__init__(monad, type) + monad.translator.query_result_is_cacheable = False def __call__(monad): - return NumericExprMonad(monad.translator, float, [ 'RANDOM' ], nullable=False) + return NumericExprMonad(float, [ 'RANDOM' ], nullable=False) class SetMixin(MonadMixin): forced_distinct = False @@ -2712,15 +2714,13 @@ def attrset_binop(monad, monad2): class AttrSetMonad(SetMixin, Monad): def __init__(monad, parent, attr): - translator = parent.translator item_type = normalize_type(attr.py_type) - Monad.__init__(monad, translator, SetType(item_type)) + Monad.__init__(monad, SetType(item_type)) monad.parent = parent monad.attr = attr monad.sqlquery = None monad.tableref = None def cmp(monad, op, monad2): - translator = monad.translator if type(monad2.type) is SetType \ and are_comparable_types(monad.type.item_type, monad2.type.item_type): pass elif monad.type != monad2.type: check_comparable(monad, monad2) @@ -2743,7 +2743,7 @@ def contains(monad, item, not_in=False): else: conditions += [ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(item.getsql(), expr_list) ] sql_ast = [ 'NOT_EXISTS' if not_in else 'EXISTS', from_ast, [ 'WHERE' ] + conditions ] - result = BoolExprMonad(translator, sql_ast, nullable=False) + result = BoolExprMonad(sql_ast, nullable=False) result.nogroup = True return result elif not not_in: @@ -2751,7 +2751,7 @@ def contains(monad, item, not_in=False): tableref = monad.make_tableref(translator.sqlquery) expr_list = monad.make_expr_list() expr_ast = sqland([ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(expr_list, item.getsql()) ]) - return BoolExprMonad(translator, expr_ast, nullable=False) + return BoolExprMonad(expr_ast, nullable=False) else: sqlquery = SqlQuery(translator.sqlquery) tableref = monad.make_tableref(sqlquery) @@ -2766,7 +2766,7 @@ def contains(monad, item, not_in=False): conditions.extend(sqlquery.conditions) from_ast[-1][-1] = sqland([ from_ast[-1][-1] ] + conditions) expr_ast = sqland([ [ 'IS_NULL', expr ] for expr in expr_list ]) - return BoolExprMonad(translator, expr_ast, nullable=False) + return BoolExprMonad(expr_ast, nullable=False) def getattr(monad, name): try: return Monad.getattr(monad, name) except AttributeError: pass @@ -2842,7 +2842,7 @@ def count(monad, distinct=None): else: sql_ast, optimized = monad._aggregated_scalar_subselect(make_aggr, extra_grouping) translator.aggregated_subquery_paths.add(monad.tableref.name_path) - result = ExprMonad.new(translator, int, sql_ast, nullable=False) + result = ExprMonad.new(int, sql_ast, nullable=False) if optimized: result.aggregated = True else: result.nogroup = True return result @@ -2886,7 +2886,7 @@ def make_aggr(expr_list): else: result_type = item_type translator.aggregated_subquery_paths.add(monad.tableref.name_path) - result = ExprMonad.new(monad.translator, result_type, sql_ast, nullable=func_name != 'SUM') + result = ExprMonad.new(result_type, sql_ast, nullable=func_name != 'SUM') if optimized: result.aggregated = True else: result.nogroup = True return result @@ -2894,19 +2894,16 @@ def nonzero(monad): sqlquery = monad._subselect() sql_ast = [ 'EXISTS', sqlquery.from_ast, [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] - translator = monad.translator - return BoolExprMonad(translator, sql_ast, nullable=False) + return BoolExprMonad(sql_ast, nullable=False) def negate(monad): sqlquery = monad._subselect() sql_ast = [ 'NOT_EXISTS', sqlquery.from_ast, [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] - translator = monad.translator - return BoolExprMonad(translator, sql_ast, nullable=False) + return BoolExprMonad(sql_ast, nullable=False) call_is_empty = negate def make_tableref(monad, sqlquery): parent = monad.parent attr = monad.attr - translator = monad.translator if isinstance(parent, ObjectMixin): parent_tableref = parent.tableref elif isinstance(parent, AttrSetMonad): parent_tableref = parent.make_tableref(sqlquery) else: assert False # pragma: no cover @@ -3041,7 +3038,7 @@ def __init__(monad, op, sqlop, left, right): assert type(result_type) is SetType if result_type.item_type not in numeric_types: throw(TypeError, _binop_errmsg % (type2str(left.type), type2str(right.type), op)) - Monad.__init__(monad, left.translator, result_type) + Monad.__init__(monad, result_type) monad.op = op monad.sqlop = sqlop monad.left = left @@ -3069,7 +3066,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): sql_ast = [ 'SELECT', [ 'AGGREGATES', aggr_ast ], sqlquery.from_ast, [ 'WHERE' ] + sqlquery.outer_conditions + sqlquery.conditions ] - result = ExprMonad.new(translator, result_type, sql_ast, nullable=func_name != 'SUM') + result = ExprMonad.new(result_type, sql_ast, nullable=func_name != 'SUM') result.nogroup = True else: if not translator.from_optimized: @@ -3079,7 +3076,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): translator.sqlquery.from_ast.extend(from_ast) translator.from_optimized = True sql_ast = aggr_ast - result = ExprMonad.new(translator, result_type, sql_ast, nullable=func_name != 'SUM') + result = ExprMonad.new(result_type, sql_ast, nullable=func_name != 'SUM') result.aggregated = True return result def getsql(monad, sqlquery=None): @@ -3104,13 +3101,12 @@ def getsql(monad, sqlquery=None): class QuerySetMonad(SetMixin, Monad): nogroup = True - def __init__(monad, translator, subtranslator): - monad.translator = translator - monad.subtranslator = subtranslator + def __init__(monad, subtranslator): item_type = subtranslator.expr_type - monad.item_type = item_type monad_type = SetType(item_type) - Monad.__init__(monad, translator, monad_type) + Monad.__init__(monad, monad_type) + monad.subtranslator = subtranslator + monad.item_type = item_type def requires_distinct(monad, joined=False): assert False def contains(monad, item, not_in=False): @@ -3197,7 +3193,7 @@ def contains(monad, item, not_in=False): having_ast += in_conditions else: where_ast += in_conditions sql_ast = [ 'NOT_EXISTS' if not_in else 'EXISTS' ] + subquery_ast[2:] - return BoolExprMonad(translator, sql_ast, nullable=False) + return BoolExprMonad(sql_ast, nullable=False) def nonzero(monad): subquery_ast = monad.subtranslator.construct_subquery_ast(distinct=False) expr_monads = monad.subtranslator.expr_monads @@ -3209,13 +3205,11 @@ def nonzero(monad): assert subquery_ast[3][0] == 'WHERE' subquery_ast[3].append(sql[0]) subquery_ast = [ 'EXISTS' ] + subquery_ast[2:] - translator = monad.translator - return BoolExprMonad(translator, subquery_ast, nullable=False) + return BoolExprMonad(subquery_ast, nullable=False) def negate(monad): sql = monad.nonzero().sql assert sql[0] == 'EXISTS' - translator = monad.translator - return BoolExprMonad(translator, [ 'NOT_EXISTS' ] + sql[1:], nullable=False) + return BoolExprMonad([ 'NOT_EXISTS' ] + sql[1:], nullable=False) def count(monad, distinct=None): distinct = distinct_from_monad(distinct) translator = monad.translator @@ -3257,11 +3251,10 @@ def count(monad, distinct=None): else: throw(NotImplementedError) # pragma: no cover if sql_ast is None: sql_ast = [ 'SELECT', select_ast, from_ast, where_ast ] - return ExprMonad.new(translator, int, sql_ast, nullable=False) + return ExprMonad.new(int, sql_ast, nullable=False) len = count def aggregate(monad, func_name, distinct=None, sep=None): distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) - translator = monad.translator sub = monad.subtranslator if sub.aggregated: throw(TranslationError, 'Too complex aggregation in {EXPR}') subquery_ast = sub.construct_subquery_ast(distinct=False) @@ -3292,7 +3285,7 @@ def aggregate(monad, func_name, distinct=None, sep=None): result_type = unicode else: result_type = expr_type - return ExprMonad.new(translator, result_type, sql_ast, func_name != 'SUM') + return ExprMonad.new(result_type, sql_ast, func_name != 'SUM') def call_count(monad, distinct=None): return monad.count(distinct=distinct) def call_sum(monad, distinct=None): @@ -3309,10 +3302,7 @@ def call_group_concat(monad, sep=None, distinct=None): throw(TypeError, '`sep` option of `group_concat` should be type of str. Got: %s' % type(sep).__name__) return monad.aggregate('GROUP_CONCAT', distinct, sep=sep) def getsql(monad): - translator = monad.translator - sub = monad.subtranslator - subquery_ast = sub.construct_subquery_ast() - return subquery_ast + return monad.subtranslator.construct_subquery_ast() def find_or_create_having_ast(sections): groupby_offset = None From e22ed2d3056aa75840b96b084c95f9fd35ed3de7 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 5 Aug 2018 17:42:58 +0300 Subject: [PATCH 367/547] Force joining for genexpr loop variables --- pony/orm/sqltranslation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index ac674ac7d..04607f920 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -366,6 +366,7 @@ def check_name_is_single(): if j == last_index: name_path = name else: name_path += '-' + attr.name tableref = translator.sqlquery.add_tableref(name_path, parent_tableref, attr) + tableref.make_join(pk_only=True) if j == last_index: translator.namespace[name] = ObjectIterMonad(tableref, tableref.entity) if can_affect_distinct is not None: From 69ddfccb285c3aa33940f7cc84cc0b742a9c29f1 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 27 Jul 2018 15:18:38 +0300 Subject: [PATCH 368/547] Add more tests for `exists` --- pony/orm/tests/test_declarative_query_set_monad.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pony/orm/tests/test_declarative_query_set_monad.py b/pony/orm/tests/test_declarative_query_set_monad.py index b0a384a20..c0c989662 100644 --- a/pony/orm/tests/test_declarative_query_set_monad.py +++ b/pony/orm/tests/test_declarative_query_set_monad.py @@ -209,8 +209,16 @@ def test_avg_6(self): select(avg(s.scholarship, distinct=True) for s in Student)[:] self.assertTrue('AVG(DISTINCT' in db.last_sql) - def test_exists(self): - result = set(select(g for g in Group if exists(s for s in g.students if s.name == 'S1'))) + def test_exists_1(self): + result = set(select(g for g in Group if exists(s for s in g.students if s.age < 23))) + self.assertEqual(result, {Group[1]}) + + def test_exists_2(self): + result = set(select(g for g in Group if exists(s.age < 23 for s in g.students))) + self.assertEqual(result, {Group[1]}) + + def test_exists_3(self): + result = set(select(g for g in Group if (s.age < 23 for s in g.students))) self.assertEqual(result, {Group[1]}) def test_negate(self): From 9dc81a590d203d304236be98203df4c697c82743 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Tue, 7 Aug 2018 17:17:13 +0300 Subject: [PATCH 369/547] New feature: @db.on_connect decorator --- pony/orm/core.py | 53 +++++++++++++++++++++++++++++----- pony/orm/dbapiprovider.py | 9 ++++-- pony/orm/dbproviders/oracle.py | 2 +- pony/orm/tests/testutils.py | 2 +- 4 files changed, 54 insertions(+), 12 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 20364a2ce..5efd72681 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -38,7 +38,7 @@ 'Warning', 'Error', 'InterfaceError', 'DatabaseError', 'DataError', 'OperationalError', 'IntegrityError', 'InternalError', 'ProgrammingError', 'NotSupportedError', - 'OrmError', 'ERDiagramError', 'DBSchemaError', 'MappingError', + 'OrmError', 'ERDiagramError', 'DBSchemaError', 'MappingError', 'BindingError', 'TableDoesNotExist', 'TableIsNotEmpty', 'ConstraintError', 'CacheIndexError', 'ObjectNotFound', 'MultipleObjectsFoundError', 'TooManyObjectsFoundError', 'OperationWithDeletedObjectError', 'TransactionError', 'ConnectionClosedError', 'TransactionIntegrityError', 'IsolationError', @@ -134,6 +134,7 @@ class OrmError(Exception): pass class ERDiagramError(OrmError): pass class DBSchemaError(OrmError): pass class MappingError(OrmError): pass +class BindingError(OrmError): pass class TableDoesNotExist(OrmError): pass class TableIsNotEmpty(OrmError): pass @@ -636,6 +637,31 @@ def db_decorator(func, *args, **kwargs): if web: throw(web.Http404NotFound) raise +known_providers = ('sqlite', 'postgres', 'mysql', 'oracle') + +class OnConnectDecorator(object): + + @staticmethod + def check_provider(provider): + if provider: + if not isinstance(provider, basestring): + throw(TypeError, "'provider' option should be type of 'string', got %r" % type(provider).__name__) + if provider not in known_providers: + throw(BindingError, 'Unknown provider %s' % provider) + + def __init__(self, database, provider): + OnConnectDecorator.check_provider(provider) + self.provider = provider + self.database = database + + def __call__(self, func=None, provider=None): + if isinstance(func, types.FunctionType): + self.database._on_connect_funcs.append((func, provider or self.provider)) + if not provider and func is basestring: + provider = func + OnConnectDecorator.check_provider(provider) + return OnConnectDecorator(self.database, provider) + class Database(object): def __deepcopy__(self, memo): return self # Database cannot be cloned by deepcopy() @@ -658,15 +684,22 @@ def __init__(self, *args, **kwargs): self._global_stats_lock = RLock() self._dblocal = DbLocal() - self.provider = None + self.on_connect = OnConnectDecorator(self, None) + self._on_connect_funcs = [] + self.provider = self.provider_name = None if args or kwargs: self._bind(*args, **kwargs) + def call_on_connect(database, con): + for func, provider in database._on_connect_funcs: + if not provider or provider == database.provider_name: + func(database, con) + con.commit() @cut_traceback def bind(self, *args, **kwargs): self._bind(*args, **kwargs) def _bind(self, *args, **kwargs): # argument 'self' cannot be named 'database', because 'database' can be in kwargs if self.provider is not None: - throw(TypeError, 'Database object was already bound to %s provider' % self.provider.dialect) + throw(BindingError, 'Database object was already bound to %s provider' % self.provider.dialect) if args: provider, args = args[0], args[1:] elif 'provider' not in kwargs: throw(TypeError, 'Database provider is not specified') else: provider = kwargs.pop('provider') @@ -676,6 +709,7 @@ def _bind(self, *args, **kwargs): if not isinstance(provider, basestring): throw(TypeError) if provider == 'pygresql': throw(TypeError, 'Pony no longer supports PyGreSQL module. Please use psycopg2 instead.') + self.provider_name = provider provider_module = import_module('pony.orm.dbproviders.' + provider) provider_cls = provider_module.provider_cls self.provider = provider_cls(*args, **kwargs) @@ -839,7 +873,7 @@ def _exec_sql(database, sql, arguments=None, returning_id=False, start_transacti def generate_mapping(database, filename=None, check_tables=True, create_tables=False): provider = database.provider if provider is None: throw(MappingError, 'Database object is not bound with a provider yet') - if database.schema: throw(MappingError, 'Mapping was already generated') + if database.schema: throw(BindingError, 'Mapping was already generated') if filename is not None: throw(NotImplementedError) schema = database.schema = provider.dbschema_cls(provider) entities = list(sorted(database.entities.values(), key=attrgetter('_id_'))) @@ -1631,12 +1665,17 @@ def connect(cache): assert cache.connection is None if cache.in_transaction: throw(ConnectionClosedError, 'Transaction cannot be continued because database connection failed') - provider = cache.database.provider - connection = provider.connect() - try: provider.set_transaction_mode(connection, cache) # can set cache.in_transaction + database = cache.database + provider = database.provider + connection, is_new_connection = provider.connect() + if is_new_connection: + database.call_on_connect(connection) + try: + provider.set_transaction_mode(connection, cache) # can set cache.in_transaction except: provider.drop(connection, cache) raise + cache.connection = connection return connection def reconnect(cache, exc): diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index ef7ad8bd8..10b49528b 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -111,7 +111,7 @@ def __init__(provider, *args, **kwargs): pool_mockup = kwargs.pop('pony_pool_mockup', None) if pool_mockup: provider.pool = pool_mockup else: provider.pool = provider.get_pool(*args, **kwargs) - connection = provider.connect() + connection, is_new_connection = provider.connect() provider.inspect_connection(connection) provider.release(connection) @@ -321,12 +321,15 @@ def connect(pool): pool.forked_connections.append((pool.con, pool.pid)) pool.con = pool.pid = None core = pony.orm.core + is_new_connection = False if pool.con is None: if core.local.debug: core.log_orm('GET NEW CONNECTION') + is_new_connection = True pool._connect() pool.pid = pid - elif core.local.debug: core.log_orm('GET CONNECTION FROM THE LOCAL POOL') - return pool.con + elif core.local.debug: + core.log_orm('GET CONNECTION FROM THE LOCAL POOL') + return pool.con, is_new_connection def _connect(pool): pool.con = pool.dbapi_module.connect(*pool.args, **pool.kwargs) def release(pool, con): diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 2928c58c1..50a6df949 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -581,7 +581,7 @@ def connect(pool): if core.local.debug: log_orm('GET CONNECTION') con = pool.cx_pool.acquire() con.outputtypehandler = output_type_handler - return con + return con, True def release(pool, con): pool.cx_pool.release(con) def drop(pool, con): diff --git a/pony/orm/tests/testutils.py b/pony/orm/tests/testutils.py index e87c65952..13f507fff 100644 --- a/pony/orm/tests/testutils.py +++ b/pony/orm/tests/testutils.py @@ -84,7 +84,7 @@ class TestPool(object): def __init__(pool, database): pool.database = database def connect(pool): - return TestConnection(pool.database) + return TestConnection(pool.database), True def release(pool, con): pass def drop(pool, con): From 148e0134ce4b01c7fcdb24720ed063b52b3622d7 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 8 Aug 2018 16:28:32 +0300 Subject: [PATCH 370/547] Closes #371: support of explicit casting of JSON values --- pony/orm/dbproviders/mysql.py | 2 ++ pony/orm/sqltranslation.py | 11 +++++++++++ pony/orm/tests/test_json.py | 11 +++++++++++ 3 files changed, 24 insertions(+) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 14029a343..3ea432619 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -104,6 +104,8 @@ def JSON_VALUE(builder, expr, path, type): return 'NULLIF(', result, ", CAST('null' as JSON))" if type in (bool, int): return 'CAST(', result, ' AS SIGNED)' + if type is float: + return 'CAST(', result, ' AS DOUBLE)' return 'json_unquote(', result, ')' def JSON_NONZERO(builder, expr): return 'COALESCE(CAST(', builder(expr), ''' as CHAR), 'null') NOT IN ('null', 'false', '0', '""', '[]', '{}')''' diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 04607f920..ded326a02 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1466,6 +1466,8 @@ def abs(monad): throw(TypeError) def cast_from_json(monad, type): assert False, monad def to_int(monad): return NumericExprMonad(int, [ 'TO_INT', monad.getsql()[0] ], nullable=monad.nullable) + def to_str(monad): + return StringExprMonad(unicode, [ 'TO_STR', monad.getsql()[0] ], nullable=monad.nullable) def to_real(monad): return NumericExprMonad(float, [ 'TO_REAL', monad.getsql()[0] ], nullable=monad.nullable) @@ -2207,6 +2209,10 @@ def get_path(monad): return monad, path def to_int(monad): return monad.cast_from_json(int) + def to_str(monad): + return monad.cast_from_json(unicode) + def to_real(monad): + return monad.cast_from_json(float) def cast_from_json(monad, type): translator = monad.translator if issubclass(type, Json): @@ -2457,6 +2463,11 @@ class FuncIntMonad(FuncMonad): def call(monad, x): return x.to_int() +class FuncStrMonad(FuncMonad): + func = str + def call(monad, x): + return x.to_str() + class FuncFloatMonad(FuncMonad): func = float def call(monad, x): diff --git a/pony/orm/tests/test_json.py b/pony/orm/tests/test_json.py index 7e4dc42fd..23a50a142 100644 --- a/pony/orm/tests/test_json.py +++ b/pony/orm/tests/test_json.py @@ -609,6 +609,17 @@ def test_none_for_nonexistent_path(self): p = get(p for p in self.Product if p.info['some_attr'] is None) self.assertTrue(p) + @db_session + def test_str_cast(self): + p = get(coalesce(str(p.name), 'empty') for p in self.Product) + self.assertTrue('AS text' in self.db.last_sql) + + @db_session + def test_int_cast(self): + p = get(coalesce(int(p.info['os']['version']), 0) for p in self.Product) + self.assertTrue('as integer' in self.db.last_sql) + + def test_nonzero(self): Product = self.Product with db_session: From d00752047c91246db1d34f9fc82d4b4fd7bc4d3c Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 8 Aug 2018 17:56:47 +0300 Subject: [PATCH 371/547] on_connect fix --- pony/orm/core.py | 1 + pony/orm/dbapiprovider.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/pony/orm/core.py b/pony/orm/core.py index 5efd72681..1500d2164 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -712,6 +712,7 @@ def _bind(self, *args, **kwargs): self.provider_name = provider provider_module = import_module('pony.orm.dbproviders.' + provider) provider_cls = provider_module.provider_cls + kwargs['pony_call_on_connect'] = self.call_on_connect self.provider = provider_cls(*args, **kwargs) @property def last_sql(database): diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 10b49528b..2f185018c 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -109,9 +109,12 @@ class DBAPIProvider(object): def __init__(provider, *args, **kwargs): pool_mockup = kwargs.pop('pony_pool_mockup', None) + call_on_connect = kwargs.pop('pony_call_on_connect', None) if pool_mockup: provider.pool = pool_mockup else: provider.pool = provider.get_pool(*args, **kwargs) connection, is_new_connection = provider.connect() + if call_on_connect: + call_on_connect(connection) provider.inspect_connection(connection) provider.release(connection) From 79cbee90ec53d1e04e6608239134568675449a2f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 8 Aug 2018 18:39:57 +0300 Subject: [PATCH 372/547] Update changelog and pony version: 0.7.6-dev -> 0.7.6rc1 --- CHANGELOG.md | 23 +++++++++++++++++++++++ pony/__init__.py | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index af46c62b9..8c3c325b9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,26 @@ +# Pony ORM Release 0.7.6rc1 (2018-08-08) + +## New features + +* f-strings support in queries: select(f'{s.name} - {s.age}' for s in Student) +* #344: It is now possible to specify offset without limit: `query.limit(offset=10)` +* #371: Support of explicit casting of JSON expressions to `str`, `int` or `float` +* `@db.on_connect` decorator added + +## Bugfixes + +* Fix bulk delete bug introduced in 0.7.4 +* #370 Fix memory leak introduced in 0.7.4 +* Now exists() in query does not throw away condition in generator expression: `exists(s.gpa > 3 for s in Student)` +* #373: 0.7.4/0.7.5 breaks queries using the `in` operator to test membership of another query result +* #374: `auto=True` can be used with all PrimaryKey types, not only int +* #369: Make QueryResult looks like a list object again: add concatenation with lists, `.shuffle()` and `.to_list()` methods +* #355: Fix binary primary keys `PrimaryKey(buffer)` in Python2 +* Interactive mode support for PyCharm console +* Fix wrong table aliases in complex queries +* Fix query optimization code for complex queries + + # Pony ORM Release 0.7.5 (2018-07-24) ## Bugfixes diff --git a/pony/__init__.py b/pony/__init__.py index 388a44fc2..32e85918a 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.6-dev' +__version__ = '0.7.6rc1' uid = str(random.randint(1, 1000000)) From 585ab81d5031f11ed6cc28fc036083355e296abf Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 9 Aug 2018 14:47:56 +0300 Subject: [PATCH 373/547] Fixed a bug with hybrid properties that use external functions --- pony/orm/core.py | 5 +++-- .../test_hybrid_methods_and_properties.py | 18 ++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 1500d2164..766d2c855 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5433,9 +5433,10 @@ def _get_translator(query, query_key, vars): if translator is not None: if translator.func_extractors_map: for func, func_extractors in iteritems(translator.func_extractors_map): - func_filter_num = translator.filter_num, 'func', id(func) + func_id = id(func.func_code if PY2 else func.__code__) + func_filter_num = translator.filter_num, 'func', func_id func_vars, func_vartypes = extract_vars( - func_filter_num, func_extractors, func.__globals__, {}, func.__closure__) # todo closures + func_id, func_filter_num, func_extractors, func.__globals__, {}, func.__closure__) # todo closures database.provider.normalize_vars(func_vars, func_vartypes) new_vars.update(func_vars) all_func_vartypes.update(func_vartypes) diff --git a/pony/orm/tests/test_hybrid_methods_and_properties.py b/pony/orm/tests/test_hybrid_methods_and_properties.py index 6190e3fb0..64e2f1cda 100644 --- a/pony/orm/tests/test_hybrid_methods_and_properties.py +++ b/pony/orm/tests/test_hybrid_methods_and_properties.py @@ -15,6 +15,10 @@ class Person(db.Entity): def full_name(self): return self.first_name + ' ' + self.last_name + @property + def full_name_2(self): + return concat(self.first_name, ' ', self.last_name) # tests using of function `concat` from external scope + @property def has_car(self): return not self.cars.is_empty() @@ -33,6 +37,10 @@ def cars_price(self): def incorrect_full_name(self): return self.first_name + ' ' + p.last_name # p is FakePerson instance here + @classmethod + def find_by_full_name(cls, full_name): + return cls.select(lambda p: p.full_name_2 == full_name) + class FakePerson(object): pass @@ -139,6 +147,16 @@ def test14(self): persons = select(p.incorrect_full_name for p in Person if p.has_car)[:] self.assertEqual(set(persons), {'Alexander ***', 'Alexei ***', 'Alexander ***'}) + @db_session + def test15(self): + # Test repeated use of the same generator with hybrid method/property that uses funciton from external scope + result = Person.find_by_full_name('Alexander Kozlovsky') + self.assertEqual(set(obj.last_name for obj in result), {'Kozlovsky'}) + result = Person.find_by_full_name('Alexander Kozlovsky') + self.assertEqual(set(obj.last_name for obj in result), {'Kozlovsky'}) + result = Person.find_by_full_name('Alexander Tischenko') + self.assertEqual(set(obj.last_name for obj in result), {'Tischenko'}) + if __name__ == '__main__': unittest.main() From d88cb15072e740ed6ec316390d5a4502b3e3c8ae Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 9 Aug 2018 14:13:57 +0300 Subject: [PATCH 374/547] Improved error message --- pony/orm/core.py | 2 +- pony/orm/tests/test_declarative_exceptions.py | 24 +++++++++---------- pony/orm/tests/test_declarative_func_monad.py | 2 +- .../tests/test_declarative_sqltranslator2.py | 2 +- pony/orm/tests/test_query.py | 4 ++-- pony/orm/tests/test_raw_sql.py | 4 ++-- .../tests/test_select_from_select_queries.py | 2 +- 7 files changed, 20 insertions(+), 20 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 766d2c855..256b26f1e 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -202,7 +202,7 @@ def __init__(exc, msg, original_exc): class ExprEvalError(TranslationError): def __init__(exc, src, cause): assert isinstance(cause, Exception) - msg = '%s raises %s: %s' % (src, type(cause).__name__, str(cause)) + msg = '`%s` raises %s: %s' % (src, type(cause).__name__, str(cause)) TranslationError.__init__(exc, msg) exc.cause = cause diff --git a/pony/orm/tests/test_declarative_exceptions.py b/pony/orm/tests/test_declarative_exceptions.py index cd04aa263..1900b9136 100644 --- a/pony/orm/tests/test_declarative_exceptions.py +++ b/pony/orm/tests/test_declarative_exceptions.py @@ -72,26 +72,26 @@ def test4(self): def test5(self): select(s for s in Student if s.name.upper(**{'a':'b', 'c':'d'})) - @raises_exception(ExprEvalError, "1 in 2 raises TypeError: argument of type 'int' is not iterable" if not PYPY else - "1 in 2 raises TypeError: 'int' object is not iterable") + @raises_exception(ExprEvalError, "`1 in 2` raises TypeError: argument of type 'int' is not iterable" if not PYPY else + "`1 in 2` raises TypeError: 'int' object is not iterable") def test6(self): select(s for s in Student if 1 in 2) @raises_exception(NotImplementedError, 'Group[s.group.number]') def test7(self): select(s for s in Student if Group[s.group.number].dept.number == 44) - @raises_exception(ExprEvalError, "Group[123, 456].dept.number == 44 raises TypeError: Invalid count of attrs in Group primary key (2 instead of 1)") + @raises_exception(ExprEvalError, "`Group[123, 456].dept.number == 44` raises TypeError: Invalid count of attrs in Group primary key (2 instead of 1)") def test8(self): select(s for s in Student if Group[123, 456].dept.number == 44) - @raises_exception(ExprEvalError, "Course[123] raises TypeError: Invalid count of attrs in Course primary key (1 instead of 2)") + @raises_exception(ExprEvalError, "`Course[123]` raises TypeError: Invalid count of attrs in Course primary key (1 instead of 2)") def test9(self): select(s for s in Student if Course[123] in s.courses) @raises_exception(TypeError, "Incomparable types '%s' and 'float' in expression: s.name < s.gpa" % unicode.__name__) def test10(self): select(s for s in Student if s.name < s.gpa) - @raises_exception(ExprEvalError, "Group(101) raises TypeError: Group constructor accept only keyword arguments. Got: 1 positional argument") + @raises_exception(ExprEvalError, "`Group(101)` raises TypeError: Group constructor accept only keyword arguments. Got: 1 positional argument") def test11(self): select(s for s in Student if s.group == Group(101)) - @raises_exception(ExprEvalError, "Group[date(2011, 1, 2)] raises TypeError: Value type for attribute Group.number must be int. Got: %r" % date) + @raises_exception(ExprEvalError, "`Group[date(2011, 1, 2)]` raises TypeError: Value type for attribute Group.number must be int. Got: %r" % date) def test12(self): select(s for s in Student if s.group == Group[date(2011, 1, 2)]) @raises_exception(TypeError, "Unsupported operand types 'int' and '%s' for operation '+' in expression: s.group.number + s.name" % unicode.__name__) @@ -142,7 +142,7 @@ def test26(self): @raises_exception(AttributeError, "Entity Group does not have attribute foo: s.group.foo") def test27(self): select(s.name for s in Student if s.group.foo.bar == 10) - @raises_exception(ExprEvalError, "g.dept.foo.bar raises AttributeError: 'Department' object has no attribute 'foo'") + @raises_exception(ExprEvalError, "`g.dept.foo.bar` raises AttributeError: 'Department' object has no attribute 'foo'") def test28(self): g = Group[101] select(s for s in Student if s.name == g.dept.foo.bar) @@ -153,8 +153,8 @@ def test29(self): @raises_exception(NotImplementedError, "date(s.id, 1, 1)") def test30(self): select(s for s in Student if s.dob < date(s.id, 1, 1)) - @raises_exception(ExprEvalError, "max() raises TypeError: max expected 1 arguments, got 0" if not PYPY else - "max() raises TypeError: max() expects at least one argument") + @raises_exception(ExprEvalError, "`max()` raises TypeError: max expected 1 arguments, got 0" if not PYPY else + "`max()` raises TypeError: max() expects at least one argument") def test31(self): select(s for s in Student if s.id < max()) @raises_exception(TypeError, "Incomparable types 'Student' and 'Course' in expression: s in s.courses") @@ -182,9 +182,9 @@ def test38(self): def test39(self): select(s for s in Student if s.name.strip(1, 2, 3)) @raises_exception(ExprEvalError, - "len(1, 2) == 3 raises TypeError: len() takes exactly 1 argument (2 given)" if PYPY2 else - "len(1, 2) == 3 raises TypeError: len() takes 1 positional argument but 2 were given" if PYPY else - "len(1, 2) == 3 raises TypeError: len() takes exactly one argument (2 given)") + "`len(1, 2) == 3` raises TypeError: len() takes exactly 1 argument (2 given)" if PYPY2 else + "`len(1, 2) == 3` raises TypeError: len() takes 1 positional argument but 2 were given" if PYPY else + "`len(1, 2) == 3` raises TypeError: len() takes exactly one argument (2 given)") def test40(self): select(s for s in Student if len(1, 2) == 3) @raises_exception(TypeError, "Function sum() expects query or items of numeric type, got 'Student' in sum(s for s in Student if s.group == g)") diff --git a/pony/orm/tests/test_declarative_func_monad.py b/pony/orm/tests/test_declarative_func_monad.py index c7e87f30f..656c3de48 100644 --- a/pony/orm/tests/test_declarative_func_monad.py +++ b/pony/orm/tests/test_declarative_func_monad.py @@ -113,7 +113,7 @@ def test_datetime_func4(self): def test_datetime_now1(self): result = set(select(s for s in Student if s.dob < date.today())) self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) - @raises_exception(ExprEvalError, "1 < datetime.now() raises TypeError: " + ( + @raises_exception(ExprEvalError, "`1 < datetime.now()` raises TypeError: " + ( "can't compare 'datetime' to 'int'" if PYPY2 else "unorderable types: int < datetime" if PYPY else "can't compare datetime.datetime to int" if PY2 else diff --git a/pony/orm/tests/test_declarative_sqltranslator2.py b/pony/orm/tests/test_declarative_sqltranslator2.py index ba684ee7e..a45889e60 100644 --- a/pony/orm/tests/test_declarative_sqltranslator2.py +++ b/pony/orm/tests/test_declarative_sqltranslator2.py @@ -177,7 +177,7 @@ def test_exception2(self): get(s for s in Student) def test_exists(self): result = exists(s for s in Student) - @raises_exception(ExprEvalError, "db.FooBar raises AttributeError: 'Database' object has no attribute 'FooBar'") + @raises_exception(ExprEvalError, "`db.FooBar` raises AttributeError: 'Database' object has no attribute 'FooBar'") def test_entity_not_found(self): select(s for s in db.Student for g in db.FooBar) def test_keyargs1(self): diff --git a/pony/orm/tests/test_query.py b/pony/orm/tests/test_query.py index 070e15ea4..911638ac6 100644 --- a/pony/orm/tests/test_query.py +++ b/pony/orm/tests/test_query.py @@ -47,8 +47,8 @@ def test2(self): def test3(self): g = Group[1] select(s for s in g.students) - @raises_exception(ExprEvalError, "a raises NameError: global name 'a' is not defined" if PYPY2 else - "a raises NameError: name 'a' is not defined") + @raises_exception(ExprEvalError, "`a` raises NameError: global name 'a' is not defined" if PYPY2 else + "`a` raises NameError: name 'a' is not defined") def test4(self): select(a for s in Student) @raises_exception(TypeError, "Incomparable types '%s' and 'list' in expression: s.name == x" % unicode.__name__) diff --git a/pony/orm/tests/test_raw_sql.py b/pony/orm/tests/test_raw_sql.py index f56247b01..99ab1aa86 100644 --- a/pony/orm/tests/test_raw_sql.py +++ b/pony/orm/tests/test_raw_sql.py @@ -160,8 +160,8 @@ def test_19(self): @db_session @raises_exception(ExprEvalError, - "raw_sql('p.dob < $x') raises NameError: global name 'x' is not defined" if PYPY2 else - "raw_sql('p.dob < $x') raises NameError: name 'x' is not defined") + "`raw_sql('p.dob < $x')` raises NameError: global name 'x' is not defined" if PYPY2 else + "`raw_sql('p.dob < $x')` raises NameError: name 'x' is not defined") def test_20(self): # testing for situation where parameter variable is missing select(p for p in Person if raw_sql('p.dob < $x'))[:] diff --git a/pony/orm/tests/test_select_from_select_queries.py b/pony/orm/tests/test_select_from_select_queries.py index 1b3e88997..587f4205d 100644 --- a/pony/orm/tests/test_select_from_select_queries.py +++ b/pony/orm/tests/test_select_from_select_queries.py @@ -88,7 +88,7 @@ def test_6(self): # selecting hybrid property in the first query self.assertEqual(db.last_sql.count('SELECT'), 1) @db_session - @raises_exception(ExprEvalError, "s.scholarship > 0 raises NameError: name 's' is not defined") + @raises_exception(ExprEvalError, "`s.scholarship > 0` raises NameError: name 's' is not defined") def test_7(self): # test access to original query var name from the new query q = select(s.first_name for s in Student if s.scholarship < 500) q2 = select(x for x in q if s.scholarship > 0) From 260795ead1cdff8f1496aa600c3c20924dc27607 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 10 Aug 2018 12:26:43 +0300 Subject: [PATCH 375/547] Update changelog and pony version: 0.7.6rc1 -> 0.7.6 --- CHANGELOG.md | 13 ++++++++++--- pony/__init__.py | 2 +- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8c3c325b9..8ffffbd27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,15 @@ +# Pony ORM Release 0.7.6 (2018-08-10) + +## Bugfixes + +* Fixed a bug with hybrid properties that use external functions + + # Pony ORM Release 0.7.6rc1 (2018-08-08) ## New features -* f-strings support in queries: select(f'{s.name} - {s.age}' for s in Student) +* f-strings support in queries: `select(f'{s.name} - {s.age}' for s in Student)` * #344: It is now possible to specify offset without limit: `query.limit(offset=10)` * #371: Support of explicit casting of JSON expressions to `str`, `int` or `float` * `@db.on_connect` decorator added @@ -11,9 +18,9 @@ * Fix bulk delete bug introduced in 0.7.4 * #370 Fix memory leak introduced in 0.7.4 -* Now exists() in query does not throw away condition in generator expression: `exists(s.gpa > 3 for s in Student)` +* Now `exists()` in query does not throw away condition in generator expression: `exists(s.gpa > 3 for s in Student)` * #373: 0.7.4/0.7.5 breaks queries using the `in` operator to test membership of another query result -* #374: `auto=True` can be used with all PrimaryKey types, not only int +* #374: `auto=True` can be used with all PrimaryKey types, not only `int` * #369: Make QueryResult looks like a list object again: add concatenation with lists, `.shuffle()` and `.to_list()` methods * #355: Fix binary primary keys `PrimaryKey(buffer)` in Python2 * Interactive mode support for PyCharm console diff --git a/pony/__init__.py b/pony/__init__.py index 32e85918a..9cbb066e6 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.6rc1' +__version__ = '0.7.6' uid = str(random.randint(1, 1000000)) From 95b3a24db8289d30631078b6a9aa01d6512983fc Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 22 Sep 2018 14:44:52 +0300 Subject: [PATCH 376/547] Update Pony version: 0.7.6 -> 0.7.7-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index 9cbb066e6..de1834556 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.6' +__version__ = '0.7.7-dev' uid = str(random.randint(1, 1000000)) From 5cc8fa8fc1fd701b9587eb14500a18cf896f2e0a Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 19 Sep 2018 18:26:17 +0300 Subject: [PATCH 377/547] typo: NotImplemented used instead of NotImplementedError --- pony/orm/sqltranslation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index ded326a02..946f224e7 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -255,11 +255,11 @@ def init(translator, tree, parent_translator, code_key=None, filter_num=None, ex elif isinstance(assign, ast.AssName): ass_names = (assign,) else: - throw(NotImplemented, ast2src(assign)) + throw(NotImplementedError, ast2src(assign)) for ass_name in ass_names: if not isinstance(ass_name, ast.AssName): - throw(NotImplemented, ast2src(ass_name)) + throw(NotImplementedError, ast2src(ass_name)) if ass_name.flags != 'OP_ASSIGN': throw(TypeError, ast2src(ass_name)) From 810b92f6d0d58854ee74681518634421604faf2d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 18 Aug 2018 02:12:34 +0300 Subject: [PATCH 378/547] Negative JSON array indexes in SQLite --- pony/orm/dbproviders/sqlite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index ff3196b04..dc356fc6f 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -433,7 +433,7 @@ def py_json_unwrap(value): path_cache = {} -json_path_re = re.compile(r'\[(\d+)\]|\.(?:(\w+)|"([^"]*)")', re.UNICODE) +json_path_re = re.compile(r'\[(-?\d+)\]|\.(?:(\w+)|"([^"]*)")', re.UNICODE) def _parse_path(path): if path in path_cache: From 00440cfc5353f865b1f0f43f5f244a9c75103dd9 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 19 Sep 2018 17:17:17 +0300 Subject: [PATCH 379/547] Refactoring: normalize EntityIter to Entity --- pony/orm/core.py | 4 +--- pony/orm/ormtypes.py | 3 ++- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 256b26f1e..e3c9ed7a7 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5359,9 +5359,7 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False prev_query = origin._query_result._query else: prev_query = None - if isinstance(origin, EntityIter): - origin = origin.entity - elif not isinstance(origin, EntityMeta): + if not isinstance(origin, EntityMeta): if node.src == '.0': throw(TypeError, 'Query can only iterate over entity or another query (not a list of objects)') throw(TypeError, 'Cannot iterate over non-entity object %s' % node.src) diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index c29942350..36634c48c 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -158,7 +158,8 @@ def normalize(value): return SetType(value), value if t.__name__ == 'EntityIter': - return SetType(value.entity), value + entity = value.entity + return SetType(entity), entity if PY2 and isinstance(value, str): try: From 3e8a952dbc494556f9cf199368805b7d77736e18 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Fri, 17 Aug 2018 14:02:01 +0300 Subject: [PATCH 380/547] Improved error message --- pony/orm/tests/test_declarative_exceptions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/tests/test_declarative_exceptions.py b/pony/orm/tests/test_declarative_exceptions.py index 1900b9136..718343b40 100644 --- a/pony/orm/tests/test_declarative_exceptions.py +++ b/pony/orm/tests/test_declarative_exceptions.py @@ -54,11 +54,11 @@ def tearDown(self): def test1(self): x = 10 select(s for s in Student for x in s.name) - @raises_exception(TranslationError, "Inside declarative query, iterator must be entity. Got: for i in x") + @raises_exception(TranslationError, "Inside declarative query, iterator must be entity or query. Got: for i in x") def test2(self): x = [1, 2, 3] select(s for s in Student for i in x) - @raises_exception(TranslationError, "Inside declarative query, iterator must be entity. Got: for s2 in g.students") + @raises_exception(TranslationError, "Inside declarative query, iterator must be entity or query. Got: for s2 in g.students") def test3(self): g = Group[101] select(s for s in Student for s2 in g.students) From da35e18daa98b3e564f40122a9fc4f8aacd1ebfd Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 22 Sep 2018 14:06:31 +0300 Subject: [PATCH 381/547] Support of limit in "select from select" type of queries --- pony/orm/core.py | 20 +++-- pony/orm/ormtypes.py | 14 +++- pony/orm/sqltranslation.py | 74 +++++++++++++++---- .../tests/test_select_from_select_queries.py | 32 ++++++++ 4 files changed, 115 insertions(+), 25 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index e3c9ed7a7..56c774594 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5308,14 +5308,14 @@ def extract_vars(code_key, filter_num, extractors, globals, locals, cells=None): value = make_query((value,), frame_depth=None) if isinstance(value, QueryResultIterator): - query_result = value._query_result - if query_result._items: - value = tuple(query_result._items[value._position:]) - else: - value = value._query_result._query + qr = value._query_result + value = qr if not qr._items else tuple(qr._items[value._position:]) + + if isinstance(value, QueryResult) and value._items: + value = tuple(value._items) - if isinstance(value, Query): - query = value + if isinstance(value, (Query, QueryResult)): + query = value._query if isinstance(value, QueryResult) else value vars.update(query._vars) vartypes.update(query._translator.vartypes) @@ -5981,7 +5981,7 @@ def __init__(self, query, limit, offset, lazy): self._col_names = translator.col_names def _get_type_(self): if self._items is None: - return QueryType(self._query) + return QueryType(self._query, self._limit, self._offset) item_type = self._query._translator.expr_type return tuple(item_type for item in self._items) def _normalize_var(self, query_type): @@ -6000,6 +6000,10 @@ def __setstate__(self, state): self._query = None self._items, self._limit, self._offset, self._expr_type, self._col_names = state def __repr__(self): + if self._items is not None: + return self.__str__() + return '' % hex(id(self)) + def __str__(self): return repr(self._get_items()) def __iter__(self): return QueryResultIterator(self) diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 36634c48c..5ab634a0e 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -124,13 +124,21 @@ def __ne__(self, other): return not self.__eq__(other) class QueryType(object): - def __init__(self, query): + def __init__(self, query, limit=None, offset=None): self.query_key = query._key self.translator = query._translator + self.limit = limit + self.offset = offset def __hash__(self): - return hash(self.query_key) + result = hash(self.query_key) + if self.limit is not None: + result ^= hash(self.limit + 3) + if self.offset is not None: + result ^= hash(self.offset) + return result def __eq__(self, other): - return type(other) is QueryType and self.query_key == other.query_key + return type(other) is QueryType and self.query_key == other.query_key \ + and self.limit == other.limit and self.offset == other.offset def __ne__(self, other): return not self.__eq__(other) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 946f224e7..cc8d0250f 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1,5 +1,5 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PY2, items_list, izip, xrange, basestring, unicode, buffer, with_metaclass +from pony.py23compat import PY2, items_list, izip, xrange, basestring, unicode, buffer, with_metaclass, int_types import types, sys, re, itertools, inspect from decimal import Decimal @@ -121,6 +121,8 @@ def dispatch_external(translator, node): if translator.database is not prev_translator.database: throw(TranslationError, 'Mixing queries from different databases') monad = QuerySetMonad(prev_translator) + if t.limit is not None or t.offset is not None: + monad = monad.call_limit(t.limit, t.offset) elif tt is FuncType: func = t.func func_monad_class = translator.registered_functions.get(func, ErrorSpecialFuncMonad) @@ -243,6 +245,7 @@ def init(translator, tree, parent_translator, code_key=None, filter_num=None, ex translator.conditions = translator.sqlquery.conditions translator.having_conditions = [] translator.order = [] + translator.limit = translator.offset = None translator.inside_order_by = False translator.aggregated = False if not optimize else True translator.hint_join = False @@ -310,14 +313,17 @@ def check_name_is_single(): translator.namespace[name] = node.monad = ObjectIterMonad(tableref, entity) elif isinstance(iterable, QueryType): prev_translator = deepcopy(iterable.translator) + prev_limit = iterable.limit + prev_offset = iterable.offset database = prev_translator.database try: - translator.process_query_qual(prev_translator, names, try_extend_prev_query=not i) + translator.process_query_qual(prev_translator, prev_limit, prev_offset, + names, try_extend_prev_query=not i) except UseAnotherTranslator as e: assert local.translators and local.translators[-1] is translator translator = e.translator local.translators[-1] = translator - else: throw(TranslationError, 'Inside declarative query, iterator must be entity. ' + else: throw(TranslationError, 'Inside declarative query, iterator must be entity or query. ' 'Got: for %s in %s' % (name, ast2src(qual.iter))) else: @@ -328,7 +334,7 @@ def check_name_is_single(): subtranslator = monad.subtranslator database = subtranslator.database try: - translator.process_query_qual(subtranslator, names) + translator.process_query_qual(subtranslator, monad.limit, monad.offset, names) except UseAnotherTranslator: assert False else: @@ -493,7 +499,7 @@ def can_be_optimized(translator): if tableref.joined and not aggr_path.startswith(tableref.name_path): return False return aggr_path - def process_query_qual(translator, prev_translator, names, try_extend_prev_query=False): + def process_query_qual(translator, prev_translator, prev_limit, prev_offset, names, try_extend_prev_query=False): sqlquery = translator.sqlquery tablerefs = sqlquery.tablerefs expr_types = prev_translator.expr_type @@ -533,15 +539,18 @@ def process_query_qual(translator, prev_translator, names, try_extend_prev_query prev_translator.namespace_stack = [ {name: expr for name, expr in izip(names, prev_translator.expr_monads)} ] + prev_translator.limit, prev_translator.offset = combine_limit_and_offset( + prev_translator.limit, prev_translator.offset, prev_limit, prev_offset) raise UseAnotherTranslator(prev_translator) + if len(names) == 1 and isinstance(prev_translator.expr_type, EntityMeta) \ and not prev_translator.aggregated and not prev_translator.distinct: name = names[0] entity = prev_translator.expr_type [expr_monad] = prev_translator.expr_monads entity_alias = expr_monad.tableref.alias - subquery_ast = prev_translator.construct_subquery_ast(star=entity_alias) + subquery_ast = prev_translator.construct_subquery_ast(prev_limit, prev_offset, star=entity_alias) tableref = StarTableRef(sqlquery, name, entity, subquery_ast) tablerefs[name] = tableref tableref.make_join() @@ -562,7 +571,7 @@ def process_query_qual(translator, prev_translator, names, try_extend_prev_query aliases.append(name) aliases_dict[base_expr_monad] = name - subquery_ast = prev_translator.construct_subquery_ast(aliases=aliases) + subquery_ast = prev_translator.construct_subquery_ast(prev_limit, prev_offset, aliases=aliases) tableref = ExprTableRef(sqlquery, 't', subquery_ast, names, aliases) for name in names: tablerefs[name] = tableref @@ -580,8 +589,10 @@ def process_query_qual(translator, prev_translator, names, try_extend_prev_query expr_monad = ExprMonad.new(t, expr_ast, base_expr_monad.nullable) assert name not in translator.namespace translator.namespace[name] = expr_monad - def construct_subquery_ast(translator, aliases=None, star=None, distinct=None, is_not_null_checks=False): - subquery_ast, attr_offsets = translator.construct_sql_ast(distinct=distinct, is_not_null_checks=is_not_null_checks) + def construct_subquery_ast(translator, limit=None, offset=None, aliases=None, star=None, + distinct=None, is_not_null_checks=False): + subquery_ast, attr_offsets = translator.construct_sql_ast( + limit, offset, distinct, is_not_null_checks=is_not_null_checks) assert len(subquery_ast) >= 3 and subquery_ast[0] == 'SELECT' select_ast = subquery_ast[1][:] @@ -720,6 +731,7 @@ def ast_transformer(ast): if translator.order and not aggr_func_name: sql_ast.append([ 'ORDER_BY' ] + translator.order) + limit, offset = combine_limit_and_offset(translator.limit, translator.offset, limit, offset) if limit is not None or offset is not None: assert not aggr_func_name provider = translator.database.provider @@ -1106,6 +1118,27 @@ def postJoinedStr(translator, node): return StringExprMonad(unicode, sql, nullable=nullable) def postFormattedValue(translator, node): throw(NotImplementedError, 'You cannot set width and precision markers in query') + +def combine_limit_and_offset(limit, offset, limit2, offset2): + assert limit is None or limit >= 0 + assert limit2 is None or limit2 >= 0 + + if offset2 is not None: + if limit is not None: + limit = max(0, limit - offset2) + offset = (offset or 0) + offset2 + + if limit2 is not None: + if limit is not None: + limit = min(limit, limit2) + else: + limit = limit2 + + if limit == 0: + offset = None + + return limit, offset + def coerce_monads(m1, m2, for_comparison=False): result_type = coerce_types(m1.type, m2.type) if result_type in numeric_types and bool in (m1.type, m2.type) and ( @@ -3119,8 +3152,21 @@ def __init__(monad, subtranslator): Monad.__init__(monad, monad_type) monad.subtranslator = subtranslator monad.item_type = item_type + monad.limit = monad.offset = None def requires_distinct(monad, joined=False): assert False + def call_limit(monad, limit=None, offset=None): + if limit is not None and not isinstance(limit, int_types): + if not isinstance(limit, (NoneMonad, NumericConstMonad)): + throw(TypeError, '`limit` parameter should be of int type') + limit = limit.value + if offset is not None and not isinstance(offset, int_types): + if not isinstance(offset, (NoneMonad, NumericConstMonad)): + throw(TypeError, '`offset` parameter should be of int type') + offset = offset.value + monad.limit = limit + monad.offset = offset + return monad def contains(monad, item, not_in=False): translator = monad.translator check_comparable(item, monad, 'in') @@ -3131,7 +3177,7 @@ def contains(monad, item, not_in=False): sub = monad.subtranslator if translator.hint_join and len(sub.sqlquery.from_ast[1]) == 3: - subquery_ast = sub.construct_subquery_ast(distinct=False) + subquery_ast = sub.construct_subquery_ast(monad.limit, monad.offset, distinct=False) select_ast, from_ast, where_ast = subquery_ast[1:4] sqlquery = translator.sqlquery if not not_in: @@ -3167,10 +3213,10 @@ def contains(monad, item, not_in=False): else: sql_ast = [ 'EQ', [ 'VALUE', 1 ], [ 'VALUE', 1 ] ] else: if len(item_columns) == 1: - subquery_ast = sub.construct_subquery_ast(distinct=False, is_not_null_checks=not_in) + subquery_ast = sub.construct_subquery_ast(monad.limit, monad.offset, distinct=False, is_not_null_checks=not_in) sql_ast = [ 'NOT_IN' if not_in else 'IN', item_columns[0], subquery_ast ] elif translator.row_value_syntax: - subquery_ast = sub.construct_subquery_ast(distinct=False, is_not_null_checks=not_in) + subquery_ast = sub.construct_subquery_ast(monad.limit, monad.offset, distinct=False, is_not_null_checks=not_in) sql_ast = [ 'NOT_IN' if not_in else 'IN', [ 'ROW' ] + item_columns, subquery_ast ] else: ambiguous_names = set() @@ -3178,7 +3224,7 @@ def contains(monad, item, not_in=False): for name in translator.sqlquery.tablerefs: if name in sub.sqlquery.tablerefs: ambiguous_names.add(name) - subquery_ast = sub.construct_subquery_ast(distinct=False) + subquery_ast = sub.construct_subquery_ast(monad.limit, monad.offset, distinct=False) if ambiguous_names: select_ast = subquery_ast[1] expr_aliases = [] @@ -3314,7 +3360,7 @@ def call_group_concat(monad, sep=None, distinct=None): throw(TypeError, '`sep` option of `group_concat` should be type of str. Got: %s' % type(sep).__name__) return monad.aggregate('GROUP_CONCAT', distinct, sep=sep) def getsql(monad): - return monad.subtranslator.construct_subquery_ast() + return monad.subtranslator.construct_subquery_ast(monad.limit, monad.offset) def find_or_create_having_ast(sections): groupby_offset = None diff --git a/pony/orm/tests/test_select_from_select_queries.py b/pony/orm/tests/test_select_from_select_queries.py index 587f4205d..9210687e4 100644 --- a/pony/orm/tests/test_select_from_select_queries.py +++ b/pony/orm/tests/test_select_from_select_queries.py @@ -356,6 +356,38 @@ def f2(q): q2 = f2(q) self.assertEqual(set(q2), {'Lee'}) + @db_session + def test_42(self): + q = select(s for s in Student if s.scholarship > 0) + q2 = select(g for g in Group if g.major == 'Computer Science')[:] + q3 = select(s.first_name for s in q if s.group in q2) + self.assertEqual(set(q3), {'Alex', 'Mary'}) + + @db_session + def test_43(self): + q = select(s for s in Student).order_by(Student.first_name).limit(3, offset=1) + q2 = select(s.first_name for s in Student if s in q) + self.assertEqual(set(q2), {'John', 'Bruce'}) + + @db_session + def test_44(self): + q = select(s for s in Student).order_by(Student.first_name).limit(3, offset=1) + q2 = select(s.first_name for s in q) + self.assertEqual(set(q2), {'Bruce', 'John', 'Mary'}) + + @db_session + def test_45(self): + q = select(s for s in Student).order_by(Student.first_name).limit(3, offset=1) + q2 = select(s for s in q if s.age > 18).limit(2, offset=1) + q3 = select(s.last_name for s in q2).limit(2, offset=1) + self.assertEqual(set(q3), {'Brown'}) + + @db_session + def test_46(self): + q = select((c, count(c.students)) for c in Course).order_by(-2).limit(2) + q2 = select((c.name, c.credits, m) for c, m in q).limit(1, offset=1) + self.assertEqual(set(q2), {('3D Modeling', 15, 2)}) + if __name__ == '__main__': unittest.main() From fe56efa94dea1e57f9fd8194e7caf5748752e0c4 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 19 Sep 2018 17:06:53 +0300 Subject: [PATCH 382/547] Fix #380: db_session should work with async functions --- pony/orm/core.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 56c774594..b7c5bfa62 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2,7 +2,7 @@ from pony.py23compat import PY2, izip, imap, iteritems, itervalues, items_list, values_list, xrange, cmp, \ basestring, unicode, buffer, int_types, builtins, with_metaclass -import json, re, sys, types, datetime, logging, itertools, warnings +import json, re, sys, types, datetime, logging, itertools, warnings, inspect from operator import attrgetter, itemgetter from itertools import chain, starmap, repeat from time import time @@ -13,6 +13,7 @@ from collections import defaultdict from hashlib import md5 from inspect import isgeneratorfunction +from functools import wraps from pony.thirdparty.compiler import ast, parse @@ -406,9 +407,9 @@ def __call__(db_session, *args, **kwargs): if kwargs: throw(TypeError, 'Pass only keyword arguments to db_session or use db_session as decorator') func = args[0] - if not isgeneratorfunction(func): - return db_session._wrap_function(func) - return db_session._wrap_generator_function(func) + if isgeneratorfunction(func) or hasattr(inspect, 'iscoroutinefunction') and inspect.iscoroutinefunction(func): + return db_session._wrap_coroutine_or_generator_function(func) + return db_session._wrap_function(func) def __enter__(db_session): if db_session.retry is not 0: throw(TypeError, "@db_session can accept 'retry' parameter only when used as decorator and not as context manager") @@ -483,7 +484,7 @@ def new_func(func, *args, **kwargs): if db_session.sql_debug is not None: local.pop_debug_state() return decorator(new_func, func) - def _wrap_generator_function(db_session, gen_func): + def _wrap_coroutine_or_generator_function(db_session, gen_func): for option in ('ddl', 'retry', 'serializable'): if getattr(db_session, option, None): throw(TypeError, "db_session with `%s` option cannot be applied to generator function" % option) @@ -501,7 +502,8 @@ def interact(iterator, input=None, exc_info=None): if throw_ is None: reraise(*exc_info) return throw_(*exc_info) - def new_gen_func(gen_func, *args, **kwargs): + @wraps(gen_func) + def new_gen_func(*args, **kwargs): db2cache_copy = {} def wrapped_interact(iterator, input=None, exc_info=None): @@ -538,7 +540,7 @@ def wrapped_interact(iterator, input=None, exc_info=None): local.db_session = None gen = gen_func(*args, **kwargs) - iterator = iter(gen) + iterator = gen.__await__() if hasattr(gen, '__await__') else iter(gen) output = wrapped_interact(iterator) try: while True: @@ -550,7 +552,10 @@ def wrapped_interact(iterator, input=None, exc_info=None): output = wrapped_interact(iterator, input) except StopIteration: return - return decorator(new_gen_func, gen_func) + + if hasattr(types, 'coroutine'): + new_gen_func = types.coroutine(new_gen_func) + return new_gen_func db_session = DBSessionContextManager() From 76093b92bba75b3ccb46f6c1b4f519894481dfc0 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 30 Sep 2018 06:22:30 +0300 Subject: [PATCH 383/547] Remove unused code --- pony/orm/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index b7c5bfa62..604d51cd5 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -3976,7 +3976,6 @@ def _construct_select_clause_(entity, alias=None, distinct=False, def _construct_discriminator_criteria_(entity, alias=None): discr_attr = entity._discriminator_attr_ if discr_attr is None: return None - code2cls = discr_attr.code2cls discr_values = [ [ 'VALUE', cls._discriminator_ ] for cls in entity._subclasses_ ] discr_values.append([ 'VALUE', entity._discriminator_]) return [ 'IN', [ 'COLUMN', alias, discr_attr.column ], discr_values ] From f5e24f2743b4e1a298291a07f74144861b809cb4 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 30 Sep 2018 07:01:18 +0300 Subject: [PATCH 384/547] Add isinstance() support in queries --- pony/orm/sqltranslation.py | 36 ++++++++++++ pony/orm/tests/test_isinstance.py | 95 +++++++++++++++++++++++++++++++ 2 files changed, 131 insertions(+) create mode 100644 pony/orm/tests/test_isinstance.py diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index cc8d0250f..53ccba820 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2465,6 +2465,42 @@ def __call__(monad, *args, **kwargs): except TypeError as exc: reraise_improved_typeerror(exc, 'call', monad.type.__name__) +def get_classes(classinfo): + if isinstance(classinfo, EntityMonad): + yield classinfo.type.item_type + elif isinstance(classinfo, ListMonad): + for item in classinfo.items: + for type in get_classes(item): + yield type + else: throw(TypeError, ast2src(classinfo.node)) + +class FuncIsinstanceMonad(FuncMonad): + func = isinstance + def call(monad, obj, classinfo): + if not isinstance(obj, ObjectMixin): throw(ValueError, + 'Inside a query, isinstance first argument should be of entity type. Got: %s' % ast2src(obj.node)) + entity = obj.type + classes = list(get_classes(classinfo)) + subclasses = set() + for cls in classes: + if entity._root_ is cls._root_: + subclasses.add(cls) + subclasses.update(cls._subclasses_) + if entity in subclasses: + return BoolExprMonad(['EQ', ['VALUE', 1], ['VALUE', 1]], nullable=False) + + subclasses.intersection_update(entity._subclasses_) + if not subclasses: + return BoolExprMonad(['EQ', ['VALUE', 0], ['VALUE', 1]], nullable=False) + + discr_attr = entity._discriminator_attr_ + assert discr_attr is not None + discr_values = [ [ 'VALUE', cls._discriminator_ ] for cls in subclasses ] + alias, pk_columns = obj.tableref.make_join(pk_only=True) + sql = [ 'IN', [ 'COLUMN', alias, discr_attr.column ], discr_values ] + return BoolExprMonad(sql, nullable=False) + + class FuncBufferMonad(FuncMonad): func = buffer def call(monad, source, encoding=None, errors=None): diff --git a/pony/orm/tests/test_isinstance.py b/pony/orm/tests/test_isinstance.py new file mode 100644 index 000000000..07b73531e --- /dev/null +++ b/pony/orm/tests/test_isinstance.py @@ -0,0 +1,95 @@ +from datetime import date +from decimal import Decimal + +import unittest + +from pony.orm import * +from pony.orm.tests.testutils import * + +db = Database('sqlite', ':memory:', create_db=True) + +class Person(db.Entity): + id = PrimaryKey(int, auto=True) + name = Required(str) + dob = Optional(date) + ssn = Required(str, unique=True) + +class Student(Person): + group = Required("Group") + mentor = Optional("Teacher") + attend_courses = Set("Course") + +class Teacher(Person): + teach_courses = Set("Course") + apprentices = Set("Student") + salary = Required(Decimal) + +class Assistant(Student, Teacher): + pass + +class Professor(Teacher): + position = Required(str) + +class Group(db.Entity): + number = PrimaryKey(int) + students = Set("Student") + +class Course(db.Entity): + name = Required(str) + semester = Required(int) + students = Set(Student) + teachers = Set(Teacher) + PrimaryKey(name, semester) + +db.generate_mapping(create_tables=True) + +with db_session: + p = Person(name='Person1', ssn='SSN1') + g = Group(number=123) + prof = Professor(name='Professor1', salary=1000, position='position1', ssn='SSN5') + a1 = Assistant(name='Assistant1', group=g, salary=100, ssn='SSN4', mentor=prof) + a2 = Assistant(name='Assistant2', group=g, salary=200, ssn='SSN6', mentor=prof) + s1 = Student(name='Student1', group=g, ssn='SSN2', mentor=a1) + s2 = Student(name='Student2', group=g, ssn='SSN3') + +class TestVolatile(unittest.TestCase): + @db_session + def test_1(self): + q = select(p.name for p in Person if isinstance(p, Student)) + self.assertEqual(set(q), {'Student1', 'Student2', 'Assistant1', 'Assistant2'}) + + @db_session + def test_2(self): + q = select(p.name for p in Person if not isinstance(p, Student)) + self.assertEqual(set(q), {'Person1', 'Professor1'}) + + @db_session + def test_3(self): + q = select(p.name for p in Student if isinstance(p, Professor)) + self.assertEqual(set(q), set()) + + @db_session + def test_4(self): + q = select(p.name for p in Person if not isinstance(p, Person)) + self.assertEqual(set(q), set()) + + @db_session + def test_5(self): + q = select(p.name for p in Person if isinstance(p, (Student, Teacher))) + self.assertEqual(set(q), {'Student1', 'Student2', 'Assistant1', 'Assistant2', 'Professor1'}) + + @db_session + def test_6(self): + q = select(p.name for p in Person if isinstance(p, Student) and isinstance(p, Teacher)) + self.assertEqual(set(q), {'Assistant1', 'Assistant2'}) + + @db_session + def test_7(self): + q = select(p.name for p in Person + if (isinstance(p, Student) and p.ssn == 'SSN2') + or (isinstance(p, Professor) and p.salary > 500)) + self.assertEqual(set(q), {'Student1', 'Professor1'}) + + +if __name__ == '__main__': + unittest.main() From 0f2ce9edf649592f3a403e36fda3e97751064061 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 30 Sep 2018 11:52:48 +0300 Subject: [PATCH 385/547] Test added --- pony/orm/tests/test_isinstance.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pony/orm/tests/test_isinstance.py b/pony/orm/tests/test_isinstance.py index 07b73531e..51c254f58 100644 --- a/pony/orm/tests/test_isinstance.py +++ b/pony/orm/tests/test_isinstance.py @@ -90,6 +90,10 @@ def test_7(self): or (isinstance(p, Professor) and p.salary > 500)) self.assertEqual(set(q), {'Student1', 'Professor1'}) + @db_session + def test_8(self): + q = select(p.name for p in Person if isinstance(p, Person)) + self.assertEqual(set(q), {'Person1', 'Student1', 'Student2', 'Assistant1', 'Assistant2', 'Professor1'}) if __name__ == '__main__': unittest.main() From c3af5090dd1bb1b5d6aa6f2e40fd8753f2df8323 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 30 Sep 2018 12:01:47 +0300 Subject: [PATCH 386/547] Test added --- pony/orm/tests/test_isinstance.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pony/orm/tests/test_isinstance.py b/pony/orm/tests/test_isinstance.py index 51c254f58..a71ce9eb2 100644 --- a/pony/orm/tests/test_isinstance.py +++ b/pony/orm/tests/test_isinstance.py @@ -95,5 +95,11 @@ def test_8(self): q = select(p.name for p in Person if isinstance(p, Person)) self.assertEqual(set(q), {'Person1', 'Student1', 'Student2', 'Assistant1', 'Assistant2', 'Professor1'}) + @db_session + def test_8(self): + q = select(g.number for g in Group if isinstance(g, Group)) + self.assertEqual(set(q), {123}) + + if __name__ == '__main__': unittest.main() From ea6fee178b38973875949782ec7d8bed4965a482 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 1 Oct 2018 14:15:13 +0300 Subject: [PATCH 387/547] Typo fixed --- pony/orm/tests/test_isinstance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/tests/test_isinstance.py b/pony/orm/tests/test_isinstance.py index a71ce9eb2..cd7be64cc 100644 --- a/pony/orm/tests/test_isinstance.py +++ b/pony/orm/tests/test_isinstance.py @@ -96,7 +96,7 @@ def test_8(self): self.assertEqual(set(q), {'Person1', 'Student1', 'Student2', 'Assistant1', 'Assistant2', 'Professor1'}) @db_session - def test_8(self): + def test_9(self): q = select(g.number for g in Group if isinstance(g, Group)) self.assertEqual(set(q), {123}) From 890ee84634a22469e64f8a23279a40db8a22b5c2 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 3 Oct 2018 18:04:37 +0300 Subject: [PATCH 388/547] Fixes #385: Tests fail with python3.6 --- setup.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 892beb2a7..e1689e30c 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,13 @@ from setuptools import setup import sys +import unittest + +def test_suite(): + test_loader = unittest.TestLoader() + test_suite = test_loader.discover('pony.orm.tests', pattern='test_*.py') + return test_suite + name = "pony" version = __import__('pony').__version__ description = "Pony Object-Relational Mapper" @@ -116,5 +123,6 @@ license=licence, packages=packages, package_data=package_data, - download_url=download_url + download_url=download_url, + test_suite='setup.test_suite' ) From 287a05ec1e4acedfc97b12b7d08d1c2dfcce753f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 5 Oct 2018 03:32:51 +0300 Subject: [PATCH 389/547] Fixes #386: `release unlocked lock` error in SQLite --- pony/orm/core.py | 64 +++++++++++++++++++--------------- pony/orm/tests/test_bug_386.py | 17 +++++++++ 2 files changed, 52 insertions(+), 29 deletions(-) create mode 100644 pony/orm/tests/test_bug_386.py diff --git a/pony/orm/core.py b/pony/orm/core.py index 604d51cd5..8a99e1a02 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -870,7 +870,8 @@ def _exec_sql(database, sql, arguments=None, returning_id=False, start_transacti if local.debug: log_sql(sql, arguments) t = time() new_id = provider.execute(cursor, sql, arguments, returning_id) - if cache.immediate: cache.in_transaction = True + if cache.immediate: + cache.in_transaction = True database._update_local_stat(sql, t) if not returning_id: return cursor if PY2 and type(new_id) is long: new_id = int(new_id) @@ -1785,34 +1786,39 @@ def flush(cache): if cache.noflush_counter: return assert cache.is_alive assert not cache.saved_objects - if not cache.immediate: cache.immediate = True - for i in xrange(50): - if not cache.modified: return - - with cache.flush_disabled(): - for obj in cache.objects_to_save: # can grow during iteration - if obj is not None: obj._before_save_() - - cache.query_results.clear() - modified_m2m = cache._calc_modified_m2m() - for attr, (added, removed) in iteritems(modified_m2m): - if not removed: continue - attr.remove_m2m(removed) - for obj in cache.objects_to_save: - if obj is not None: obj._save_() - for attr, (added, removed) in iteritems(modified_m2m): - if not added: continue - attr.add_m2m(added) - - cache.max_id_cache.clear() - cache.modified_collections.clear() - cache.objects_to_save[:] = () - cache.modified = False - - cache.call_after_save_hooks() - else: - if cache.modified: throw(TransactionError, - 'Recursion depth limit reached in obj._after_save_() call') + prev_immediate = cache.immediate + cache.immediate = True + try: + for i in xrange(50): + if not cache.modified: return + + with cache.flush_disabled(): + for obj in cache.objects_to_save: # can grow during iteration + if obj is not None: obj._before_save_() + + cache.query_results.clear() + modified_m2m = cache._calc_modified_m2m() + for attr, (added, removed) in iteritems(modified_m2m): + if not removed: continue + attr.remove_m2m(removed) + for obj in cache.objects_to_save: + if obj is not None: obj._save_() + for attr, (added, removed) in iteritems(modified_m2m): + if not added: continue + attr.add_m2m(added) + + cache.max_id_cache.clear() + cache.modified_collections.clear() + cache.objects_to_save[:] = () + cache.modified = False + + cache.call_after_save_hooks() + else: + if cache.modified: throw(TransactionError, + 'Recursion depth limit reached in obj._after_save_() call') + finally: + if not cache.in_transaction: + cache.immediate = prev_immediate def call_after_save_hooks(cache): saved_objects = cache.saved_objects cache.saved_objects = [] diff --git a/pony/orm/tests/test_bug_386.py b/pony/orm/tests/test_bug_386.py new file mode 100644 index 000000000..fba12711d --- /dev/null +++ b/pony/orm/tests/test_bug_386.py @@ -0,0 +1,17 @@ +import unittest + +from pony import orm + +class Test(unittest.TestCase): + def test_1(self): + db = orm.Database('sqlite', ':memory:') + + class Person(db.Entity): + name = orm.Required(str) + + db.generate_mapping(create_tables=True) + + with orm.db_session: + a = Person(name='John') + a.delete() + Person.exists(name='Mike') From d4b0291432746ce8b348339e54087a94212c06a0 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 16 Oct 2018 13:44:47 +0300 Subject: [PATCH 390/547] Fixes #390: "TypeError: writable buffers are not hashable" --- pony/orm/dbapiprovider.py | 7 ++++-- pony/orm/tests/test_buffer.py | 41 ++++++++++++++++++++++++++++++++++ pony/orm/tests/test_bug_355.py | 23 ------------------- 3 files changed, 46 insertions(+), 25 deletions(-) create mode 100644 pony/orm/tests/test_buffer.py delete mode 100644 pony/orm/tests/test_bug_355.py diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 2f185018c..2238e756a 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -647,10 +647,13 @@ def validate(converter, val, obj=None): if isinstance(val, str): return buffer(val) throw(TypeError, "Attribute %r: expected type is 'buffer'. Got: %r" % (converter.attr, type(val))) def sql2py(converter, val): - if not isinstance(val, buffer) or \ - (PY2 and converter.attr.pk_offset is not None and 'read-write' in repr(val)): # Issue 355 + if not isinstance(val, buffer): try: val = buffer(val) except: pass + elif PY2 and converter.attr is not None and converter.attr.is_part_of_unique_index: + try: hash(val) + except TypeError: + val = buffer(val) return val def sql_type(converter): return 'BLOB' diff --git a/pony/orm/tests/test_buffer.py b/pony/orm/tests/test_buffer.py new file mode 100644 index 000000000..7dc9f13de --- /dev/null +++ b/pony/orm/tests/test_buffer.py @@ -0,0 +1,41 @@ +import unittest + +from pony import orm +from pony.py23compat import buffer + +db = orm.Database('sqlite', ':memory:') + +class Foo(db.Entity): + id = orm.PrimaryKey(int) + b = orm.Optional(orm.buffer) + +class Bar(db.Entity): + b = orm.PrimaryKey(orm.buffer) + +class Baz(db.Entity): + id = orm.PrimaryKey(int) + b = orm.Optional(orm.buffer, unique=True) + +db.generate_mapping(create_tables=True) + +buf = orm.buffer(b'123') + +with orm.db_session: + Foo(id=1, b=buf) + Bar(b=buf) + Baz(id=1, b=buf) + + +class Test(unittest.TestCase): + def test_1(self): # Bug #355 + with orm.db_session: + Bar[buf] + + def test_2(self): # Regression after #355 fix + with orm.db_session: + result = orm.select(bar.b for bar in Foo)[:] + self.assertEqual(result, [buf]) + + def test_3(self): # Bug #390 + with orm.db_session: + Baz.get(b=buf) diff --git a/pony/orm/tests/test_bug_355.py b/pony/orm/tests/test_bug_355.py deleted file mode 100644 index 9a579ebab..000000000 --- a/pony/orm/tests/test_bug_355.py +++ /dev/null @@ -1,23 +0,0 @@ -import unittest - -from pony import orm -from pony.py23compat import buffer - -class Test(unittest.TestCase): - def test_1(self): - db = orm.Database('sqlite', ':memory:') - - class Buf(db.Entity): - pk = orm.PrimaryKey(buffer) - - db.generate_mapping(create_tables=True) - - x = buffer(b'123') - - with orm.db_session: - Buf(pk=x) - orm.commit() - - with orm.db_session: - t = Buf[x] - From 1934624b7aa5f4106a62aa0b6bf90c94a8c395a7 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Tue, 9 Oct 2018 18:29:34 +0300 Subject: [PATCH 391/547] Entity __getitem__ works with Entity.get_pk() even if pk is compostite --- pony/orm/core.py | 24 +++++++----- pony/orm/tests/test_get_pk.py | 73 +++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 9 deletions(-) create mode 100644 pony/orm/tests/test_get_pk.py diff --git a/pony/orm/core.py b/pony/orm/core.py index 8a99e1a02..d906e29d0 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -3761,11 +3761,14 @@ def __iter__(entity): @cut_traceback def __getitem__(entity, key): if type(key) is not tuple: key = (key,) - if len(key) != len(entity._pk_attrs_): - throw(TypeError, 'Invalid count of attrs in %s primary key (%s instead of %s)' - % (entity.__name__, len(key), len(entity._pk_attrs_))) - kwargs = {attr.name: value for attr, value in izip(entity._pk_attrs_, key)} - return entity._find_one_(kwargs) + if len(key) == len(entity._pk_attrs_): + kwargs = {attr.name: value for attr, value in izip(entity._pk_attrs_, key)} + return entity._find_one_(kwargs) + if len(key) == len(entity._pk_columns_): + return entity._get_by_raw_pkval_(key, from_db=False, seed=False) + + throw(TypeError, 'Invalid count of attrs in %s primary key (%s instead of %s)' + % (entity.__name__, len(key), len(entity._pk_attrs_))) @cut_traceback def exists(entity, *args, **kwargs): if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1).exists() @@ -4210,7 +4213,7 @@ def _get_from_identity_map_(entity, pkval, status, for_update=False, undo_funcs= assert cache.in_transaction cache.for_update.add(obj) return obj - def _get_by_raw_pkval_(entity, raw_pkval, for_update=False, from_db=True): + def _get_by_raw_pkval_(entity, raw_pkval, for_update=False, from_db=True, seed=True): i = 0 pkval = [] for attr in entity._pk_attrs_: @@ -4218,16 +4221,19 @@ def _get_by_raw_pkval_(entity, raw_pkval, for_update=False, from_db=True): val = raw_pkval[i] i += 1 if not attr.reverse: val = attr.validate(val, None, entity, from_db=from_db) - else: val = attr.py_type._get_by_raw_pkval_((val,), from_db=from_db) + else: val = attr.py_type._get_by_raw_pkval_((val,), from_db=from_db, seed=seed) else: if not attr.reverse: throw(NotImplementedError) vals = raw_pkval[i:i+len(attr.columns)] - val = attr.py_type._get_by_raw_pkval_(vals, from_db=from_db) + val = attr.py_type._get_by_raw_pkval_(vals, from_db=from_db, seed=seed) i += len(attr.columns) pkval.append(val) if not entity._pk_is_composite_: pkval = pkval[0] else: pkval = tuple(pkval) - obj = entity._get_from_identity_map_(pkval, 'loaded', for_update) + if seed: + obj = entity._get_from_identity_map_(pkval, 'loaded', for_update) + else: + obj = entity[pkval] assert obj._status_ != 'cancelled' return obj def _get_propagation_mixin_(entity): diff --git a/pony/orm/tests/test_get_pk.py b/pony/orm/tests/test_get_pk.py new file mode 100644 index 000000000..4579a83a0 --- /dev/null +++ b/pony/orm/tests/test_get_pk.py @@ -0,0 +1,73 @@ +from pony.py23compat import basestring + +import unittest + +from pony.orm import * +from pony import orm +from pony.utils import cached_property +from datetime import date + + +class Test(unittest.TestCase): + + @cached_property + def db(self): + return orm.Database('sqlite', ':memory:') + + def setUp(self): + db = self.db + self.day = date.today() + + class A(db.Entity): + b = Required("B") + c = Required("C") + PrimaryKey(b, c) + + class B(db.Entity): + id = PrimaryKey(date) + a_set = Set(A) + + class C(db.Entity): + x = Required("X") + y = Required("Y") + a_set = Set(A) + PrimaryKey(x, y) + + class X(db.Entity): + id = PrimaryKey(int) + c_set = Set(C) + + class Y(db.Entity): + id = PrimaryKey(int) + c_set = Set(C) + + db.generate_mapping(check_tables=True, create_tables=True) + + with orm.db_session: + x1 = X(id=123) + y1 = Y(id=456) + b1 = B(id=self.day) + c1 = C(x=x1, y=y1) + A(b=b1, c=c1) + + + @db_session + def test_1(self): + a1 = self.db.A.select().first() + a2 = self.db.A[a1.get_pk()] + self.assertEqual(a1, a2) + + @db_session + def test2(self): + a = self.db.A.select().first() + b = self.db.B.select().first() + c = self.db.C.select().first() + pk = (b.get_pk(), c._get_raw_pkval_()) + self.assertTrue(a is self.db.A[pk]) + + @db_session + def test3(self): + a = self.db.A.select().first() + c = self.db.C.select().first() + pk = (self.day, c.get_pk()) + self.assertTrue(a is self.db.A[pk]) \ No newline at end of file From b237258bb6bb190362abb99ac8a60b208513478d Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 20 Oct 2018 14:18:46 +0300 Subject: [PATCH 392/547] Fixes #398: Added support of numpy types --- pony/orm/dbapiprovider.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 2238e756a..33b2fe164 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -511,6 +511,8 @@ def init(converter, kwargs): converter.unsigned = unsigned def validate(converter, val, obj=None): if isinstance(val, int_types): pass + elif hasattr(val, '__index__'): + val = val.__index__() elif isinstance(val, basestring): try: val = int(val) except ValueError: throw(ValueError, From 4f0f2add9d0502cd8804d0dff4278e1c685662c4 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 20 Oct 2018 16:20:25 +0300 Subject: [PATCH 393/547] Fixes #380: async support --- pony/orm/core.py | 12 +++++++----- pony/orm/tests/test_generator_db_session.py | 18 ++++++------------ 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index d906e29d0..4484f7ec9 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -520,13 +520,14 @@ def wrapped_interact(iterator, input=None, exc_info=None): try: output = interact(iterator, input, exc_info) except StopIteration as e: + commit() for cache in _get_caches(): - if cache.modified or cache.in_transaction: throw(TransactionError, - 'You need to manually commit() changes before exiting from the generator') - raise + cache.release() + assert not local.db2cache + raise e for cache in _get_caches(): if cache.modified or cache.in_transaction: throw(TransactionError, - 'You need to manually commit() changes before yielding from the generator') + 'You need to manually commit() changes before suspending the generator') except: rollback_and_reraise(sys.exc_info()) else: @@ -541,8 +542,8 @@ def wrapped_interact(iterator, input=None, exc_info=None): gen = gen_func(*args, **kwargs) iterator = gen.__await__() if hasattr(gen, '__await__') else iter(gen) - output = wrapped_interact(iterator) try: + output = wrapped_interact(iterator) while True: try: input = yield output @@ -551,6 +552,7 @@ def wrapped_interact(iterator, input=None, exc_info=None): else: output = wrapped_interact(iterator, input) except StopIteration: + assert not db2cache_copy and not local.db2cache return if hasattr(types, 'coroutine'): diff --git a/pony/orm/tests/test_generator_db_session.py b/pony/orm/tests/test_generator_db_session.py index a76df966c..8f54615d1 100644 --- a/pony/orm/tests/test_generator_db_session.py +++ b/pony/orm/tests/test_generator_db_session.py @@ -119,7 +119,7 @@ def f(id1): a2 = self.Account[2] self.assertEqual(a2.amount, 2100) - @raises_exception(TransactionError, 'You need to manually commit() changes before yielding from the generator') + @raises_exception(TransactionError, 'You need to manually commit() changes before suspending the generator') def test8(self): @db_session def f(id1): @@ -141,7 +141,6 @@ def f(id1): for amount in f(1): pass - @raises_exception(TransactionError, 'You need to manually commit() changes before exiting from the generator') def test10(self): @db_session def f(id1): @@ -149,19 +148,14 @@ def f(id1): yield a1.amount a1.amount += 100 + with db_session: + a = self.Account[1].amount for amount in f(1): pass + with db_session: + b = self.Account[1].amount - def test11(self): - @db_session - def f(id1): - a1 = self.Account[id1] - yield a1.amount - a1.amount += 100 - commit() - - for amount in f(1): - pass + self.assertEqual(b, a + 100) def test12(self): @db_session From e19013d1bb6e9c89eb79477923375a6e6e42f727 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Tue, 6 Nov 2018 16:03:24 +0300 Subject: [PATCH 394/547] Composite index bug fixed: https://stackoverflow.com/questions/53147694/pony-orm-why-do-i-get-an-assertion-error-when-using-the-exists-command --- pony/orm/core.py | 3 ++- pony/orm/tests/test_indexes.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 4484f7ec9..b1b614bbd 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4637,7 +4637,8 @@ def _db_set_(obj, avdict, unpickling=False): for i, attr in enumerate(attrs): if attr in avdict: vals[i] = avdict[attr] new_vals = tuple(vals) - cache.db_update_composite_index(obj, attrs, prev_vals, new_vals) + if prev_vals != new_vals: + cache.db_update_composite_index(obj, attrs, prev_vals, new_vals) for attr, new_val in iteritems(avdict): if not attr.reverse: diff --git a/pony/orm/tests/test_indexes.py b/pony/orm/tests/test_indexes.py index f6399a147..d11febca8 100644 --- a/pony/orm/tests/test_indexes.py +++ b/pony/orm/tests/test_indexes.py @@ -92,5 +92,24 @@ class Person(db.Entity): p1.set(name='John', age=19) p1.delete() + def test_5(self): + db = Database('sqlite', ':memory:') + + class Table1(db.Entity): + name = Required(str) + table2s = Set('Table2') + + class Table2(db.Entity): + height = Required(int) + length = Required(int) + table1 = Optional('Table1') + composite_key(height, length, table1) + + db.generate_mapping(create_tables=True) + + with db_session: + Table2(height=2, length=1) + Table2.exists(height=2, length=1) + if __name__ == '__main__': unittest.main() From 29bbf4a8f7e45160847462555e5d788af5b563f0 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 29 Oct 2018 21:50:56 +0300 Subject: [PATCH 395/547] fix db_session(sql_debug=True): it should log SQL commands also during db_session.__exit__() --- pony/orm/core.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index b1b614bbd..7be0e903a 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -426,11 +426,15 @@ def _enter(db_session): if db_session.sql_debug is not None: local.push_debug_state(db_session.sql_debug, db_session.show_values) def __exit__(db_session, exc_type=None, exc=None, tb=None): - if db_session.sql_debug is not None: - local.pop_debug_state() local.db_context_counter -= 1 - if local.db_context_counter: return - assert local.db_session is db_session + try: + if not local.db_context_counter: + assert local.db_session is db_session + db_session._commit_or_rollback(exc_type, exc, tb) + finally: + if db_session.sql_debug is not None: + local.pop_debug_state() + def _commit_or_rollback(db_session, exc_type, exc, tb): try: if exc_type is None: can_commit = True elif not callable(db_session.allowed_exceptions): From 65d4a41f28603b598750b4e90b1c4f52bd178a2a Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 23 Oct 2018 18:04:32 +0300 Subject: [PATCH 396/547] Many-to-many collection loading bug fixed --- pony/orm/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 7be0e903a..7126978a4 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2795,13 +2795,13 @@ def load(attr, obj, items=None): else: d[obj] = {rentity._get_by_raw_pkval_(row) for row in cursor.fetchall()} for obj2, items in iteritems(d): setdata2 = obj2._vals_.get(attr) - if setdata2 is None: setdata2 = obj._vals_[attr] = SetData() + if setdata2 is None: setdata2 = obj2._vals_[attr] = SetData() else: phantoms = setdata2 - items if setdata2.added: phantoms -= setdata2.added if phantoms: throw(UnrepeatableReadError, 'Phantom object %s disappeared from collection %s.%s' - % (safe_repr(phantoms.pop()), safe_repr(obj), attr.name)) + % (safe_repr(phantoms.pop()), safe_repr(obj2), attr.name)) items -= setdata2 if setdata2.removed: items -= setdata2.removed setdata2 |= items From a439cb9ef0f34aa160aa59abb906ca0c6926d637 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 7 Aug 2018 15:06:40 +0300 Subject: [PATCH 397/547] Move code around --- pony/orm/core.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 7126978a4..e67abf5fb 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5954,12 +5954,6 @@ def random(query, limit): def to_json(query, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): return query._database.to_json(query[:], include, exclude, converter, with_schema, schema_hash) -def strcut(s, width): - if len(s) <= width: - return s + ' ' * (width - len(s)) - else: - return s[:width-3] + '...' - class QueryResultIterator(object): __slots__ = '_query_result', '_position' @@ -6147,6 +6141,13 @@ def to_list(self): remove = make_query_result_method_error_stub('remove', 'remove') +def strcut(s, width): + if len(s) <= width: + return s + ' ' * (width - len(s)) + else: + return s[:width-3] + '...' + + @cut_traceback def show(entity): x = entity From bfb3a72c938c5390383d934fc24d3a468d8169e7 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Tue, 6 Nov 2018 19:02:11 +0300 Subject: [PATCH 398/547] py_json_unwrap fix --- pony/orm/dbproviders/sqlite.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index dc356fc6f..e891c6983 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -428,6 +428,8 @@ def func(value): def py_json_unwrap(value): # [null,some-value] -> some-value + if value is None: + return None assert value.startswith('[null,'), value return value[6:-1] From 6e041aa766652ccd901a17878526eb79fa8a97eb Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Tue, 23 Oct 2018 16:02:01 +0300 Subject: [PATCH 399/547] Array type support added for PostgreSQL and SQLite --- pony/orm/core.py | 18 +++-- pony/orm/dbapiprovider.py | 53 +++++++++++++- pony/orm/dbproviders/postgres.py | 27 ++++++- pony/orm/dbproviders/sqlite.py | 79 ++++++++++++++++++++- pony/orm/dbschema.py | 3 + pony/orm/ormtypes.py | 26 +++++++ pony/orm/sqlbuilding.py | 12 ++++ pony/orm/sqltranslation.py | 117 +++++++++++++++++++++++++++++-- 8 files changed, 319 insertions(+), 16 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index e67abf5fb..dd138cb64 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -20,7 +20,10 @@ import pony from pony import options from pony.orm.decompiling import decompile -from pony.orm.ormtypes import LongStr, LongUnicode, numeric_types, RawSQL, normalize, Json, TrackedValue, QueryType +from pony.orm.ormtypes import ( + LongStr, LongUnicode, numeric_types, RawSQL, normalize, Json, TrackedValue, QueryType, + Array, IntArray, StrArray, FloatArray + ) from pony.orm.asttranslation import ast2src, create_extractors, TranslationError from pony.orm.dbapiprovider import ( DBAPIProvider, DBException, Warning, Error, InterfaceError, DatabaseError, DataError, @@ -54,7 +57,7 @@ 'composite_key', 'composite_index', 'flush', 'commit', 'rollback', 'db_session', 'with_transaction', - 'LongStr', 'LongUnicode', 'Json', + 'LongStr', 'LongUnicode', 'Json', 'IntArray', 'StrArray', 'FloatArray', 'select', 'left_join', 'get', 'exists', 'delete', @@ -1046,8 +1049,11 @@ def get_columns(table, column_names): child_columns = get_columns(table, attr.columns) table.add_foreign_key(attr.reverse.fk_name, child_columns, parent_table, parent_columns, attr.index) elif attr.index and attr.columns: - columns = tuple(imap(table.column_dict.__getitem__, attr.columns)) - table.add_index(attr.index, columns, is_unique=attr.is_unique) + if isinstance(attr.py_type, Array) and provider.dialect != 'PostgreSQL': + pass # GIN indexes are supported only in PostgreSQL + else: + columns = tuple(imap(table.column_dict.__getitem__, attr.columns)) + table.add_index(attr.index, columns, is_unique=attr.is_unique) entity._initialize_bits_() if create_tables: database.create_tables(check_tables) @@ -1952,7 +1958,7 @@ def __init__(attr, py_type, *args, **kwargs): if attr.is_pk: attr.pk_offset = 0 else: attr.pk_offset = None attr.id = next(attr_id_counter) - if not isinstance(py_type, (type, basestring, types.FunctionType)): + if not isinstance(py_type, (type, basestring, types.FunctionType, Array)): if py_type is datetime: throw(TypeError, 'datetime is the module and cannot be used as attribute type. Use datetime.datetime instead') throw(TypeError, 'Incorrect type of attribute: %r' % py_type) @@ -3693,7 +3699,7 @@ def _link_reverse_attrs_(entity): database = entity._database_ for attr in entity._new_attrs_: py_type = attr.py_type - if not issubclass(py_type, Entity): continue + if not isinstance(py_type, EntityMeta): continue entity2 = py_type if entity2._database_ is not database: diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 33b2fe164..aadd20532 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -9,7 +9,7 @@ import pony from pony.utils import is_utf8, decorator, throw, localbase, deprecated from pony.converting import str2date, str2time, str2datetime, str2timedelta -from pony.orm.ormtypes import LongStr, LongUnicode, RawSQLType, TrackedValue, Json, QueryType +from pony.orm.ormtypes import LongStr, LongUnicode, RawSQLType, TrackedValue, TrackedList, Json, QueryType, Array class DBException(Exception): def __init__(exc, original_exc, *args): @@ -101,6 +101,7 @@ class DBAPIProvider(object): dbschema_cls = None translator_cls = None sqlbuilder_cls = None + array_converter_cls = None name_before_table = 'schema_name' default_schema_name = None @@ -270,6 +271,11 @@ def _get_converter_type_by_py_type(provider, py_type): if isinstance(py_type, type): for t, converter_cls in provider.converter_classes: if issubclass(py_type, t): return converter_cls + if isinstance(py_type, Array): + converter_cls = provider.array_converter_cls + if converter_cls is None: + throw(NotImplementedError, 'Array type is not supported for %r' % provider.dialect) + return converter_cls if isinstance(py_type, RawSQLType): return Converter # for cases like select(raw_sql(...) for x in X) throw(TypeError, 'No database converter found for type %s' % py_type) @@ -803,3 +809,48 @@ def dbvals_equal(converter, x, y): return x == y def sql_type(converter): return "JSON" + +class ArrayConverter(Converter): + array_types = { + int: ('int', IntConverter), + unicode: ('text', StrConverter), + float: ('real', RealConverter) + } + + def __init__(converter, provider, py_type, attr=None): + Converter.__init__(converter, provider, py_type, attr) + converter.item_converter = converter.array_types[converter.py_type.item_type][1] + + def validate(converter, val, obj=None): + if obj is None or converter.attr is None: + return val + if isinstance(val, TrackedValue) and val.obj_ref() is obj and val.attr is converter.attr: + return val + + if isinstance(val, basestring) or not hasattr(val, '__len__'): + val = [val] + else: + val = list(val) + item_type = converter.py_type.item_type + if item_type == float: + item_type = (float, int) + for i, v in enumerate(val): + if not isinstance(v, item_type): + if hasattr(v, '__index__'): + val[i] = v.__index__() + else: + throw(TypeError, 'Cannot store %s item in array of %s' % + (type(v).__name__, converter.py_type.item_type.__name__)) + + return TrackedList(obj, converter.attr, val) + + def dbval2val(converter, dbval, obj=None): + if obj is None: + return dbval + return TrackedList(obj, converter.attr, dbval) + + def val2dbval(converter, val, obj=None): + return list(val) + + def sql_type(converter): + return '%s[]' % converter.array_types[converter.py_type.item_type][0] diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index eb332de94..c93d801cc 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -27,7 +27,7 @@ from pony.orm.core import log_orm from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions from pony.orm.sqltranslation import SQLTranslator -from pony.orm.sqlbuilding import Value, SQLBuilder +from pony.orm.sqlbuilding import Value, SQLBuilder, join from pony.converting import timedelta2str from pony.utils import is_ident @@ -121,6 +121,23 @@ def GROUP_CONCAT(builder, distinct, expr, sep=None): else: result = result, ", ','" return result, ')' + def ARRAY_INDEX(builder, col, index): + return builder(col), '[', builder(index), ']' + def ARRAY_CONTAINS(builder, key, not_in, col): + if not_in: + return builder(key), ' <> ALL(', builder(col), ')' + return builder(key), ' = ANY(', builder(col), ')' + def ARRAY_SUBSET(builder, array1, not_in, array2): + result = builder(array1), ' <@ ', builder(array2) + if not_in: + result = 'NOT (', result, ')' + return result + def ARRAY_LENGTH(builder, array): + return 'COALESCE(ARRAY_LENGTH(', builder(array), ', 1), 0)' + def ARRAY_SLICE(builder, array, start, stop): + return builder(array), '[', builder(start) if start else '', ':', builder(stop) if stop else '', ']' + def MAKE_ARRAY(builder, *items): + return 'ARRAY[', join(', ', (builder(item) for item in items)), ']' class PGStrConverter(dbapiprovider.StrConverter): @@ -157,6 +174,13 @@ class PGJsonConverter(dbapiprovider.JsonConverter): def sql_type(self): return "JSONB" +class PGArrayConverter(dbapiprovider.ArrayConverter): + array_types = { + int: ('int', PGIntConverter), + unicode: ('text', PGStrConverter), + float: ('double precision', PGRealConverter) + } + class PGPool(Pool): def _connect(pool): pool.con = pool.dbapi_module.connect(*pool.args, **pool.kwargs) @@ -184,6 +208,7 @@ class PGProvider(DBAPIProvider): dbschema_cls = PGSchema translator_cls = PGTranslator sqlbuilder_cls = PGSQLBuilder + array_converter_cls = PGArrayConverter default_schema_name = 'public' diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index e891c6983..207b3fb7c 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -14,7 +14,7 @@ from pony.orm import core, dbschema, dbapiprovider from pony.orm.core import log_orm -from pony.orm.ormtypes import Json +from pony.orm.ormtypes import Json, TrackedList from pony.orm.sqltranslation import SQLTranslator, StringExprMonad from pony.orm.sqlbuilding import SQLBuilder, join, make_unary_func from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions @@ -151,6 +151,20 @@ def JSON_ARRAY_LENGTH(builder, value): def JSON_CONTAINS(builder, expr, path, key): path_sql, has_params, has_wildcards = builder.build_json_path(path) return 'py_json_contains(', builder(expr), ', ', path_sql, ', ', builder(key), ')' + def ARRAY_INDEX(builder, col, index): + return 'py_array_index(', builder(col), ', ', builder(index), ')' + def ARRAY_CONTAINS(builder, key, not_in, col): + return ('NOT ' if not_in else ''), 'py_array_contains(', builder(col), ', ', builder(key), ')' + def ARRAY_SUBSET(builder, array1, not_in, array2): + return ('NOT ' if not_in else ''), 'py_array_subset(', builder(array2), ', ', builder(array1), ')' + def ARRAY_LENGTH(builder, array): + return 'py_array_length(', builder(array), ')' + def ARRAY_SLICE(builder, array, start, stop): + return 'py_array_slice(', builder(array), ', ', \ + builder(start) if start else 'null', ',',\ + builder(stop) if stop else 'null', ')' + def MAKE_ARRAY(builder, *items): + return 'py_make_array(', join(', ', (builder(item) for item in items)), ')' class SQLiteIntConverter(dbapiprovider.IntConverter): def sql_type(converter): @@ -206,6 +220,24 @@ def py2sql(converter, val): class SQLiteJsonConverter(dbapiprovider.JsonConverter): json_kwargs = {'separators': (',', ':'), 'sort_keys': True, 'ensure_ascii': False} +def dumps(items): + return json.dumps(items, **SQLiteJsonConverter.json_kwargs) + +class SQLiteArrayConverter(dbapiprovider.ArrayConverter): + array_types = { + int: ('int', SQLiteIntConverter), + unicode: ('text', dbapiprovider.StrConverter), + float: ('real', dbapiprovider.RealConverter) + } + + def dbval2val(converter, dbval, obj=None): + items = json.loads(dbval) if dbval else [] + if obj is None: + return items + return TrackedList(obj, converter.attr, items) + + def val2dbval(converter, val, obj=None): + return dumps(val) class LocalExceptions(localbase): def __init__(self): @@ -240,6 +272,7 @@ class SQLiteProvider(DBAPIProvider): dbschema_cls = SQLiteSchema translator_cls = SQLiteTranslator sqlbuilder_cls = SQLiteBuilder + array_converter_cls = SQLiteArrayConverter name_before_table = 'db_name' @@ -511,6 +544,43 @@ def py_json_array_length(expr, path=None): expr = _traverse(expr, keys) return len(expr) if type(expr) is list else 0 +def wrap_array_func(func): + @wraps(func) + def new_func(array, *args): + if array is None: + return None + array = json.loads(array) + return func(array, *args) + return new_func + +@wrap_array_func +def py_array_index(array, index): + try: + return array[index] + except IndexError: + return None + +@wrap_array_func +def py_array_contains(array, item): + return item in array + +@wrap_array_func +def py_array_subset(array, items): + if items is None: return None + items = json.loads(items) + return set(items).issubset(set(array)) + +@wrap_array_func +def py_array_length(array): + return len(array) + +@wrap_array_func +def py_array_slice(array, start, stop): + return dumps(array[start:stop]) + +def py_make_array(*items): + return dumps(items) + class SQLitePool(Pool): def __init__(pool, filename, create_db, **kwargs): # called separately in each thread pool.filename = filename @@ -538,6 +608,13 @@ def create_function(name, num_params, func): create_function('py_json_nonzero', 2, py_json_nonzero) create_function('py_json_array_length', -1, py_json_array_length) + create_function('py_array_index', 2, py_array_index) + create_function('py_array_contains', 2, py_array_contains) + create_function('py_array_subset', 2, py_array_subset) + create_function('py_array_length', 1, py_array_length) + create_function('py_array_slice', 3, py_array_slice) + create_function('py_make_array', -1, py_make_array) + if sqlite.sqlite_version_info >= (3, 6, 19): con.execute('PRAGMA foreign_keys = true') diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index 0fd971ae8..5079b49f6 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -307,6 +307,9 @@ def _get_create_sql(index, inside_table): append(quote_name(index.name)) append(case('ON')) append(quote_name(index.table.name)) + converter = index.columns[0].converter + if isinstance(converter.py_type, core.Array) and converter.provider.dialect == 'PostgreSQL': + append(case('USING GIN')) else: if index.name: append(case('CONSTRAINT')) diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 5ab634a0e..5fc576732 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -348,3 +348,29 @@ def __init__(self, wrapped): def __repr__(self): return '' % self.wrapped + +class Array(object): + def __init__(self, item_type): + if item_type not in(unicode, int, float): + throw(NotImplementedError, 'Only int, float and str types are supported. Got: `Array(%r)`' % item_type) + self.item_type = item_type + + def __repr__(self): + return 'Array(%s)' % self.item_type.__name__ + + def __deepcopy__(self, memo): + return self + + def __eq__(self, other): + return type(other) is Array and self.item_type == other.item_type + + def __ne__(self, other): + return not self.__eq__(other) + + def __hash__(self): + return hash(self.item_type) + +IntArray = Array(int) +StrArray = Array(unicode) +FloatArray = Array(float) + diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index b38034388..ac61689da 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -605,3 +605,15 @@ def JSON_ARRAY_LENGTH(builder, value): throw(NotImplementedError) def JSON_PARAM(builder, expr): return builder(expr) + def ARRAY_INDEX(builder, col, index): + throw(NotImplementedError) + def ARRAY_CONTAINS(builder, key, not_in, col): + throw(NotImplementedError) + def ARRAY_SUBSET(builder, array1, not_in, array2): + throw(NotImplementedError) + def ARRAY_LENGTH(builder, array): + throw(NotImplementedError) + def ARRAY_SLICE(builder, array, start, stop): + throw(NotImplementedError) + def MAKE_ARRAY(builder, *items): + throw(NotImplementedError) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 53ccba820..59dc68d7a 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -18,7 +18,7 @@ from pony.orm.ormtypes import \ numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ normalize, normalize_type, coerce_types, are_comparable_types, \ - Json, QueryType + Json, QueryType, Array from pony.orm import core from pony.orm.core import EntityMeta, Set, JOIN, OptimizationFailed, Attribute, DescWrapper, \ special_functions, const_functions, extract_vars, Query, UseAnotherTranslator @@ -141,12 +141,32 @@ def dispatch_external(translator, node): monad = ConstMonad.new(value) elif tt is tuple: params = [] - for i, item_type in enumerate(t): - if item_type is NoneType: - throw(TypeError, 'Expression `%s` should not contain None values' % node.src) - param = ParamMonad.new(item_type, (varkey, i, None)) - params.append(param) - monad = ListMonad(params) + is_array = False + if translator.database.provider.array_converter_cls is None: + types = set(t) + if len(types) == 1 and unicode in types: + item_type = unicode + is_array = True + else: + item_type = int + for type_ in types: + if type_ is float: + item_type = float + if type_ not in (float, int) or not hasattr(type_, '__index__'): + break + else: + is_array = True + + if is_array: + array_type = Array(item_type) + monad = ArrayParamMonad(array_type, (varkey, None, None)) + else: + for i, item_type in enumerate(t): + if item_type is NoneType: + throw(TypeError, 'Expression `%s` should not contain None values' % node.src) + param = ParamMonad.new(item_type, (varkey, i, None)) + params.append(param) + monad = ListMonad(params) elif isinstance(t, RawSQLType): monad = RawSQLMonad(t, varkey) else: @@ -2017,6 +2037,83 @@ def cast_from_json(monad, type): def nonzero(monad): return BoolExprMonad([ 'JSON_NONZERO', monad.getsql()[0] ]) +class ArrayMixin(MonadMixin): + def contains(monad, key, not_in=False): + if key.type is monad.type.item_type: + sql = 'ARRAY_CONTAINS', key.getsql()[0], not_in, monad.getsql()[0] + return BoolExprMonad(sql) + if isinstance(key, ListMonad): + sql = [ 'MAKE_ARRAY' ] + sql.extend(item.getsql()[0] for item in key.items) + sql = 'ARRAY_SUBSET', sql, not_in, monad.getsql()[0] + return BoolExprMonad(sql) + elif isinstance(key, ArrayParamMonad): + sql = 'ARRAY_SUBSET', key.getsql()[0], not_in, monad.getsql()[0] + return BoolExprMonad(sql) + throw(TypeError, 'Cannot search for %s in %s: {EXPR}' % + (type2str(key.type), type2str(monad.type))) + + def len(monad): + sql = ['ARRAY_LENGTH', monad.getsql()[0]] + return NumericExprMonad(int, sql) + + def nonzero(monad): + return BoolExprMonad(['GT', ['ARRAY_LENGTH', monad.getsql()[0]], ['VALUE', 0]]) + + def __getitem__(monad, index): + if isinstance(index, NumericConstMonad): + expr_sql = monad.getsql()[0] + index = index.getsql()[0] + value = index[1] + if not monad.translator.database.provider.dialect == 'SQLite': + if value >= 0: + index = ['VALUE', value + 1] + else: + index = ['SUB', ['ARRAY_LENGTH', expr_sql], ['VALUE', abs(value) + 1]] + + sql = ['ARRAY_INDEX', expr_sql, index] + return ExprMonad.new(monad.type.item_type, sql) + elif isinstance(index, slice): + if index.step is not None: throw(TypeError, 'Step is not supported in {EXPR}') + start, stop = index.start, index.stop + if start is None and stop is None: + return monad + + if start is not None and start.type is not int: + throw(TypeError, "Invalid type of start index (expected 'int', got %r) in array slice {EXPR}" + % type2str(start.type)) + if stop is not None and stop.type is not int: + throw(TypeError, "Invalid type of stop index (expected 'int', got %r) in array slice {EXPR}" + % type2str(stop.type)) + + if (start is not None and not isinstance(start, NumericConstMonad)) or \ + (stop is not None and not isinstance(stop, NumericConstMonad)): + throw(TypeError, 'Array indices should be type of int') + + expr_sql = monad.getsql()[0] + + if not monad.translator.database.provider.dialect == 'SQLite': + if start is None: + start_sql = None + elif start.value >= 0: + start_sql = ['VALUE', start.value + 1] + else: + start_sql = ['SUB', ['ARRAY_LENGTH', expr_sql], ['VALUE', abs(start.value) + 1]] + + if stop is None: + stop_sql = None + elif stop.value >= 0: + stop_sql = ['VALUE', stop.value + 1] + else: + stop_sql = ['SUB', ['ARRAY_LENGTH', expr_sql], ['VALUE', abs(stop.value) + 1]] + else: + start_sql = None if start is None else ['VALUE', start.value] + stop_sql = None if stop is None else ['VALUE', stop.value] + + sql = ['ARRAY_SLICE', expr_sql, start_sql, stop_sql] + return ExprMonad.new(monad.type, sql) + + class ObjectMixin(MonadMixin): def mixin_init(monad): assert isinstance(monad.type, EntityMeta) @@ -2071,6 +2168,7 @@ def new(parent, attr, *args, **kwargs): elif type is UUID: cls = UuidAttrMonad elif type is Json: cls = JsonAttrMonad elif isinstance(type, EntityMeta): cls = ObjectAttrMonad + elif isinstance(type, Array): cls = ArrayAttrMonad else: throw(NotImplementedError, type) # pragma: no cover return cls(parent, attr, *args, **kwargs) def __new__(cls, *args): @@ -2124,6 +2222,7 @@ class DatetimeAttrMonad(DatetimeMixin, AttrMonad): pass class BufferAttrMonad(BufferMixin, AttrMonad): pass class UuidAttrMonad(UuidMixin, AttrMonad): pass class JsonAttrMonad(JsonMixin, AttrMonad): pass +class ArrayAttrMonad(ArrayMixin, AttrMonad): pass class ParamMonad(Monad): @staticmethod @@ -2138,6 +2237,7 @@ def new(type, paramkey): elif type is buffer: cls = BufferParamMonad elif type is UUID: cls = UuidParamMonad elif type is Json: cls = JsonParamMonad + elif type is Array: cls = ArrayParamMonad elif isinstance(type, EntityMeta): cls = ObjectParamMonad else: throw(NotImplementedError, 'Parameter {EXPR} has unsupported type %r' % (type,)) result = cls(type, paramkey) @@ -2180,6 +2280,7 @@ class TimedeltaParamMonad(TimedeltaMixin, ParamMonad): pass class DatetimeParamMonad(DatetimeMixin, ParamMonad): pass class BufferParamMonad(BufferMixin, ParamMonad): pass class UuidParamMonad(UuidMixin, ParamMonad): pass +class ArrayParamMonad(ArrayMixin, ParamMonad): pass class JsonParamMonad(JsonMixin, ParamMonad): def getsql(monad, sqlquery=None): @@ -2196,6 +2297,7 @@ def new(type, sql, nullable=True): elif type is datetime: cls = DatetimeExprMonad elif type is Json: cls = JsonExprMonad elif isinstance(type, EntityMeta): cls = ObjectExprMonad + elif isinstance(type, Array): cls = ArrayExprMonad else: throw(NotImplementedError, type) # pragma: no cover return cls(type, sql, nullable=nullable) def __new__(cls, *args, **kwargs): @@ -2218,6 +2320,7 @@ class TimeExprMonad(TimeMixin, ExprMonad): pass class TimedeltaExprMonad(TimedeltaMixin, ExprMonad): pass class DatetimeExprMonad(DatetimeMixin, ExprMonad): pass class JsonExprMonad(JsonMixin, ExprMonad): pass +class ArrayExprMonad(ArrayMixin, ExprMonad): pass class JsonItemMonad(JsonMixin, Monad): def __init__(monad, parent, key): From 20015292aab882e437d55822353c2b9cbf1f76b7 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 17 Oct 2018 16:03:56 +0300 Subject: [PATCH 400/547] remove options.PREFETCHING --- pony/options.py | 1 - pony/orm/core.py | 9 ++++----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pony/options.py b/pony/options.py index 8e26fcad6..6c31ab487 100644 --- a/pony/options.py +++ b/pony/options.py @@ -59,7 +59,6 @@ CONSOLE_ENCODING = None # db options -PREFETCHING = True MAX_FETCH_COUNT = None # used for select(...).show() diff --git a/pony/orm/core.py b/pony/orm/core.py index dd138cb64..866592ad4 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2762,7 +2762,7 @@ def load(attr, obj, items=None): counter = cache.collection_statistics.setdefault(attr, 0) nplus1_threshold = attr.nplus1_threshold - prefetching = options.PREFETCHING and not attr.lazy and nplus1_threshold is not None \ + prefetching = not attr.lazy and nplus1_threshold is not None \ and (counter >= nplus1_threshold or cache.noflush_counter) objects = [ obj ] @@ -4514,10 +4514,9 @@ def _load_(obj): seeds = cache.seeds[entity._pk_attrs_] max_batch_size = database.provider.max_params_count // len(entity._pk_columns_) objects = [ obj ] - if options.PREFETCHING: - for seed in seeds: - if len(objects) >= max_batch_size: break - if seed is not obj: objects.append(seed) + for seed in seeds: + if len(objects) >= max_batch_size: break + if seed is not obj: objects.append(seed) sql, adapter, attr_offsets = entity._construct_batchload_sql_(len(objects)) arguments = adapter(objects) cursor = database._exec_sql(sql, arguments) From f5292dac38fa459f1856273182fa46731f884a83 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 7 Aug 2018 15:41:22 +0300 Subject: [PATCH 401/547] Improved query prefetching --- pony/orm/core.py | 289 ++++++++++++++++++++--------- pony/orm/sqltranslation.py | 4 +- pony/orm/tests/test_prefetching.py | 74 +++++++- 3 files changed, 265 insertions(+), 102 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 866592ad4..69d73120b 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -282,6 +282,42 @@ def adapt_sql(sql, paramstyle): adapted_sql_cache[(sql, paramstyle)] = result return result + +class PrefetchContext(object): + def __init__(self, database=None): + self.database = database + self.attrs_to_prefetch_dict = defaultdict(set) + self.entities_to_prefetch = set() + self.relations_to_prefetch_cache = {} + def copy(self): + result = PrefetchContext(self.database) + result.attrs_to_prefetch_dict = self.attrs_to_prefetch_dict.copy() + result.entities_to_prefetch = self.entities_to_prefetch.copy() + return result + def __enter__(self): + assert local.prefetch_context is None + local.prefetch_context = self + def __exit__(self, exc_type, exc_val, exc_tb): + assert local.prefetch_context is self + local.prefetch_context = None + def get_frozen_attrs_to_prefetch(self, entity): + attrs_to_prefetch = self.attrs_to_prefetch_dict.get(entity, ()) + if type(attrs_to_prefetch) is set: + attrs_to_prefetch = frozenset(attrs_to_prefetch) + self.attrs_to_prefetch_dict[entity] = attrs_to_prefetch + return attrs_to_prefetch + def get_relations_to_prefetch(self, entity): + result = self.relations_to_prefetch_cache.get(entity) + if result is None: + attrs_to_prefetch = self.attrs_to_prefetch_dict[entity] + result = tuple(attr for attr in entity._attrs_ + if attr.is_relation and ( + attr in attrs_to_prefetch or + attr.py_type in self.entities_to_prefetch and not attr.is_collection)) + self.relations_to_prefetch_cache[entity] = result + return result + + class Local(localbase): def __init__(local): local.debug = False @@ -290,6 +326,7 @@ def __init__(local): local.db2cache = {} local.db_context_counter = 0 local.db_session = None + local.prefetch_context = None local.current_user = None local.perms_context = None local.user_groups_cache = {} @@ -2718,6 +2755,63 @@ def validate(attr, val, obj=None, entity=None, from_db=False): if item._session_cache_ is not cache: throw(TransactionError, 'An attempt to mix objects belonging to different transactions') return items + def prefetch_load_all(attr, objects): + entity = attr.entity + database = entity._database_ + cache = database._get_cache() + if cache is None or not cache.is_alive: + throw(DatabaseSessionIsOver, 'Cannot load objects from the database: the database session is over') + reverse = attr.reverse + rentity = reverse.entity + objects = sorted(objects, key=entity._get_raw_pkval_) + max_batch_size = database.provider.max_params_count // len(entity._pk_columns_) + result = set() + if not reverse.is_collection: + for i in xrange(0, len(objects), max_batch_size): + batch = objects[i:i+max_batch_size] + sql, adapter, attr_offsets = rentity._construct_batchload_sql_(len(batch), reverse) + arguments = adapter(batch) + cursor = database._exec_sql(sql, arguments) + result.update(rentity._fetch_objects(cursor, attr_offsets)) + else: + pk_len = len(entity._pk_columns_) + m2m_dict = defaultdict(set) + for i in xrange(0, len(objects), max_batch_size): + batch = objects[i:i+max_batch_size] + sql, adapter = attr.construct_sql_m2m(len(batch)) + arguments = adapter(batch) + cursor = database._exec_sql(sql, arguments) + if len(batch) > 1: + for row in cursor.fetchall(): + obj = entity._get_by_raw_pkval_(row[:pk_len]) + item = rentity._get_by_raw_pkval_(row[pk_len:]) + m2m_dict[obj].add(item) + else: + obj = batch[0] + m2m_dict[obj] = {rentity._get_by_raw_pkval_(row) for row in cursor.fetchall()} + + for obj2, items in iteritems(m2m_dict): + setdata2 = obj2._vals_.get(attr) + if setdata2 is None: setdata2 = obj2._vals_[attr] = SetData() + else: + phantoms = setdata2 - items + if setdata2.added: phantoms -= setdata2.added + if phantoms: throw(UnrepeatableReadError, + 'Phantom object %s disappeared from collection %s.%s' + % (safe_repr(phantoms.pop()), safe_repr(obj2), attr.name)) + items -= setdata2 + if setdata2.removed: items -= setdata2.removed + setdata2 |= items + reverse.db_reverse_add(items, obj2) + result.update(items) + for obj in objects: + setdata = obj._vals_.get(attr) + if setdata is None: + setdata = obj._vals_[attr] = SetData() + setdata.is_fully_loaded = True + setdata.absent = None + setdata.count = len(setdata) + return result def load(attr, obj, items=None): cache = obj._session_cache_ if cache is None or not cache.is_alive: throw_db_session_is_over('load collection', obj, attr) @@ -2728,7 +2822,6 @@ def load(attr, obj, items=None): entity = attr.entity reverse = attr.reverse rentity = reverse.entity - if not reverse: throw(NotImplementedError) database = obj._database_ if cache is not database._get_cache(): throw(TransactionError, "Transaction of object %s belongs to different thread") @@ -3978,11 +4071,12 @@ def _find_by_sql_(entity, max_fetch_count, sql, globals, locals, frame_depth): objects = entity._fetch_objects(cursor, attr_offsets, max_fetch_count) return objects - def _construct_select_clause_(entity, alias=None, distinct=False, - query_attrs=(), attrs_to_prefetch=(), all_attributes=False): + def _construct_select_clause_(entity, alias=None, distinct=False, query_attrs=(), all_attributes=False): attr_offsets = {} select_list = [ 'DISTINCT' ] if distinct else [ 'ALL' ] root = entity._root_ + pc = local.prefetch_context + attrs_to_prefetch = pc.attrs_to_prefetch_dict.get(entity, ()) if pc else () for attr in chain(root._attrs_, root._subclass_attrs_): if not all_attributes and not issubclass(attr.entity, entity) \ and not issubclass(entity, attr.entity): continue @@ -4001,7 +4095,9 @@ def _construct_discriminator_criteria_(entity, alias=None): discr_values.append([ 'VALUE', entity._discriminator_]) return [ 'IN', [ 'COLUMN', alias, discr_attr.column ], discr_values ] def _construct_batchload_sql_(entity, batch_size, attr=None, from_seeds=True): - query_key = batch_size, attr, from_seeds + pc = local.prefetch_context + attrs_to_prefetch = pc.get_frozen_attrs_to_prefetch(entity) if pc is not None else () + query_key = batch_size, attr, from_seeds, attrs_to_prefetch cached_sql = entity._batchload_sql_cache_.get(query_key) if cached_sql is not None: return cached_sql select_list, attr_offsets = entity._construct_select_clause_(all_attributes=True) @@ -4504,6 +4600,20 @@ def __repr__(obj): if obj._pk_is_composite_: pkval = ','.join(imap(repr, pkval)) else: pkval = repr(pkval) return '%s[%s]' % (obj.__class__.__name__, pkval) + @classmethod + def _prefetch_load_all_(entity, objects): + objects = sorted(objects, key=entity._get_raw_pkval_) + database = entity._database_ + cache = database._get_cache() + if cache is None or not cache.is_alive: + throw(DatabaseSessionIsOver, 'Cannot load objects from the database: the database session is over') + max_batch_size = database.provider.max_params_count // len(entity._pk_columns_) + for i in xrange(0, len(objects), max_batch_size): + batch = objects[i:i+max_batch_size] + sql, adapter, attr_offsets = entity._construct_batchload_sql_(len(batch)) + arguments = adapter(batch) + cursor = database._exec_sql(sql, arguments) + entity._fetch_objects(cursor, attr_offsets) def _load_(obj): cache = obj._session_cache_ if cache is None or not cache.is_alive: throw_db_session_is_over('load object', obj) @@ -5438,8 +5548,7 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False query._for_update = query._nowait = False query._distinct = None query._prefetch = False - query._entities_to_prefetch = set() - query._attrs_to_prefetch_dict = defaultdict(set) + query._prefetch_context = PrefetchContext(query._database) def _get_type_(query): return QueryType(query) def _normalize_var(query, query_type): @@ -5477,8 +5586,9 @@ def _get_translator(query, query_key, vars): def _construct_sql_and_arguments(query, limit=None, offset=None, range=None, aggr_func_name=None, aggr_func_distinct=None, sep=None): translator = query._translator expr_type = translator.expr_type - if isinstance(expr_type, EntityMeta) and query._attrs_to_prefetch_dict: - attrs_to_prefetch = tuple(sorted(query._attrs_to_prefetch_dict.get(expr_type, ()))) + attrs_to_prefetch_dict = query._prefetch_context.attrs_to_prefetch_dict + if isinstance(expr_type, EntityMeta) and attrs_to_prefetch_dict: + attrs_to_prefetch = tuple(sorted(attrs_to_prefetch_dict.get(expr_type, ()))) else: attrs_to_prefetch = () sql_key = HashableDict( @@ -5499,7 +5609,7 @@ def _construct_sql_and_arguments(query, limit=None, offset=None, range=None, agg if cache_entry is None: sql_ast, attr_offsets = translator.construct_sql_ast( limit, offset, query._distinct, aggr_func_name, aggr_func_distinct, sep, - query._for_update, query._nowait, attrs_to_prefetch) + query._for_update, query._nowait) cache = database._get_cache() sql, adapter = database.provider.ast2sql(sql_ast) cache_entry = sql, adapter, attr_offsets @@ -5518,115 +5628,110 @@ def get_sql(query): return sql def _actual_fetch(query, limit=None, offset=None): translator = query._translator - sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments(limit, offset) - database = query._database - cache = database._get_cache() - if query._for_update: cache.immediate = True - cache.prepare_connection_for_query_execution() # may clear cache.query_results - items = cache.query_results.get(query_key) - if items is None: - cursor = database._exec_sql(sql, arguments) - if isinstance(translator.expr_type, EntityMeta): - entity = translator.expr_type - items = entity._fetch_objects(cursor, attr_offsets, for_update=query._for_update, - used_attrs=translator.get_used_attrs()) - elif len(translator.row_layout) == 1: - func, slice_or_offset, src = translator.row_layout[0] - items = list(starmap(func, cursor.fetchall())) + with query._prefetch_context: + sql, arguments, attr_offsets, query_key = query._construct_sql_and_arguments(limit, offset) + database = query._database + cache = database._get_cache() + if query._for_update: cache.immediate = True + cache.prepare_connection_for_query_execution() # may clear cache.query_results + items = cache.query_results.get(query_key) + if items is None: + cursor = database._exec_sql(sql, arguments) + if isinstance(translator.expr_type, EntityMeta): + entity = translator.expr_type + items = entity._fetch_objects(cursor, attr_offsets, for_update=query._for_update, + used_attrs=translator.get_used_attrs()) + elif len(translator.row_layout) == 1: + func, slice_or_offset, src = translator.row_layout[0] + items = list(starmap(func, cursor.fetchall())) + else: + items = [ tuple(func(sql_row[slice_or_offset]) + for func, slice_or_offset, src in translator.row_layout) + for sql_row in cursor.fetchall() ] + for i, t in enumerate(translator.expr_type): + if isinstance(t, EntityMeta) and t._subclasses_: t._load_many_(row[i] for row in items) + if query_key is not None: cache.query_results[query_key] = items else: - items = [ tuple(func(sql_row[slice_or_offset]) - for func, slice_or_offset, src in translator.row_layout) - for sql_row in cursor.fetchall() ] - for i, t in enumerate(translator.expr_type): - if isinstance(t, EntityMeta) and t._subclasses_: t._load_many_(row[i] for row in items) - if query_key is not None: cache.query_results[query_key] = items - else: - stats = database._dblocal.stats - stat = stats.get(sql) - if stat is not None: stat.cache_count += 1 - else: stats[sql] = QueryStat(sql) - if query._prefetch: query._do_prefetch(items) + stats = database._dblocal.stats + stat = stats.get(sql) + if stat is not None: stat.cache_count += 1 + else: stats[sql] = QueryStat(sql) + if query._prefetch: query._do_prefetch(items) return items @cut_traceback def prefetch(query, *args): - query = query._clone(_entities_to_prefetch=query._entities_to_prefetch.copy(), - _attrs_to_prefetch_dict=query._attrs_to_prefetch_dict.copy()) + query = query._clone(_prefetch_context=query._prefetch_context.copy()) query._prefetch = True + prefetch_context = query._prefetch_context for arg in args: if isinstance(arg, EntityMeta): entity = arg if query._database is not entity._database_: throw(TypeError, 'Entity %s belongs to different database and cannot be prefetched' % entity.__name__) - query._entities_to_prefetch.add(entity) + prefetch_context.entities_to_prefetch.add(entity) elif isinstance(arg, Attribute): attr = arg entity = attr.entity if query._database is not entity._database_: throw(TypeError, 'Entity of attribute %s belongs to different database and cannot be prefetched' % attr) if isinstance(attr.py_type, EntityMeta) or attr.lazy: - query._attrs_to_prefetch_dict[entity].add(attr) + prefetch_context.attrs_to_prefetch_dict[entity].add(attr) else: throw(TypeError, 'Argument of prefetch() query method must be entity class or attribute. ' 'Got: %r' % arg) return query - def _do_prefetch(query, result): + def _do_prefetch(query, query_result): expr_type = query._translator.expr_type - object_list = [] - object_set = set() - append_to_object_list = object_list.append - add_to_object_set = object_set.add + all_objects = set() + objects_to_process = set() + objects_to_prefetch = set() if isinstance(expr_type, EntityMeta): - for obj in result: - if obj not in object_set: - add_to_object_set(obj) - append_to_object_list(obj) + objects_to_process.update(query_result) + all_objects.update(query_result) elif type(expr_type) is tuple: - for i, t in enumerate(expr_type): - if not isinstance(t, EntityMeta): continue - for row in result: - obj = row[i] - if obj not in object_set: - add_to_object_set(obj) - append_to_object_list(obj) + obj_indexes = [ i for i, t in enumerate(expr_type) if isinstance(t, EntityMeta) ] + if obj_indexes: + for row in query_result: + objects_to_prefetch.update(row[i] for i in obj_indexes) + all_objects.update(objects_to_prefetch) + + prefetch_context = local.prefetch_context + assert prefetch_context + collection_prefetch_dict = defaultdict(set) + + objects_to_prefetch_dict = defaultdict(set) + while objects_to_process or objects_to_prefetch: + for obj in objects_to_process: + entity = obj.__class__ + relations_to_prefetch = prefetch_context.get_relations_to_prefetch(entity) + for attr in relations_to_prefetch: + if attr.is_collection: + collection_prefetch_dict[attr].add(obj) + else: + obj2 = attr.get(obj) + if obj2 not in all_objects: + all_objects.add(obj2) + objects_to_prefetch.add(obj2) + + next_objects_to_process = set() + for attr, objects in collection_prefetch_dict.items(): + items = attr.prefetch_load_all(objects) + if attr.reverse.is_collection: + objects_to_prefetch.update(items) + else: + next_objects_to_process.update(item for item in items if item not in all_objects) + collection_prefetch_dict.clear() - cache = query._database._get_cache() - entities_to_prefetch = query._entities_to_prefetch - attrs_to_prefetch_dict = query._attrs_to_prefetch_dict - prefetching_attrs_cache = {} - for obj in object_list: - entity = obj.__class__ - if obj in cache.seeds[entity._pk_attrs_]: obj._load_() + for obj in objects_to_prefetch: + objects_to_prefetch_dict[obj.__class__._root_].add(obj) + objects_to_prefetch.clear() - all_attrs_to_prefetch = prefetching_attrs_cache.get(entity) - if all_attrs_to_prefetch is None: - all_attrs_to_prefetch = [] - append = all_attrs_to_prefetch.append - attrs_to_prefetch = attrs_to_prefetch_dict[entity] - for attr in obj._attrs_: - if attr.is_collection: - if attr in attrs_to_prefetch: append(attr) - elif attr.is_relation: - if attr in attrs_to_prefetch or attr.py_type in entities_to_prefetch: append(attr) - elif attr.lazy: - if attr in attrs_to_prefetch: append(attr) - prefetching_attrs_cache[entity] = all_attrs_to_prefetch + for entity, objects in objects_to_prefetch_dict.items(): + next_objects_to_process.update(objects) + entity._prefetch_load_all_(objects) + objects_to_prefetch_dict.clear() - for attr in all_attrs_to_prefetch: - if attr.is_collection: - if not isinstance(attr, Set): throw(NotImplementedError) - setdata = obj._vals_.get(attr) - if setdata is None or not setdata.is_fully_loaded: setdata = attr.load(obj) - for obj2 in setdata: - if obj2 not in object_set: - add_to_object_set(obj2) - append_to_object_list(obj2) - elif attr.is_relation: - obj2 = attr.get(obj) - if obj2 is not None and obj2 not in object_set: - add_to_object_set(obj2) - append_to_object_list(obj2) - elif attr.lazy: attr.get(obj) - else: assert False # pragma: no cover + objects_to_process = next_objects_to_process @cut_traceback def show(query, width=None): query._fetch().show(width) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 59dc68d7a..14f05e9f6 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -651,7 +651,7 @@ def construct_subquery_ast(translator, limit=None, offset=None, aliases=None, st return [ 'SELECT', select_ast, from_ast, where_ast ] + other_ast def construct_sql_ast(translator, limit=None, offset=None, distinct=None, aggr_func_name=None, aggr_func_distinct=None, sep=None, - for_update=False, nowait=False, attrs_to_prefetch=(), is_not_null_checks=False): + for_update=False, nowait=False, is_not_null_checks=False): attr_offsets = None if distinct is None: distinct = translator.distinct ast_transformer = lambda ast: ast @@ -720,7 +720,7 @@ def ast_transformer(ast): elif isinstance(translator.expr_type, EntityMeta) and not translator.parent \ and not translator.aggregated and not translator.optimize: select_ast, attr_offsets = translator.expr_type._construct_select_clause_( - translator.alias, distinct, translator.tableref.used_attrs, attrs_to_prefetch) + translator.alias, distinct, translator.tableref.used_attrs) sql_ast.append(select_ast) sql_ast.append(translator.sqlquery.from_ast) diff --git a/pony/orm/tests/test_prefetching.py b/pony/orm/tests/test_prefetching.py index ff3fda31f..f5f4d40e3 100644 --- a/pony/orm/tests/test_prefetching.py +++ b/pony/orm/tests/test_prefetching.py @@ -18,7 +18,7 @@ class Student(db.Entity): class Group(db.Entity): number = PrimaryKey(int) - major = Required(str) + major = Required(str, lazy=True) students = Set(Student) class Course(db.Entity): @@ -33,9 +33,11 @@ class Course(db.Entity): c1 = Course(name='Math') c2 = Course(name='Physics') c3 = Course(name='Computer Science') - Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='some text') - Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) - Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) + Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='S1 bio') + Student(id=2, name='S2', group=g1, gpa=4.2, scholarship=100, dob=date(2000, 1, 1), biography='S2 bio') + Student(id=3, name='S3', group=g1, gpa=4.7, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) + Student(id=4, name='S4', group=g2, gpa=3.2, biography='S4 bio', courses=[c1, c3]) + Student(id=5, name='S5', group=g2, gpa=4.5, biography='S5 bio', courses=[c1, c3]) class TestPrefetching(unittest.TestCase): def test_1(self): @@ -53,13 +55,13 @@ def test_2(self): def test_3(self): with db_session: - s1 = Student.select().prefetch(Group).first() + s1 = Student.select().prefetch(Group, Group.major).first() g = s1.group self.assertEqual(g.major, 'Math') def test_4(self): with db_session: - s1 = Student.select().prefetch(Student.group).first() + s1 = Student.select().prefetch(Student.group, Group.major).first() g = s1.group self.assertEqual(g.major, 'Math') @@ -76,7 +78,7 @@ def test_6(self): def test_7(self): with db_session: - name, group = select((s.name, s.group) for s in Student).prefetch(Group).first() + name, group = select((s.name, s.group) for s in Student).prefetch(Group, Group.major).first() self.assertEqual(group.major, 'Math') @raises_exception(DatabaseSessionIsOver, 'Cannot load collection Student[1].courses: the database session is over') @@ -105,11 +107,67 @@ def test_11(self): def test_12(self): with db_session: s1 = Student.select().prefetch(Student.biography).first() - self.assertEqual(s1.biography, 'some text') + self.assertEqual(s1.biography, 'S1 bio') self.assertEqual(db.last_sql, '''SELECT "s"."id", "s"."name", "s"."scholarship", "s"."gpa", "s"."dob", "s"."group", "s"."biography" FROM "Student" "s" ORDER BY 1 LIMIT 1''') + def test_13(self): + db.merge_local_stats() + with db_session: + q = select(g for g in Group) + for g in q: # 1 query + for s in g.students: # 2 query + b = s.biography # 5 queries + query_count = sum(stat.db_count for stat in db.local_stats.values()) + self.assertEqual(query_count, 8) + + def test_14(self): + db.merge_local_stats() + with db_session: + q = select(g for g in Group).prefetch(Group.students) + for g in q: # 1 query + for s in g.students: # 1 query + b = s.biography # 5 queries + query_count = sum(stat.db_count for stat in db.local_stats.values()) + self.assertEqual(query_count, 7) + + def test_15(self): + with db_session: + q = select(g for g in Group).prefetch(Group.students) + q[:] + db.merge_local_stats() + with db_session: + q = select(g for g in Group).prefetch(Group.students, Student.biography) + for g in q: # 1 query + for s in g.students: # 1 query + b = s.biography # 0 queries + query_count = sum(stat.db_count for stat in db.local_stats.values()) + self.assertEqual(query_count, 2) + + def test_16(self): + db.merge_local_stats() + with db_session: + q = select(c for c in Course).prefetch(Course.students, Student.biography) + for c in q: # 1 query + for s in c.students: # 2 queries (as it is many-to-many relationship) + b = s.biography # 0 queries + query_count = sum(stat.db_count for stat in db.local_stats.values()) + self.assertEqual(query_count, 3) + + def test_17(self): + db.merge_local_stats() + with db_session: + q = select(c for c in Course).prefetch(Course.students, Student.biography, Group, Group.major) + for c in q: # 1 query + for s in c.students: # 2 queries (as it is many-to-many relationship) + m = s.group.major # 1 query + b = s.biography # 0 queries + query_count = sum(stat.db_count for stat in db.local_stats.values()) + self.assertEqual(query_count, 4) + + + if __name__ == '__main__': unittest.main() From e78d98ed632f1439b91988f9cad0c80ce7bd7b8f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 6 Nov 2018 15:37:44 +0300 Subject: [PATCH 402/547] Increase DBAPIProvider.max_params_count value --- pony/orm/dbapiprovider.py | 2 +- pony/orm/dbproviders/mysql.py | 1 + pony/orm/dbproviders/postgres.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index aadd20532..816dfd6ad 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -85,7 +85,7 @@ def get_version_tuple(s): class DBAPIProvider(object): paramstyle = 'qmark' quote_char = '"' - max_params_count = 200 + max_params_count = 999 max_name_len = 128 table_if_not_exists_syntax = True index_if_not_exists_syntax = True diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 3ea432619..d4da64af3 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -183,6 +183,7 @@ class MySQLProvider(DBAPIProvider): paramstyle = 'format' quote_char = "`" max_name_len = 64 + max_params_count = 10000 table_if_not_exists_syntax = True index_if_not_exists_syntax = False select_for_update_nowait_syntax = False diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index c93d801cc..96dbfece1 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -202,6 +202,7 @@ class PGProvider(DBAPIProvider): dialect = 'PostgreSQL' paramstyle = 'pyformat' max_name_len = 63 + max_params_count = 10000 index_if_not_exists_syntax = False dbapi_module = psycopg2 From 57c862882c1a1e2a24bcf94c85a5b826ece4548f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 7 Nov 2018 13:26:47 +0300 Subject: [PATCH 403/547] Local variable renaming --- pony/orm/core.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 69d73120b..52fb0bbef 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2174,14 +2174,14 @@ def parse_value(attr, row, offsets): if not attr.reverse: if len(offsets) > 1: throw(NotImplementedError) offset = offsets[0] - val = attr.validate(row[offset], None, attr.entity, from_db=True) + dbval = attr.validate(row[offset], None, attr.entity, from_db=True) else: - vals = [ row[offset] for offset in offsets ] - if None in vals: - assert len(set(vals)) == 1 - val = None - else: val = attr.py_type._get_by_raw_pkval_(vals) - return val + dbvals = [ row[offset] for offset in offsets ] + if None in dbvals: + assert len(set(dbvals)) == 1 + dbval = None + else: dbval = attr.py_type._get_by_raw_pkval_(dbvals) + return dbval def load(attr, obj): cache = obj._session_cache_ if cache is None or not cache.is_alive: throw_db_session_is_over('load attribute', obj, attr) From d1f5e795a6c536c3d8bc884d6ea0d10dca2524ac Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 7 Nov 2018 17:02:59 +0300 Subject: [PATCH 404/547] Memory optimization: deduplication of values received from the database in the same session --- pony/orm/core.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 52fb0bbef..b65362661 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1708,6 +1708,7 @@ def __init__(cache, database): cache.objects_to_save = [] cache.saved_objects = [] cache.query_results = {} + cache.dbvals_deduplication_cache = {} cache.modified = False cache.db_session = db_session = local.db_session cache.immediate = db_session is not None and db_session.immediate @@ -1825,7 +1826,7 @@ def close(cache, rollback=True): cache.objects = cache.objects_to_save = cache.saved_objects = cache.query_results \ = cache.indexes = cache.seeds = cache.for_update = cache.max_id_cache \ - = cache.modified_collections = cache.collection_statistics = None + = cache.modified_collections = cache.collection_statistics = cache.dbvals_deduplication_cache = None @contextmanager def flush_disabled(cache): cache.noflush_counter += 1 @@ -2169,12 +2170,14 @@ def validate(attr, val, obj=None, entity=None, from_db=False): if attr.py_check is not None and not attr.py_check(val): throw(ValueError, 'Check for attribute %s failed. Value: %s' % (attr, truncate_repr(val))) return val - def parse_value(attr, row, offsets): + def parse_value(attr, row, offsets, dbvals_deduplication_cache): assert len(attr.columns) == len(offsets) if not attr.reverse: if len(offsets) > 1: throw(NotImplementedError) offset = offsets[0] dbval = attr.validate(row[offset], None, attr.entity, from_db=True) + try: dbval = dbvals_deduplication_cache.setdefault(dbval, dbval) + except: pass else: dbvals = [ row[offset] for offset in offsets ] if None in dbvals: @@ -2211,7 +2214,7 @@ def load(attr, obj): arguments = adapter(obj._get_raw_pkval_()) cursor = database._exec_sql(sql, arguments) row = cursor.fetchone() - dbval = attr.parse_value(row, offsets) + dbval = attr.parse_value(row, offsets, cache.dbvals_deduplication_cache) attr.db_set(obj, dbval) else: obj._load_() return obj._vals_[attr] @@ -4203,11 +4206,14 @@ def _parse_row_(entity, row, attr_offsets): real_entity_subclass = discr_attr.code2cls[discr_value] discr_value = real_entity_subclass._discriminator_ # To convert unicode to str in Python 2.x + database = entity._database_ + cache = local.db2cache[database] + avdict = {} for attr in real_entity_subclass._attrs_: offsets = attr_offsets.get(attr) if offsets is None or attr.is_discriminator: continue - avdict[attr] = attr.parse_value(row, offsets) + avdict[attr] = attr.parse_value(row, offsets, cache.dbvals_deduplication_cache) pkval = tuple(avdict.pop(attr, discr_value) for attr in entity._pk_attrs_) assert None not in pkval From 1cd129cef7592321a6d037c43ea86d38d1152f2a Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 21 Nov 2018 13:12:21 +0300 Subject: [PATCH 405/547] Fixes #404: Google App Engine local run detection --- pony/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pony/__init__.py b/pony/__init__.py index de1834556..18d6d1be5 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -1,6 +1,6 @@ from __future__ import absolute_import, print_function -import sys, time, threading, random +import os, sys, time, threading, random from os.path import dirname from itertools import count @@ -12,9 +12,9 @@ def detect_mode(): try: import google.appengine except ImportError: pass else: - try: import dev_appserver - except ImportError: return 'GAE-SERVER' - return 'GAE-LOCAL' + if os.environ.get('SERVER_SOFTWARE', '').startswith('Development'): + return 'GAE-LOCAL' + return 'GAE-SERVER' try: from mod_wsgi import version except: pass From db964c297971c8f10d56fa8df20a9312222f6500 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 12 Nov 2018 13:42:18 +0300 Subject: [PATCH 406/547] Support for "SELECT ... FOR UPDATE SKIP LOCKED" added --- pony/orm/core.py | 38 +++++++++++++++++++--------------- pony/orm/dbapiprovider.py | 1 - pony/orm/dbproviders/mysql.py | 1 - pony/orm/dbproviders/oracle.py | 8 ++++--- pony/orm/dbproviders/sqlite.py | 5 ++--- pony/orm/sqlbuilding.py | 6 ++++-- pony/orm/sqltranslation.py | 4 ++-- 7 files changed, 34 insertions(+), 29 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index b65362661..e79cdd150 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -3892,8 +3892,12 @@ def get(entity, *args, **kwargs): @cut_traceback def get_for_update(entity, *args, **kwargs): nowait = kwargs.pop('nowait', False) - if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1).for_update(nowait).get() - try: return entity._find_one_(kwargs, True, nowait) # can throw MultipleObjectsFoundError + skip_locked = kwargs.pop('skip_locked', False) + if nowait and skip_locked: + throw(TypeError, 'nowait and skip_locked options are mutually exclusive') + if args: return entity._query_from_args_(args, kwargs, frame_depth=cut_traceback_depth+1) \ + .for_update(nowait, skip_locked).get() + try: return entity._find_one_(kwargs, True, nowait, skip_locked) # can throw MultipleObjectsFoundError except ObjectNotFound: return None @cut_traceback def get_by_sql(entity, sql, globals=None, locals=None): @@ -3969,7 +3973,7 @@ def select_random(entity, limit): if obj in seeds: obj._load_() if found_in_cache: shuffle(result) return result - def _find_one_(entity, kwargs, for_update=False, nowait=False): + def _find_one_(entity, kwargs, for_update=False, nowait=False, skip_locked=False): if entity._database_.schema is None: throw(ERDiagramError, 'Mapping is not generated for entity %r' % entity.__name__) avdict = {} @@ -3986,7 +3990,7 @@ def _find_one_(entity, kwargs, for_update=False, nowait=False): if attr.is_collection: throw(TypeError, 'Collection attribute %s cannot be specified as search criteria' % attr) obj, unique = entity._find_in_cache_(pkval, avdict, for_update) - if obj is None: obj = entity._find_in_db_(avdict, unique, for_update, nowait) + if obj is None: obj = entity._find_in_db_(avdict, unique, for_update, nowait, skip_locked) if obj is None: throw(ObjectNotFound, entity, pkval) return obj def _find_in_cache_(entity, pkval, avdict, for_update=False): @@ -4038,11 +4042,11 @@ def _find_in_cache_(entity, pkval, avdict, for_update=False): entity._set_rbits((obj,), avdict) return obj, unique return None, unique - def _find_in_db_(entity, avdict, unique=False, for_update=False, nowait=False): + def _find_in_db_(entity, avdict, unique=False, for_update=False, nowait=False, skip_locked=False): database = entity._database_ query_attrs = {attr: value is None for attr, value in iteritems(avdict)} limit = 2 if not unique else None - sql, adapter, attr_offsets = entity._construct_sql_(query_attrs, False, limit, for_update, nowait) + sql, adapter, attr_offsets = entity._construct_sql_(query_attrs, False, limit, for_update, nowait, skip_locked) arguments = adapter(avdict) if for_update: database._get_cache().immediate = True cursor = database._exec_sql(sql, arguments) @@ -4120,10 +4124,10 @@ def _construct_batchload_sql_(entity, batch_size, attr=None, from_seeds=True): cached_sql = sql, adapter, attr_offsets entity._batchload_sql_cache_[query_key] = cached_sql return cached_sql - def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_update=False, nowait=False): - if nowait: assert for_update + def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_update=False, nowait=False, skip_locked=False): + if nowait or skip_locked: assert for_update sorted_query_attrs = tuple(sorted(query_attrs.items())) - query_key = sorted_query_attrs, order_by_pk, limit, for_update, nowait + query_key = sorted_query_attrs, order_by_pk, limit, for_update, nowait, skip_locked cached_sql = entity._find_sql_cache_.get(query_key) if cached_sql is not None: return cached_sql select_list, attr_offsets = entity._construct_select_clause_(query_attrs=query_attrs) @@ -4153,7 +4157,7 @@ def _construct_sql_(entity, query_attrs, order_by_pk=False, limit=None, for_upda where_list.append([ converter.EQ, [ 'COLUMN', None, column ], [ 'PARAM', (attr, None, j), converter ] ]) if not for_update: sql_ast = [ 'SELECT', select_list, from_list, where_list ] - else: sql_ast = [ 'SELECT_FOR_UPDATE', bool(nowait), select_list, from_list, where_list ] + else: sql_ast = [ 'SELECT_FOR_UPDATE', nowait, skip_locked, select_list, from_list, where_list ] if order_by_pk: sql_ast.append([ 'ORDER_BY' ] + [ [ 'COLUMN', None, column ] for column in entity._pk_columns_ ]) if limit is not None: sql_ast.append([ 'LIMIT', limit ]) database = entity._database_ @@ -5551,7 +5555,7 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False query._translator = translator query._filters = () query._next_kwarg_id = 0 - query._for_update = query._nowait = False + query._for_update = query._nowait = query._skip_locked = False query._distinct = None query._prefetch = False query._prefetch_context = PrefetchContext(query._database) @@ -5607,6 +5611,7 @@ def _construct_sql_and_arguments(query, limit=None, offset=None, range=None, agg aggr_func=(aggr_func_name, aggr_func_distinct, sep), for_update=query._for_update, nowait=query._nowait, + skip_locked=query._skip_locked, inner_join_syntax=options.INNER_JOIN_SYNTAX, attrs_to_prefetch=attrs_to_prefetch ) @@ -5615,7 +5620,7 @@ def _construct_sql_and_arguments(query, limit=None, offset=None, range=None, agg if cache_entry is None: sql_ast, attr_offsets = translator.construct_sql_ast( limit, offset, query._distinct, aggr_func_name, aggr_func_distinct, sep, - query._for_update, query._nowait) + query._for_update, query._nowait, query._skip_locked) cache = database._get_cache() sql, adapter = database.provider.ast2sql(sql_ast) cache_entry = sql, adapter, attr_offsets @@ -6060,11 +6065,10 @@ def max(query): def count(query, distinct=None): return query._aggregate('COUNT', distinct) @cut_traceback - def for_update(query, nowait=False): - provider = query._database.provider - if nowait and not provider.select_for_update_nowait_syntax: throw(TranslationError, - '%s provider does not support SELECT FOR UPDATE NOWAIT syntax' % provider.dialect) - return query._clone(_for_update=True, _nowait=nowait) + def for_update(query, nowait=False, skip_locked=False): + if nowait and skip_locked: + throw(TypeError, 'nowait and skip_locked options are mutually exclusive') + return query._clone(_for_update=True, _nowait=nowait, _skip_locked=skip_locked) def random(query, limit): return query.order_by('random()')[:limit] def to_json(query, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 816dfd6ad..fa303a638 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -91,7 +91,6 @@ class DBAPIProvider(object): index_if_not_exists_syntax = True max_time_precision = default_time_precision = 6 uint64_support = False - select_for_update_nowait_syntax = True # SQLite and PostgreSQL does not limit varchar max length. varchar_default_max_len = None diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index d4da64af3..ae8eecb37 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -186,7 +186,6 @@ class MySQLProvider(DBAPIProvider): max_params_count = 10000 table_if_not_exists_syntax = True index_if_not_exists_syntax = False - select_for_update_nowait_syntax = False max_time_precision = default_time_precision = 0 varchar_default_max_len = 255 uint64_support = True diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 50a6df949..c52cd09ff 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -131,11 +131,13 @@ def INSERT(builder, table_name, columns, values, returning=None): if returning is not None: result.extend((' RETURNING ', builder.quote_name(returning), ' INTO :new_id')) return result - def SELECT_FOR_UPDATE(builder, nowait, *sections): + def SELECT_FOR_UPDATE(builder, nowait, skip_locked, *sections): assert not builder.indent + nowait = ' NOWAIT' if nowait else '' + skip_locked = ' SKIP LOCKED' if skip_locked else '' last_section = sections[-1] if last_section[0] != 'LIMIT': - return builder.SELECT(*sections), 'FOR UPDATE NOWAIT\n' if nowait else 'FOR UPDATE\n' + return builder.SELECT(*sections), 'FOR UPDATE', nowait, skip_locked, '\n' from_section = sections[1] assert from_section[0] == 'FROM' @@ -154,7 +156,7 @@ def SELECT_FOR_UPDATE(builder, nowait, *sections): ('SELECT', [ 'ROWID', ['AS', rowid, 'row-id' ] ]) + sections[1:] ] ] ] if order_by_section: sql_ast.append(order_by_section) result = builder(sql_ast) - return result, 'FOR UPDATE NOWAIT\n' if nowait else 'FOR UPDATE\n' + return result, 'FOR UPDATE', nowait, skip_locked, '\n' def SELECT(builder, *sections): prev_suppress_aliases = builder.suppress_aliases builder.suppress_aliases = False diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 207b3fb7c..5f3e953c1 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -61,8 +61,8 @@ class SQLiteBuilder(SQLBuilder): def __init__(builder, provider, ast): builder.json1_available = provider.json1_available SQLBuilder.__init__(builder, provider, ast) - def SELECT_FOR_UPDATE(builder, nowait, *sections): - assert not builder.indent and not nowait + def SELECT_FOR_UPDATE(builder, nowait, skip_locked, *sections): + assert not builder.indent return builder.SELECT(*sections) def INSERT(builder, table_name, columns, values, returning=None): if not values: return 'INSERT INTO %s DEFAULT VALUES' % builder.quote_name(table_name) @@ -266,7 +266,6 @@ class SQLiteProvider(DBAPIProvider): dialect = 'SQLite' local_exceptions = local_exceptions max_name_len = 1024 - select_for_update_nowait_syntax = False dbapi_module = sqlite dbschema_cls = SQLiteSchema diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index ac61689da..9c4d1d4b8 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -254,10 +254,12 @@ def SELECT(builder, *sections): return result finally: builder.suppress_aliases = prev_suppress_aliases - def SELECT_FOR_UPDATE(builder, nowait, *sections): + def SELECT_FOR_UPDATE(builder, nowait, skip_locked, *sections): assert not builder.indent result = builder.SELECT(*sections) - return result, 'FOR UPDATE NOWAIT\n' if nowait else 'FOR UPDATE\n' + nowait = ' NOWAIT' if nowait else '' + skip_locked = ' SKIP LOCKED' if skip_locked else '' + return result, 'FOR UPDATE', nowait, skip_locked, '\n' def EXISTS(builder, *sections): result = builder._subquery(*sections) indent = builder.indent_spaces * builder.indent diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 14f05e9f6..00a6a1aaa 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -651,12 +651,12 @@ def construct_subquery_ast(translator, limit=None, offset=None, aliases=None, st return [ 'SELECT', select_ast, from_ast, where_ast ] + other_ast def construct_sql_ast(translator, limit=None, offset=None, distinct=None, aggr_func_name=None, aggr_func_distinct=None, sep=None, - for_update=False, nowait=False, is_not_null_checks=False): + for_update=False, nowait=False, skip_locked=False, is_not_null_checks=False): attr_offsets = None if distinct is None: distinct = translator.distinct ast_transformer = lambda ast: ast if for_update: - sql_ast = [ 'SELECT_FOR_UPDATE', nowait ] + sql_ast = [ 'SELECT_FOR_UPDATE', nowait, skip_locked ] translator.query_result_is_cacheable = False else: sql_ast = [ 'SELECT' ] From a832d37f3e57a20bbe0de7be6a8b288f8c8cf0be Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 17 Nov 2018 16:37:30 +0300 Subject: [PATCH 407/547] Code formatting --- pony/orm/ormtypes.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 5fc576732..be256873b 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -206,12 +206,12 @@ def normalize_type(t): throw(TypeError, 'Unsupported type %r' % t.__name__) coercions = { - (int, float) : float, - (int, Decimal) : Decimal, - (date, datetime) : datetime, - (bool, int) : int, - (bool, float) : float, - (bool, Decimal) : Decimal + (int, float): float, + (int, Decimal): Decimal, + (date, datetime): datetime, + (bool, int): int, + (bool, float): float, + (bool, Decimal): Decimal } coercions.update(((t2, t1), t3) for ((t1, t2), t3) in items_list(coercions)) From 873470698f8cf8a9ade8bc9eda832869d82a1c27 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 24 Nov 2018 16:30:23 +0300 Subject: [PATCH 408/547] Fix duplicated table join in FROM clause of optimized queries --- pony/orm/sqltranslation.py | 20 +++++++++++++------- pony/orm/tests/queries.txt | 16 ++++++++++++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 00a6a1aaa..1116f418c 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -238,14 +238,14 @@ def init(translator, tree, parent_translator, code_key=None, filter_num=None, ex if parent_translator is None: translator.root_translator = translator translator.database = None - translator.sqlquery = SqlQuery(left_join=left_join) + translator.sqlquery = SqlQuery(translator, left_join=left_join) assert code_key is not None and filter_num is not None translator.code_key = translator.original_code_key = code_key translator.filter_num = translator.original_filter_num = filter_num else: translator.root_translator = parent_translator.root_translator translator.database = parent_translator.database - translator.sqlquery = SqlQuery(parent_translator.sqlquery, left_join=left_join) + translator.sqlquery = SqlQuery(translator, parent_translator.sqlquery, left_join=left_join) assert code_key is None and filter_num is None translator.code_key = parent_translator.code_key translator.filter_num = parent_translator.filter_num @@ -1180,7 +1180,8 @@ def coerce_monads(m1, m2, for_comparison=False): max_alias_length = 30 class SqlQuery(object): - def __init__(sqlquery, parent_sqlquery=None, left_join=False): + def __init__(sqlquery, translator, parent_sqlquery=None, left_join=False): + sqlquery.translator = translator sqlquery.parent_sqlquery = parent_sqlquery sqlquery.left_join = left_join sqlquery.from_ast = [ 'LEFT_JOIN' if left_join else 'FROM' ] @@ -1382,7 +1383,12 @@ def make_join(tableref, pk_only=False): discr_criteria = entity._construct_discriminator_criteria_(alias) assert discr_criteria is not None join_cond.append(discr_criteria) - sqlquery.join_table(parent_alias, alias, entity._table_, join_cond) + + translator = tableref.sqlquery.translator.root_translator + if translator.optimize == tableref.name_path and translator.from_optimized and tableref.sqlquery is translator.sqlquery: + pass + else: + sqlquery.join_table(parent_alias, alias, entity._table_, join_cond) tableref.alias = alias tableref.pk_columns = pk_columns tableref.optimized = False @@ -2937,7 +2943,7 @@ def contains(monad, item, not_in=False): expr_ast = sqland([ [ 'EQ', expr1, expr2 ] for expr1, expr2 in izip(expr_list, item.getsql()) ]) return BoolExprMonad(expr_ast, nullable=False) else: - sqlquery = SqlQuery(translator.sqlquery) + sqlquery = SqlQuery(translator, translator.sqlquery) tableref = monad.make_tableref(sqlquery) attr = monad.attr alias, columns = tableref.make_join(pk_only=attr.reverse) @@ -3190,7 +3196,7 @@ def _subselect(monad, sqlquery=None, extract_outer_conditions=True): attr = monad.attr translator = monad.translator if sqlquery is None: - sqlquery = SqlQuery(translator.sqlquery) + sqlquery = SqlQuery(translator, translator.sqlquery) monad.make_tableref(sqlquery) sqlquery.expr_list = monad.make_expr_list() if not attr.reverse and not attr.is_required: @@ -3230,7 +3236,7 @@ def __init__(monad, op, sqlop, left, right): def aggregate(monad, func_name, distinct=None, sep=None): distinct = distinct_from_monad(distinct, default=monad.forced_distinct and func_name in ('SUM', 'AVG')) translator = monad.translator - sqlquery = SqlQuery(translator.sqlquery) + sqlquery = SqlQuery(translator, translator.sqlquery) expr = monad.getsql(sqlquery)[0] translator.aggregated_subquery_paths.add(monad.tableref.name_path) outer_cond = sqlquery.from_ast[1].pop() diff --git a/pony/orm/tests/queries.txt b/pony/orm/tests/queries.txt index 29c4a1e17..2ce040941 100644 --- a/pony/orm/tests/queries.txt +++ b/pony/orm/tests/queries.txt @@ -146,6 +146,22 @@ SELECT "s"."group", MIN("s"."gpa"), MAX("s"."gpa") FROM "Student" "s" GROUP BY "s"."group" +>>> select((g, min(g.students.gpa), max(g.students.gpa)) for g in Group) + +SELECT "g"."number", MIN("student"."gpa"), MAX("student"."gpa") +FROM "Group" "g" + LEFT JOIN "Student" "student" + ON "g"."number" = "student"."group" +GROUP BY "g"."number" + +>>> select((g, g.students.name, min(g.students.gpa), max(g.students.gpa)) for g in Group) + +SELECT "g"."number", "student"."name", MIN("student"."gpa"), MAX("student"."gpa") +FROM "Group" "g" + LEFT JOIN "Student" "student" + ON "g"."number" = "student"."group" +GROUP BY "g"."number", "student"."name" + >>> count(s for s in Student if s.group.number == 101) SELECT COUNT(*) From 16ab5ff495df2ac8febd5a1e3e5f93dd8a5d478f Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 7 Nov 2018 19:30:58 +0300 Subject: [PATCH 409/547] TrackedArray type added --- pony/orm/dbapiprovider.py | 6 +++--- pony/orm/dbproviders/sqlite.py | 4 ++-- pony/orm/ormtypes.py | 26 +++++++++++++++++++++++++- 3 files changed, 30 insertions(+), 6 deletions(-) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index fa303a638..424c273d4 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -9,7 +9,7 @@ import pony from pony.utils import is_utf8, decorator, throw, localbase, deprecated from pony.converting import str2date, str2time, str2datetime, str2timedelta -from pony.orm.ormtypes import LongStr, LongUnicode, RawSQLType, TrackedValue, TrackedList, Json, QueryType, Array +from pony.orm.ormtypes import LongStr, LongUnicode, RawSQLType, TrackedValue, TrackedArray, Json, QueryType, Array class DBException(Exception): def __init__(exc, original_exc, *args): @@ -841,12 +841,12 @@ def validate(converter, val, obj=None): throw(TypeError, 'Cannot store %s item in array of %s' % (type(v).__name__, converter.py_type.item_type.__name__)) - return TrackedList(obj, converter.attr, val) + return TrackedArray(obj, converter.attr, val) def dbval2val(converter, dbval, obj=None): if obj is None: return dbval - return TrackedList(obj, converter.attr, dbval) + return TrackedArray(obj, converter.attr, dbval) def val2dbval(converter, val, obj=None): return list(val) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 5f3e953c1..6bb224a3b 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -14,7 +14,7 @@ from pony.orm import core, dbschema, dbapiprovider from pony.orm.core import log_orm -from pony.orm.ormtypes import Json, TrackedList +from pony.orm.ormtypes import Json, TrackedArray from pony.orm.sqltranslation import SQLTranslator, StringExprMonad from pony.orm.sqlbuilding import SQLBuilder, join, make_unary_func from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions @@ -234,7 +234,7 @@ def dbval2val(converter, dbval, obj=None): items = json.loads(dbval) if dbval else [] if obj is None: return items - return TrackedList(obj, converter.attr, items) + return TrackedArray(obj, converter.attr, items) def val2dbval(converter, val, obj=None): return dumps(val) diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index be256873b..69f880a95 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -340,6 +340,30 @@ def __reduce__(self): def get_untracked(self): return [val.get_untracked() if isinstance(val, TrackedValue) else val for val in self] +def validate_item(item_type, item): + if not isinstance(item, item_type): + if item_type is not unicode and hasattr(item, '__index__'): + return item.__index__() + throw(TypeError, 'Cannot store %r item in array of %r' % (type(item).__name__, item_type.__name__)) + return item + +class TrackedArray(TrackedList): + def __init__(self, obj, attr, value): + TrackedList.__init__(self, obj, attr, value) + self.item_type = attr.py_type.item_type + def extend(self, items): + items = [validate_item(self.item_type, item) for item in items] + TrackedList.extend(self, items) + def append(self, item): + item = validate_item(self.item_type, item) + TrackedList.append(self, item) + def insert(self, index, item): + item = validate_item(self.item_type, item) + TrackedList.insert(self, index, item) + def __setitem__(self, index, item): + item = validate_item(self.item_type, item) + TrackedList.__setitem__(self, index, item) + class Json(object): """A wrapper over a dict or list """ @@ -351,7 +375,7 @@ def __repr__(self): class Array(object): def __init__(self, item_type): - if item_type not in(unicode, int, float): + if item_type not in (unicode, int, float): throw(NotImplementedError, 'Only int, float and str types are supported. Got: `Array(%r)`' % item_type) self.item_type = item_type From ab257bbd14deb74f5c7c579ba8820bbda039ca4b Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Mon, 12 Nov 2018 02:37:20 +0300 Subject: [PATCH 410/547] Array tests added --- pony/orm/tests/test_array.py | 76 ++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) create mode 100644 pony/orm/tests/test_array.py diff --git a/pony/orm/tests/test_array.py b/pony/orm/tests/test_array.py new file mode 100644 index 000000000..976fea628 --- /dev/null +++ b/pony/orm/tests/test_array.py @@ -0,0 +1,76 @@ +from pony.py23compat import PY2 + +import unittest +from pony.orm.tests.testutils import * + +from pony.orm import * + +db = Database('sqlite', ':memory:') + +class Foo(db.Entity): + array1 = Required(IntArray, index=True) + array2 = Required(FloatArray) + array3 = Required(StrArray) + +db.generate_mapping(create_tables=True) + + +with db_session: + Foo(array1=[1, 2, 3, 4, 5], array2=[1.1, 2.2, 3.3, 4.4, 5.5], array3=['foo', 'bar']) + +class Test(unittest.TestCase): + @db_session + def test_1(self): + foo = select(f for f in Foo if 1 in f.array1)[:] + self.assertEqual([Foo[1]], foo) + + @db_session + def test_2(self): + foo = select(f for f in Foo if [1, 2, 5] in f.array1)[:] + self.assertEqual([Foo[1]], foo) + + @db_session + def test_3(self): + x = [1, 2, 5] + foo = select(f for f in Foo if x in f.array1)[:] + self.assertEqual([Foo[1]], foo) + + @db_session + def test_4(self): + foo = select(f for f in Foo if 1.1 in f.array2)[:] + self.assertEqual([Foo[1]], foo) + + err_msg = "Cannot store 'int' item in array of " + ("'unicode'" if PY2 else "'str'") + + @raises_exception(TypeError, err_msg) + @db_session + def test_5(self): + foo = Foo.select().first() + foo.array3.append(123) + + @raises_exception(TypeError, err_msg) + @db_session + def test_6(self): + foo = Foo.select().first() + foo.array3[0] = 123 + + @raises_exception(TypeError, err_msg) + @db_session + def test_7(self): + foo = Foo.select().first() + foo.array3.extend(['str', 123, 'str']) + + @db_session + def test_8(self): + foo = Foo.select().first() + foo.array3.extend(['str1', 'str2']) + + @db_session + def test_9(self): + foos = select(f.array2[0] for f in Foo)[:] + self.assertEqual([1.1], foos) + + @db_session + def test_10(self): + foos = select(f.array1[1:-1] for f in Foo)[:] + self.assertEqual([2, 3, 4], foos[0]) From 7e21d5da3c4a5546df4ee855d8ec0693b0089fbf Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 17 Nov 2018 15:51:09 +0300 Subject: [PATCH 411/547] Make IntArray, StrArray and FloatArray classes --- pony/orm/dbapiprovider.py | 26 ++++++----- pony/orm/ormtypes.py | 30 +++++-------- pony/orm/sqltranslation.py | 90 +++++++++++++++++++------------------- 3 files changed, 69 insertions(+), 77 deletions(-) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 424c273d4..378f3bff4 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -270,11 +270,11 @@ def _get_converter_type_by_py_type(provider, py_type): if isinstance(py_type, type): for t, converter_cls in provider.converter_classes: if issubclass(py_type, t): return converter_cls - if isinstance(py_type, Array): - converter_cls = provider.array_converter_cls - if converter_cls is None: - throw(NotImplementedError, 'Array type is not supported for %r' % provider.dialect) - return converter_cls + if issubclass(py_type, Array): + converter_cls = provider.array_converter_cls + if converter_cls is None: + throw(NotImplementedError, 'Array type is not supported for %r' % provider.dialect) + return converter_cls if isinstance(py_type, RawSQLType): return Converter # for cases like select(raw_sql(...) for x in X) throw(TypeError, 'No database converter found for type %s' % py_type) @@ -821,27 +821,29 @@ def __init__(converter, provider, py_type, attr=None): converter.item_converter = converter.array_types[converter.py_type.item_type][1] def validate(converter, val, obj=None): - if obj is None or converter.attr is None: - return val if isinstance(val, TrackedValue) and val.obj_ref() is obj and val.attr is converter.attr: return val if isinstance(val, basestring) or not hasattr(val, '__len__'): - val = [val] + items = [val] else: - val = list(val) + items = list(val) item_type = converter.py_type.item_type if item_type == float: item_type = (float, int) - for i, v in enumerate(val): + for i, v in enumerate(items): + if PY2 and isinstance(v, str): + v = v.decode('ascii') if not isinstance(v, item_type): if hasattr(v, '__index__'): - val[i] = v.__index__() + items[i] = v.__index__() else: throw(TypeError, 'Cannot store %s item in array of %s' % (type(v).__name__, converter.py_type.item_type.__name__)) - return TrackedArray(obj, converter.attr, val) + if obj is None or converter.attr is None: + return items + return TrackedArray(obj, converter.attr, items) def dbval2val(converter, dbval, obj=None): if obj is None: diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 69f880a95..679d1adb1 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -203,6 +203,7 @@ def normalize_type(t): if t in (slice, type(Ellipsis)): return t if issubclass(t, basestring): return unicode if issubclass(t, (dict, Json)): return Json + if issubclass(t, Array): return t throw(TypeError, 'Unsupported type %r' % t.__name__) coercions = { @@ -341,6 +342,8 @@ def get_untracked(self): return [val.get_untracked() if isinstance(val, TrackedValue) else val for val in self] def validate_item(item_type, item): + if PY2 and isinstance(item, str): + item = item.decode('ascii') if not isinstance(item, item_type): if item_type is not unicode and hasattr(item, '__index__'): return item.__index__() @@ -374,27 +377,14 @@ def __repr__(self): return '' % self.wrapped class Array(object): - def __init__(self, item_type): - if item_type not in (unicode, int, float): - throw(NotImplementedError, 'Only int, float and str types are supported. Got: `Array(%r)`' % item_type) - self.item_type = item_type - - def __repr__(self): - return 'Array(%s)' % self.item_type.__name__ + item_type = None # Should be overridden in subclass - def __deepcopy__(self, memo): - return self +class IntArray(Array): + item_type = int - def __eq__(self, other): - return type(other) is Array and self.item_type == other.item_type - - def __ne__(self, other): - return not self.__eq__(other) - - def __hash__(self): - return hash(self.item_type) +class StrArray(Array): + item_type = unicode -IntArray = Array(int) -StrArray = Array(unicode) -FloatArray = Array(float) +class FloatArray(Array): + item_type = float diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 1116f418c..fd8d4999e 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2163,19 +2163,19 @@ def requires_distinct(monad, joined=False): class AttrMonad(Monad): @staticmethod def new(parent, attr, *args, **kwargs): - type = normalize_type(attr.py_type) - if type in numeric_types: cls = NumericAttrMonad - elif type is unicode: cls = StringAttrMonad - elif type is date: cls = DateAttrMonad - elif type is time: cls = TimeAttrMonad - elif type is timedelta: cls = TimedeltaAttrMonad - elif type is datetime: cls = DatetimeAttrMonad - elif type is buffer: cls = BufferAttrMonad - elif type is UUID: cls = UuidAttrMonad - elif type is Json: cls = JsonAttrMonad - elif isinstance(type, EntityMeta): cls = ObjectAttrMonad - elif isinstance(type, Array): cls = ArrayAttrMonad - else: throw(NotImplementedError, type) # pragma: no cover + t = normalize_type(attr.py_type) + if t in numeric_types: cls = NumericAttrMonad + elif t is unicode: cls = StringAttrMonad + elif t is date: cls = DateAttrMonad + elif t is time: cls = TimeAttrMonad + elif t is timedelta: cls = TimedeltaAttrMonad + elif t is datetime: cls = DatetimeAttrMonad + elif t is buffer: cls = BufferAttrMonad + elif t is UUID: cls = UuidAttrMonad + elif t is Json: cls = JsonAttrMonad + elif isinstance(t, EntityMeta): cls = ObjectAttrMonad + elif isinstance(t, type) and issubclass(t, Array): cls = ArrayAttrMonad + else: throw(NotImplementedError, t) # pragma: no cover return cls(parent, attr, *args, **kwargs) def __new__(cls, *args): if cls is AttrMonad: assert False, 'Abstract class' # pragma: no cover @@ -2232,33 +2232,33 @@ class ArrayAttrMonad(ArrayMixin, AttrMonad): pass class ParamMonad(Monad): @staticmethod - def new(type, paramkey): - type = normalize_type(type) - if type in numeric_types: cls = NumericParamMonad - elif type is unicode: cls = StringParamMonad - elif type is date: cls = DateParamMonad - elif type is time: cls = TimeParamMonad - elif type is timedelta: cls = TimedeltaParamMonad - elif type is datetime: cls = DatetimeParamMonad - elif type is buffer: cls = BufferParamMonad - elif type is UUID: cls = UuidParamMonad - elif type is Json: cls = JsonParamMonad - elif type is Array: cls = ArrayParamMonad - elif isinstance(type, EntityMeta): cls = ObjectParamMonad - else: throw(NotImplementedError, 'Parameter {EXPR} has unsupported type %r' % (type,)) - result = cls(type, paramkey) + def new(t, paramkey): + t = normalize_type(t) + if t in numeric_types: cls = NumericParamMonad + elif t is unicode: cls = StringParamMonad + elif t is date: cls = DateParamMonad + elif t is time: cls = TimeParamMonad + elif t is timedelta: cls = TimedeltaParamMonad + elif t is datetime: cls = DatetimeParamMonad + elif t is buffer: cls = BufferParamMonad + elif t is UUID: cls = UuidParamMonad + elif t is Json: cls = JsonParamMonad + elif isinstance(t, type) and issubclass(t, Array): cls = ArrayParamMonad + elif isinstance(t, EntityMeta): cls = ObjectParamMonad + else: throw(NotImplementedError, 'Parameter {EXPR} has unsupported type %r' % (t,)) + result = cls(t, paramkey) result.aggregated = False return result def __new__(cls, *args): if cls is ParamMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) - def __init__(monad, type, paramkey): - type = normalize_type(type) - Monad.__init__(monad, type, nullable=False) + def __init__(monad, t, paramkey): + t = normalize_type(t) + Monad.__init__(monad, t, nullable=False) monad.paramkey = paramkey - if not isinstance(type, EntityMeta): + if not isinstance(t, EntityMeta): provider = monad.translator.database.provider - monad.converter = provider.get_converter_by_py_type(type) + monad.converter = provider.get_converter_by_py_type(t) else: monad.converter = None def getsql(monad, sqlquery=None): return [ [ 'PARAM', monad.paramkey, monad.converter ] ] @@ -2294,18 +2294,18 @@ def getsql(monad, sqlquery=None): class ExprMonad(Monad): @staticmethod - def new(type, sql, nullable=True): - if type in numeric_types: cls = NumericExprMonad - elif type is unicode: cls = StringExprMonad - elif type is date: cls = DateExprMonad - elif type is time: cls = TimeExprMonad - elif type is timedelta: cls = TimedeltaExprMonad - elif type is datetime: cls = DatetimeExprMonad - elif type is Json: cls = JsonExprMonad - elif isinstance(type, EntityMeta): cls = ObjectExprMonad - elif isinstance(type, Array): cls = ArrayExprMonad - else: throw(NotImplementedError, type) # pragma: no cover - return cls(type, sql, nullable=nullable) + def new(t, sql, nullable=True): + if t in numeric_types: cls = NumericExprMonad + elif t is unicode: cls = StringExprMonad + elif t is date: cls = DateExprMonad + elif t is time: cls = TimeExprMonad + elif t is timedelta: cls = TimedeltaExprMonad + elif t is datetime: cls = DatetimeExprMonad + elif t is Json: cls = JsonExprMonad + elif isinstance(t, EntityMeta): cls = ObjectExprMonad + elif isinstance(t, type) and issubclass(t, Array): cls = ArrayExprMonad + else: throw(NotImplementedError, t) # pragma: no cover + return cls(t, sql, nullable=nullable) def __new__(cls, *args, **kwargs): if cls is ExprMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) From b34baf674dfd454ea25202500632d1150778fa71 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 17 Nov 2018 12:49:48 +0300 Subject: [PATCH 412/547] Optional arrays should be NOT NULL by default --- pony/orm/core.py | 9 +++++---- pony/orm/dbproviders/sqlite.py | 3 ++- pony/orm/ormtypes.py | 8 ++++++++ pony/orm/tests/test_array.py | 14 ++++++++++++++ 4 files changed, 29 insertions(+), 5 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index e79cdd150..3c4c55ccf 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1013,7 +1013,7 @@ def get_columns(table, column_names): m2m_table.m2m.add(reverse) else: if attr.is_required: pass - elif not attr.is_string: + elif not attr.type_has_empty_value: if attr.nullable is False: throw(TypeError, 'Optional attribute with non-string type %s must be nullable' % attr) attr.nullable = True @@ -1976,7 +1976,7 @@ class Attribute(object): 'lazy', 'lazy_sql_cache', 'args', 'auto', 'default', 'reverse', 'composite_keys', \ 'column', 'columns', 'col_paths', '_columns_checked', 'converters', 'kwargs', \ 'cascade_delete', 'index', 'original_default', 'sql_default', 'py_check', 'hidden', \ - 'optimistic', 'fk_name' + 'optimistic', 'fk_name', 'type_has_empty_value' def __deepcopy__(attr, memo): return attr # Attribute cannot be cloned by deepcopy() @cut_traceback @@ -2002,6 +2002,7 @@ def __init__(attr, py_type, *args, **kwargs): throw(TypeError, 'Incorrect type of attribute: %r' % py_type) attr.py_type = py_type attr.is_string = type(py_type) is type and issubclass(py_type, basestring) + attr.type_has_empty_value = attr.is_string or hasattr(attr.py_type, 'default_empty_value') attr.is_collection = isinstance(attr, Collection) attr.is_relation = isinstance(attr.py_type, (EntityMeta, basestring, types.FunctionType)) attr.is_basic = not attr.is_collection and not attr.is_relation @@ -2074,8 +2075,8 @@ def _init_(attr, entity, name): 'Default value for required attribute %s cannot be empty string' % attr) elif attr.default is None and not attr.nullable: throw(TypeError, 'Default value for non-nullable attribute %s cannot be set to None' % attr) - elif attr.is_string and not attr.is_required and not attr.nullable: - attr.default = '' + elif attr.type_has_empty_value and not attr.is_required and not attr.nullable: + attr.default = '' if attr.is_string else attr.py_type.default_empty_value() else: attr.default = None diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 6bb224a3b..483959ff8 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -231,7 +231,8 @@ class SQLiteArrayConverter(dbapiprovider.ArrayConverter): } def dbval2val(converter, dbval, obj=None): - items = json.loads(dbval) if dbval else [] + if not dbval: return None + items = json.loads(dbval) if obj is None: return items return TrackedArray(obj, converter.attr, items) diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 679d1adb1..a29e45535 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -370,6 +370,10 @@ def __setitem__(self, index, item): class Json(object): """A wrapper over a dict or list """ + @classmethod + def default_empty_value(cls): + return {} + def __init__(self, wrapped): self.wrapped = wrapped @@ -379,6 +383,10 @@ def __repr__(self): class Array(object): item_type = None # Should be overridden in subclass + @classmethod + def default_empty_value(cls): + return [] + class IntArray(Array): item_type = int diff --git a/pony/orm/tests/test_array.py b/pony/orm/tests/test_array.py index 976fea628..2278d52a8 100644 --- a/pony/orm/tests/test_array.py +++ b/pony/orm/tests/test_array.py @@ -11,6 +11,8 @@ class Foo(db.Entity): array1 = Required(IntArray, index=True) array2 = Required(FloatArray) array3 = Required(StrArray) + array4 = Optional(IntArray) + array5 = Optional(IntArray, nullable=True) db.generate_mapping(create_tables=True) @@ -74,3 +76,15 @@ def test_9(self): def test_10(self): foos = select(f.array1[1:-1] for f in Foo)[:] self.assertEqual([2, 3, 4], foos[0]) + + @db_session + def test_11(self): + foo = Foo.select().first() + foo.array4.append(1) + self.assertEqual([1], foo.array4) + + @raises_exception(AttributeError, "'NoneType' object has no attribute 'append'") + @db_session + def test_12(self): + foo = Foo.select().first() + foo.array5.append(1) From 5b4ba26b18d942e3f34b1fc08351ee9e0e91f307 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 21 Nov 2018 13:10:05 +0300 Subject: [PATCH 413/547] Array fixes --- pony/orm/ormtypes.py | 21 ++++++++++++++++----- pony/orm/sqltranslation.py | 6 +++--- pony/orm/tests/test_array.py | 5 +++++ pony/orm/tests/test_query.py | 2 +- 4 files changed, 25 insertions(+), 9 deletions(-) diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index a29e45535..65c0706ad 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -142,11 +142,6 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) -numeric_types = {bool, int, float, Decimal} -comparable_types = {int, float, Decimal, unicode, date, time, datetime, timedelta, bool, UUID} -primitive_types = comparable_types | {buffer} -function_types = {type, types.FunctionType, types.BuiltinFunctionType} -type_normalization_dict = { long : int } if PY2 else {} def normalize(value): t = type(value) @@ -387,12 +382,28 @@ class Array(object): def default_empty_value(cls): return [] + class IntArray(Array): item_type = int + class StrArray(Array): item_type = unicode + class FloatArray(Array): item_type = float + +numeric_types = {bool, int, float, Decimal} +comparable_types = {int, float, Decimal, unicode, date, time, datetime, timedelta, bool, UUID, IntArray, StrArray, FloatArray} +primitive_types = comparable_types | {buffer} +function_types = {type, types.FunctionType, types.BuiltinFunctionType} +type_normalization_dict = { long : int } if PY2 else {} + +array_types = { + int: IntArray, + float: FloatArray, + unicode: StrArray +} + diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index fd8d4999e..82c051a97 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -18,7 +18,7 @@ from pony.orm.ormtypes import \ numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ normalize, normalize_type, coerce_types, are_comparable_types, \ - Json, QueryType, Array + Json, QueryType, Array, array_types from pony.orm import core from pony.orm.core import EntityMeta, Set, JOIN, OptimizationFailed, Attribute, DescWrapper, \ special_functions, const_functions, extract_vars, Query, UseAnotherTranslator @@ -142,7 +142,7 @@ def dispatch_external(translator, node): elif tt is tuple: params = [] is_array = False - if translator.database.provider.array_converter_cls is None: + if translator.database.provider.array_converter_cls is not None: types = set(t) if len(types) == 1 and unicode in types: item_type = unicode @@ -158,7 +158,7 @@ def dispatch_external(translator, node): is_array = True if is_array: - array_type = Array(item_type) + array_type = array_types.get(item_type, None) monad = ArrayParamMonad(array_type, (varkey, None, None)) else: for i, item_type in enumerate(t): diff --git a/pony/orm/tests/test_array.py b/pony/orm/tests/test_array.py index 2278d52a8..0bbe27b2a 100644 --- a/pony/orm/tests/test_array.py +++ b/pony/orm/tests/test_array.py @@ -88,3 +88,8 @@ def test_11(self): def test_12(self): foo = Foo.select().first() foo.array5.append(1) + + @db_session + def test_13(self): + x = [1, 2, 3, 4, 5] + select(f for f in Foo if x == f.array1)[:] diff --git a/pony/orm/tests/test_query.py b/pony/orm/tests/test_query.py index 911638ac6..42559afbd 100644 --- a/pony/orm/tests/test_query.py +++ b/pony/orm/tests/test_query.py @@ -51,7 +51,7 @@ def test3(self): "`a` raises NameError: name 'a' is not defined") def test4(self): select(a for s in Student) - @raises_exception(TypeError, "Incomparable types '%s' and 'list' in expression: s.name == x" % unicode.__name__) + @raises_exception(TypeError, "Incomparable types '%s' and 'StrArray' in expression: s.name == x" % unicode.__name__) def test5(self): x = ['A'] select(s for s in Student if s.name == x) From 5960ac45936f447c5f885022f4446fc503fdca7d Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 24 Nov 2018 15:42:55 +0300 Subject: [PATCH 414/547] Fix array __getitem__ --- pony/orm/sqltranslation.py | 66 ++++++++++++++------------------------ 1 file changed, 24 insertions(+), 42 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 82c051a97..4d20c2635 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2066,56 +2066,38 @@ def len(monad): def nonzero(monad): return BoolExprMonad(['GT', ['ARRAY_LENGTH', monad.getsql()[0]], ['VALUE', 0]]) - def __getitem__(monad, index): + def _index(monad, index, from_one, plus_one): if isinstance(index, NumericConstMonad): expr_sql = monad.getsql()[0] - index = index.getsql()[0] - value = index[1] - if not monad.translator.database.provider.dialect == 'SQLite': + index_sql = index.getsql()[0] + value = index_sql[1] + if from_one and plus_one: if value >= 0: - index = ['VALUE', value + 1] + index_sql = ['VALUE', value + 1] else: - index = ['SUB', ['ARRAY_LENGTH', expr_sql], ['VALUE', abs(value) + 1]] + index_sql = ['SUB', ['ARRAY_LENGTH', expr_sql], ['VALUE', abs(value) + 1]] + + return index_sql + elif isinstance(index, NumericMixin): + expr_sql = monad.getsql()[0] + index0 = index.getsql()[0] + index1 = ['ADD', index0, ['VALUE', 1]] if from_one and plus_one else index0 + index_sql = ['CASE', None, [[['GE', index0, ['VALUE', 0]], index1]], + ['ADD', ['ARRAY_LENGTH', expr_sql], index1]] + return index_sql - sql = ['ARRAY_INDEX', expr_sql, index] + def __getitem__(monad, index): + dialect = monad.translator.database.provider.dialect + expr_sql = monad.getsql()[0] + from_one = dialect != 'SQLite' + if isinstance(index, NumericMixin): + index_sql = monad._index(index, from_one, plus_one=True) + sql = ['ARRAY_INDEX', expr_sql, index_sql] return ExprMonad.new(monad.type.item_type, sql) elif isinstance(index, slice): if index.step is not None: throw(TypeError, 'Step is not supported in {EXPR}') - start, stop = index.start, index.stop - if start is None and stop is None: - return monad - - if start is not None and start.type is not int: - throw(TypeError, "Invalid type of start index (expected 'int', got %r) in array slice {EXPR}" - % type2str(start.type)) - if stop is not None and stop.type is not int: - throw(TypeError, "Invalid type of stop index (expected 'int', got %r) in array slice {EXPR}" - % type2str(stop.type)) - - if (start is not None and not isinstance(start, NumericConstMonad)) or \ - (stop is not None and not isinstance(stop, NumericConstMonad)): - throw(TypeError, 'Array indices should be type of int') - - expr_sql = monad.getsql()[0] - - if not monad.translator.database.provider.dialect == 'SQLite': - if start is None: - start_sql = None - elif start.value >= 0: - start_sql = ['VALUE', start.value + 1] - else: - start_sql = ['SUB', ['ARRAY_LENGTH', expr_sql], ['VALUE', abs(start.value) + 1]] - - if stop is None: - stop_sql = None - elif stop.value >= 0: - stop_sql = ['VALUE', stop.value + 1] - else: - stop_sql = ['SUB', ['ARRAY_LENGTH', expr_sql], ['VALUE', abs(stop.value) + 1]] - else: - start_sql = None if start is None else ['VALUE', start.value] - stop_sql = None if stop is None else ['VALUE', stop.value] - + start_sql = monad._index(index.start, from_one, plus_one=True) + stop_sql = monad._index(index.stop, from_one, plus_one=False) sql = ['ARRAY_SLICE', expr_sql, start_sql, stop_sql] return ExprMonad.new(monad.type, sql) From 9c01370cd97a673fa568d6c987c9ea490170adf9 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 24 Nov 2018 18:11:23 +0300 Subject: [PATCH 415/547] More array tests added --- pony/orm/tests/test_array.py | 130 +++++++++++++++++++++++++++++++++-- 1 file changed, 123 insertions(+), 7 deletions(-) diff --git a/pony/orm/tests/test_array.py b/pony/orm/tests/test_array.py index 0bbe27b2a..435ad0187 100644 --- a/pony/orm/tests/test_array.py +++ b/pony/orm/tests/test_array.py @@ -8,6 +8,10 @@ db = Database('sqlite', ':memory:') class Foo(db.Entity): + id = PrimaryKey(int) + a = Required(int) + b = Required(int) + c = Required(int) array1 = Required(IntArray, index=True) array2 = Required(FloatArray) array3 = Required(StrArray) @@ -18,22 +22,22 @@ class Foo(db.Entity): with db_session: - Foo(array1=[1, 2, 3, 4, 5], array2=[1.1, 2.2, 3.3, 4.4, 5.5], array3=['foo', 'bar']) + Foo(id=1, a=1, b=3, c=-2, array1=[10, 20, 30, 40, 50], array2=[1.1, 2.2, 3.3, 4.4, 5.5], array3=['foo', 'bar']) class Test(unittest.TestCase): @db_session def test_1(self): - foo = select(f for f in Foo if 1 in f.array1)[:] + foo = select(f for f in Foo if 10 in f.array1)[:] self.assertEqual([Foo[1]], foo) @db_session def test_2(self): - foo = select(f for f in Foo if [1, 2, 5] in f.array1)[:] + foo = select(f for f in Foo if [10, 20, 50] in f.array1)[:] self.assertEqual([Foo[1]], foo) @db_session def test_3(self): - x = [1, 2, 5] + x = [10, 20, 50] foo = select(f for f in Foo if x in f.array1)[:] self.assertEqual([Foo[1]], foo) @@ -75,7 +79,7 @@ def test_9(self): @db_session def test_10(self): foos = select(f.array1[1:-1] for f in Foo)[:] - self.assertEqual([2, 3, 4], foos[0]) + self.assertEqual([20, 30, 40], foos[0]) @db_session def test_11(self): @@ -91,5 +95,117 @@ def test_12(self): @db_session def test_13(self): - x = [1, 2, 3, 4, 5] - select(f for f in Foo if x == f.array1)[:] + x = [10, 20, 30, 40, 50] + ids = select(f.id for f in Foo if x == f.array1)[:] + self.assertEqual(ids, [1]) + + @db_session + def test_14(self): + val = select(f.array1[0] for f in Foo).first() + self.assertEqual(val, 10) + + @db_session + def test_15(self): + val = select(f.array1[2] for f in Foo).first() + self.assertEqual(val, 30) + + @db_session + def test_16(self): + val = select(f.array1[-1] for f in Foo).first() + self.assertEqual(val, 50) + + @db_session + def test_17(self): + val = select(f.array1[-2] for f in Foo).first() + self.assertEqual(val, 40) + + @db_session + def test_18(self): + x = 2 + val = select(f.array1[x] for f in Foo).first() + self.assertEqual(val, 30) + + @db_session + def test_19(self): + val = select(f.array1[f.a] for f in Foo).first() + self.assertEqual(val, 20) + + @db_session + def test_20(self): + val = select(f.array1[f.c] for f in Foo).first() + self.assertEqual(val, 40) + + @db_session + def test_21(self): + array = select(f.array1[2:4] for f in Foo).first() + self.assertEqual(array, [30, 40]) + + @db_session + def test_22(self): + array = select(f.array1[1:-2] for f in Foo).first() + self.assertEqual(array, [20, 30]) + + @db_session + def test_23(self): + array = select(f.array1[10:-10] for f in Foo).first() + self.assertEqual(array, []) + + @db_session + def test_24(self): + x = 2 + array = select(f.array1[x:4] for f in Foo).first() + self.assertEqual(array, [30, 40]) + + @db_session + def test_25(self): + y = 4 + array = select(f.array1[2:y] for f in Foo).first() + self.assertEqual(array, [30, 40]) + + @db_session + def test_26(self): + x, y = 2, 4 + array = select(f.array1[x:y] for f in Foo).first() + self.assertEqual(array, [30, 40]) + + @db_session + def test_27(self): + x, y = 1, -2 + array = select(f.array1[x:y] for f in Foo).first() + self.assertEqual(array, [20, 30]) + + @db_session + def test_28(self): + x = 1 + array = select(f.array1[x:f.b] for f in Foo).first() + self.assertEqual(array, [20, 30]) + + @db_session + def test_29(self): + array = select(f.array1[f.a:f.c] for f in Foo).first() + self.assertEqual(array, [20, 30]) + + @db_session + def test_30(self): + array = select(f.array1[:3] for f in Foo).first() + self.assertEqual(array, [10, 20, 30]) + + @db_session + def test_31(self): + array = select(f.array1[2:] for f in Foo).first() + self.assertEqual(array, [30, 40, 50]) + + @db_session + def test_32(self): + array = select(f.array1[:f.b] for f in Foo).first() + self.assertEqual(array, [10, 20, 30]) + + @db_session + def test_33(self): + array = select(f.array1[:f.c] for f in Foo).first() + self.assertEqual(array, [10, 20, 30]) + + @db_session + def test_34(self): + array = select(f.array1[f.c:] for f in Foo).first() + self.assertEqual(array, [40, 50]) From c0b5a3c0d8e9b0c1390e4ed1148888923d9c2251 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 8 Dec 2018 15:27:06 +0300 Subject: [PATCH 416/547] Fix warnings filter in unittests --- pony/orm/tests/test_validate.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pony/orm/tests/test_validate.py b/pony/orm/tests/test_validate.py index 3ff4425d5..020822c1e 100644 --- a/pony/orm/tests/test_validate.py +++ b/pony/orm/tests/test_validate.py @@ -22,9 +22,6 @@ class Person(db.Entity): ) """) -warnings.simplefilter('error', ) - - class TestValidate(unittest.TestCase): @db_session From a42d351ac2f842b93c60aa4c17d470970e7b58be Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 9 Dec 2018 11:11:56 +0300 Subject: [PATCH 417/547] For nested db_session retry option should be ignored --- pony/orm/core.py | 36 ++++++++++++++++++++++--------- pony/orm/tests/test_db_session.py | 16 ++++++++++++-- 2 files changed, 40 insertions(+), 12 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 3c4c55ccf..a3c40b15d 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -48,7 +48,7 @@ 'TransactionError', 'ConnectionClosedError', 'TransactionIntegrityError', 'IsolationError', 'CommitException', 'RollbackException', 'UnrepeatableReadError', 'OptimisticCheckError', 'UnresolvableCyclicDependency', 'UnexpectedError', 'DatabaseSessionIsOver', - 'DatabaseContainsIncorrectValue', 'DatabaseContainsIncorrectEmptyValue', + 'PonyRuntimeWarning', 'DatabaseContainsIncorrectValue', 'DatabaseContainsIncorrectEmptyValue', 'TranslationError', 'ExprEvalError', 'PermissionError', 'Database', 'sql_debug', 'set_sql_debug', 'sql_debugging', 'show', @@ -221,7 +221,10 @@ def __init__(self, translator): Exception.__init__(self, 'This exception should be catched internally by PonyORM') self.translator = translator -class DatabaseContainsIncorrectValue(RuntimeWarning): +class PonyRuntimeWarning(RuntimeWarning): + pass + +class DatabaseContainsIncorrectValue(PonyRuntimeWarning): pass class DatabaseContainsIncorrectEmptyValue(DatabaseContainsIncorrectValue): @@ -498,11 +501,24 @@ def _commit_or_rollback(db_session, exc_type, exc, tb): local.user_roles_cache.clear() def _wrap_function(db_session, func): def new_func(func, *args, **kwargs): - if db_session.ddl and local.db_context_counter: - if isinstance(func, types.FunctionType): func = func.__name__ + '()' - throw(TransactionError, '%s cannot be called inside of db_session' % func) - if db_session.sql_debug is not None: + if local.db_context_counter: + if db_session.ddl: + fname = func.__name__ + '()' if isinstance(func, types.FunctionType) else func + throw(TransactionError, '@db_session-decorated %s function with `ddl` option ' + 'cannot be called inside of another db_session' % fname) + if db_session.retry: + fname = func.__name__ + '()' if isinstance(func, types.FunctionType) else func + message = '@db_session decorator with `retry=%d` option is ignored for %s function ' \ + 'because it is called inside another db_session' % (db_session.retry, fname) + warnings.warn(message, PonyRuntimeWarning, stacklevel=3) + if db_session.sql_debug is None: + return func(*args, **kwargs) local.push_debug_state(db_session.sql_debug, db_session.show_values) + try: + return func(*args, **kwargs) + finally: + local.pop_debug_state() + exc = tb = None try: for i in xrange(db_session.retry+1): @@ -520,13 +536,13 @@ def new_func(func, *args, **kwargs): else: assert exc is not None # exc can be None in Python 2.6 do_retry = retry_exceptions(exc) - if not do_retry: raise - finally: db_session.__exit__(exc_type, exc, tb) + if not do_retry: + raise + finally: + db_session.__exit__(exc_type, exc, tb) reraise(exc_type, exc, tb) finally: del exc, tb - if db_session.sql_debug is not None: - local.pop_debug_state() return decorator(new_func, func) def _wrap_coroutine_or_generator_function(db_session, gen_func): for option in ('ddl', 'retry', 'serializable'): diff --git a/pony/orm/tests/test_db_session.py b/pony/orm/tests/test_db_session.py index d9a496c95..8f0d8795d 100644 --- a/pony/orm/tests/test_db_session.py +++ b/pony/orm/tests/test_db_session.py @@ -1,6 +1,6 @@ from __future__ import absolute_import, print_function, division -import unittest +import unittest, warnings from datetime import date from decimal import Decimal from itertools import count @@ -254,6 +254,17 @@ def test(): else: self.fail() + @raises_exception(PonyRuntimeWarning, '@db_session decorator with `retry=3` option is ignored for test() function ' + 'because it is called inside another db_session') + def test_retry_11(self): + @db_session(retry=3) + def test(): + pass + with warnings.catch_warnings(): + warnings.simplefilter('error', PonyRuntimeWarning) + with db_session: + test() + def test_db_session_manager_1(self): with db_session: self.X(a=3, b=3) @@ -313,7 +324,8 @@ def test_db_session_ddl_1c(self): with db_session(ddl=True): pass - @raises_exception(TransactionError, "test() cannot be called inside of db_session") + @raises_exception(TransactionError, "@db_session-decorated test() function with `ddl` option " + "cannot be called inside of another db_session") def test_db_session_ddl_2(self): @db_session(ddl=True) def test(): From 82cfc7432a391874e2f613b6492ba931e0a2b323 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 24 Nov 2018 20:18:51 +0300 Subject: [PATCH 418/547] contains for tracked array --- pony/orm/ormtypes.py | 6 ++++++ pony/orm/tests/test_array.py | 14 ++++++++++++++ 2 files changed, 20 insertions(+) diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 65c0706ad..66adf4589 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -362,6 +362,12 @@ def __setitem__(self, index, item): item = validate_item(self.item_type, item) TrackedList.__setitem__(self, index, item) + def __contains__(self, item): + if not isinstance(item, basestring) and hasattr(item, '__iter__'): + return all(it in set(self) for it in item) + return list.__contains__(self, item) + + class Json(object): """A wrapper over a dict or list """ diff --git a/pony/orm/tests/test_array.py b/pony/orm/tests/test_array.py index 435ad0187..728b9ceb5 100644 --- a/pony/orm/tests/test_array.py +++ b/pony/orm/tests/test_array.py @@ -209,3 +209,17 @@ def test_33(self): def test_34(self): array = select(f.array1[f.c:] for f in Foo).first() self.assertEqual(array, [40, 50]) + + @db_session + def test_35(self): + foo = Foo.select().first() + self.assertTrue(10 in foo.array1) + self.assertTrue(1000 not in foo.array1) + self.assertTrue([10, 20] in foo.array1) + self.assertTrue([20, 10] in foo.array1) + self.assertTrue([10, 1000] not in foo.array1) + self.assertTrue('bar' in foo.array3) + self.assertTrue('baz' not in foo.array3) + self.assertTrue(['foo', 'bar'] in foo.array3) + self.assertTrue(['bar', 'foo'] in foo.array3) + self.assertTrue(['baz', 'bar'] not in foo.array3) From 926a557a811c799a7d15a6f50701127ed7595a6d Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sun, 9 Dec 2018 12:46:15 +0300 Subject: [PATCH 419/547] Fixes #405: Breaking change with cx_Oracle 7.0: DML RETURNING statements now return a list --- pony/orm/dbproviders/oracle.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index c52cd09ff..57d86d990 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -471,7 +471,11 @@ def execute(provider, cursor, sql, arguments=None, returning_id=False): arguments['new_id'] = var if arguments is None: cursor.execute(sql) else: cursor.execute(sql, arguments) - return var.getvalue() + value = var.getvalue() + if isinstance(value, list): + assert len(value) == 1 + value = value[0] + return value if arguments is None: cursor.execute(sql) else: cursor.execute(sql, arguments) From 05216295e1e1ab0a8a6cc9fdc9b1fbabebe097c5 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 4 Jan 2019 13:14:29 +0300 Subject: [PATCH 420/547] Optimization of ast2src(node) --- pony/orm/asttranslation.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pony/orm/asttranslation.py b/pony/orm/asttranslation.py index 379afef6c..6c9068b60 100644 --- a/pony/orm/asttranslation.py +++ b/pony/orm/asttranslation.py @@ -28,8 +28,8 @@ def dispatch(translator, node): except KeyError: pre_method = getattr(translator_cls, 'pre' + node_cls.__name__, translator_cls.default_pre) pre_methods[node_cls] = pre_method - stop = translator.call(pre_method, node) + stop = translator.call(pre_method, node) if stop: return for child in node.getChildNodes(): @@ -61,6 +61,9 @@ def binop_src(op, node): return op.join((node.left.src, node.right.src)) def ast2src(tree): + src = getattr(tree, 'src', None) + if src is not None: + return src PythonTranslator(tree) return tree.src @@ -71,6 +74,9 @@ def __init__(translator, tree): translator.dispatch(tree) def call(translator, method, node): node.src = method(translator, node) + def default_pre(translator, node): + if getattr(node, 'src', None) is not None: + return True # node.src is already calculated, stop dispatching def default_post(translator, node): throw(NotImplementedError, node) def postGenExpr(translator, node): From 6e340ffcd5f37b34d9a907e60e8e6435339fe4fd Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 15 Dec 2018 15:05:33 +0300 Subject: [PATCH 421/547] Add pony.egg-info/ to .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index d58e61b2f..b42e08466 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ pony/orm/tests/coverage.bat pony/orm/tests/htmlcov/*.* MANIFEST docs/_build/ +pony.egg-info/ From d2508c4947ceb0f2b2dc673048ea79d9b67dc811 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 15 Dec 2018 20:27:51 +0300 Subject: [PATCH 422/547] MySQL `group_concat_max_len` option set to 32-bit platforms' max value --- pony/orm/dbproviders/mysql.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index ae8eecb37..23ea895aa 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -227,6 +227,7 @@ def inspect_connection(provider, connection): provider.max_time_precision = 6 cursor.execute('select database()') provider.default_schema_name = cursor.fetchone()[0] + cursor.execute('set session group_concat_max_len = 4294967295') def should_reconnect(provider, exc): return isinstance(exc, mysql_module.OperationalError) and exc.args[0] == 2006 From 68bb40f05a98856984aef4dd13835d8efe755348 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 15 Dec 2018 19:37:10 +0300 Subject: [PATCH 423/547] Add total_stat across all SQL queries: db.local_stats[None] --- pony/orm/core.py | 29 ++++++++++++++++++----------- pony/orm/tests/test_prefetching.py | 10 +++++----- 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index a3c40b15d..7e306276f 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -791,16 +791,27 @@ def _update_local_stat(database, sql, query_start_time): dblocal = database._dblocal dblocal.last_sql = sql stats = dblocal.stats + query_end_time = time() + duration = query_end_time - query_start_time + stat = stats.get(sql) - if stat is not None: stat.query_executed(query_start_time) - else: stats[sql] = QueryStat(sql, query_start_time) + if stat is not None: + stat.query_executed(duration) + else: + stats[sql] = QueryStat(sql, duration) + + total_stat = stats.get(None) + if total_stat is not None: + total_stat.query_executed(duration) + else: + stats[None] = QueryStat(None, duration) def merge_local_stats(database): setdefault = database._global_stats.setdefault with database._global_stats_lock: for sql, stat in iteritems(database._dblocal.stats): global_stat = setdefault(sql, stat) if global_stat is not stat: global_stat.merge(stat) - database._dblocal.stats.clear() + database._dblocal.stats = {None: QueryStat(None)} @property def global_stats(database): with database._global_stats_lock: @@ -1659,14 +1670,12 @@ def decorator(func): class DbLocal(localbase): def __init__(dblocal): - dblocal.stats = {} + dblocal.stats = {None: QueryStat(None)} dblocal.last_sql = None class QueryStat(object): - def __init__(stat, sql, query_start_time=None): - if query_start_time is not None: - query_end_time = time() - duration = query_end_time - query_start_time + def __init__(stat, sql, duration=None): + if duration is not None: stat.min_time = stat.max_time = stat.sum_time = duration stat.db_count = 1 stat.cache_count = 0 @@ -1679,9 +1688,7 @@ def copy(stat): result = object.__new__(QueryStat) result.__dict__.update(stat.__dict__) return result - def query_executed(stat, query_start_time): - query_end_time = time() - duration = query_end_time - query_start_time + def query_executed(stat, duration): if stat.db_count: stat.min_time = builtins.min(stat.min_time, duration) stat.max_time = builtins.max(stat.max_time, duration) diff --git a/pony/orm/tests/test_prefetching.py b/pony/orm/tests/test_prefetching.py index f5f4d40e3..d7afb54b4 100644 --- a/pony/orm/tests/test_prefetching.py +++ b/pony/orm/tests/test_prefetching.py @@ -120,7 +120,7 @@ def test_13(self): for g in q: # 1 query for s in g.students: # 2 query b = s.biography # 5 queries - query_count = sum(stat.db_count for stat in db.local_stats.values()) + query_count = db.local_stats[None].db_count self.assertEqual(query_count, 8) def test_14(self): @@ -130,7 +130,7 @@ def test_14(self): for g in q: # 1 query for s in g.students: # 1 query b = s.biography # 5 queries - query_count = sum(stat.db_count for stat in db.local_stats.values()) + query_count = db.local_stats[None].db_count self.assertEqual(query_count, 7) def test_15(self): @@ -143,7 +143,7 @@ def test_15(self): for g in q: # 1 query for s in g.students: # 1 query b = s.biography # 0 queries - query_count = sum(stat.db_count for stat in db.local_stats.values()) + query_count = db.local_stats[None].db_count self.assertEqual(query_count, 2) def test_16(self): @@ -153,7 +153,7 @@ def test_16(self): for c in q: # 1 query for s in c.students: # 2 queries (as it is many-to-many relationship) b = s.biography # 0 queries - query_count = sum(stat.db_count for stat in db.local_stats.values()) + query_count = db.local_stats[None].db_count self.assertEqual(query_count, 3) def test_17(self): @@ -164,7 +164,7 @@ def test_17(self): for s in c.students: # 2 queries (as it is many-to-many relationship) m = s.group.major # 1 query b = s.biography # 0 queries - query_count = sum(stat.db_count for stat in db.local_stats.values()) + query_count = db.local_stats[None].db_count self.assertEqual(query_count, 4) From b442cb302beb2dbbbd58888741cb97e871274377 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 5 Jan 2019 17:32:35 +0300 Subject: [PATCH 424/547] Wrap obj.get_pk() with @cut_traceback --- pony/orm/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pony/orm/core.py b/pony/orm/core.py index 7e306276f..778fb6b45 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4587,6 +4587,7 @@ def __init__(obj, *args, **kwargs): obj._save_pos_ = len(objects_to_save) objects_to_save.append(obj) cache.modified = True + @cut_traceback def get_pk(obj): pkval = obj._get_raw_pkval_() if len(pkval) == 1: return pkval[0] From aa16b4f789cb93995f676601d26d2387c32870f7 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 5 Jan 2019 17:49:54 +0300 Subject: [PATCH 425/547] A bit more aggressive cache clearing after rollback in interactive mode --- pony/orm/core.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 778fb6b45..f9dcb6a1b 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1835,17 +1835,16 @@ def close(cache, rollback=True): provider.release(connection, cache) finally: db_session = cache.db_session or local.db_session - if db_session: - if db_session.strict: - for obj in cache.objects: - obj._vals_ = obj._dbvals_ = obj._session_cache_ = None - cache.perm_cache = cache.user_roles_cache = cache.obj_labels_cache = None - else: - for obj in cache.objects: - obj._dbvals_ = obj._session_cache_ = None - for attr, setdata in iteritems(obj._vals_): - if attr.is_collection: - if not setdata.is_fully_loaded: obj._vals_[attr] = None + if db_session and db_session.strict: + for obj in cache.objects: + obj._vals_ = obj._dbvals_ = obj._session_cache_ = None + cache.perm_cache = cache.user_roles_cache = cache.obj_labels_cache = None + else: + for obj in cache.objects: + obj._dbvals_ = obj._session_cache_ = None + for attr, setdata in iteritems(obj._vals_): + if attr.is_collection: + if not setdata.is_fully_loaded: obj._vals_[attr] = None cache.objects = cache.objects_to_save = cache.saved_objects = cache.query_results \ = cache.indexes = cache.seeds = cache.for_update = cache.max_id_cache \ From e5f06c858be9c9f299c8380cca72926b907ba9d8 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 6 Jan 2019 17:23:24 +0300 Subject: [PATCH 426/547] Fix accessing global variables from hybrid methods and properties --- pony/orm/sqltranslation.py | 2 +- .../test_hybrid_methods_and_properties.py | 44 ++++++++++++++++--- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 4d20c2635..a0ebd2656 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -882,7 +882,7 @@ def apply_lambda(translator, func_id, filter_num, order_by, func_ast, argnames, translator.code_key = func_id translator.filter_num = filter_num translator.extractors.update(extractors) - translator.vars = vars.copy() if vars is not None else None + translator.vars = vars translator.vartypes = translator.vartypes.copy() # make HashableDict mutable again translator.vartypes.update(vartypes) diff --git a/pony/orm/tests/test_hybrid_methods_and_properties.py b/pony/orm/tests/test_hybrid_methods_and_properties.py index 64e2f1cda..aa54a52b8 100644 --- a/pony/orm/tests/test_hybrid_methods_and_properties.py +++ b/pony/orm/tests/test_hybrid_methods_and_properties.py @@ -5,7 +5,10 @@ db = Database('sqlite', ':memory:') +sep = ' ' + class Person(db.Entity): + id = PrimaryKey(int) first_name = Required(str) last_name = Required(str) favorite_color = Optional(str) @@ -13,11 +16,11 @@ class Person(db.Entity): @property def full_name(self): - return self.first_name + ' ' + self.last_name + return self.first_name + sep + self.last_name @property def full_name_2(self): - return concat(self.first_name, ' ', self.last_name) # tests using of function `concat` from external scope + return concat(self.first_name, sep, self.last_name) # tests using of function `concat` from external scope @property def has_car(self): @@ -60,10 +63,10 @@ class Car(db.Entity): db.generate_mapping(create_tables=True) with db_session: - p1 = Person(first_name='Alexander', last_name='Kozlovsky', favorite_color='white') - p2 = Person(first_name='Alexei', last_name='Malashkevich', favorite_color='green') - p3 = Person(first_name='Vitaliy', last_name='Abetkin') - p4 = Person(first_name='Alexander', last_name='Tischenko', favorite_color='blue') + p1 = Person(id=1, first_name='Alexander', last_name='Kozlovsky', favorite_color='white') + p2 = Person(id=2, first_name='Alexei', last_name='Malashkevich', favorite_color='green') + p3 = Person(id=3, first_name='Vitaliy', last_name='Abetkin') + p4 = Person(id=4, first_name='Alexander', last_name='Tischenko', favorite_color='blue') c1 = Car(brand='Peugeot', model='306', owner=p1, year=2006, price=14000, color='red') c2 = Car(brand='Honda', model='Accord', owner=p1, year=2007, price=13850, color='white') @@ -157,6 +160,35 @@ def test15(self): result = Person.find_by_full_name('Alexander Tischenko') self.assertEqual(set(obj.last_name for obj in result), {'Tischenko'}) + @db_session + def test16(self): + result = Person.select(lambda p: p.full_name == 'Alexander Kozlovsky') + self.assertEqual(set(p.id for p in result), {1}) + + @db_session + def test17(self): + global sep + sep = '.' + try: + result = Person.select(lambda p: p.full_name == 'Alexander.Kozlovsky') + self.assertEqual(set(p.id for p in result), {1}) + finally: + sep = ' ' + + @db_session + def test18(self): + result = Person.select().filter(lambda p: p.full_name == 'Alexander Kozlovsky') + self.assertEqual(set(p.id for p in result), {1}) + + @db_session + def test19(self): + global sep + sep = '.' + try: + result = Person.select().filter(lambda p: p.full_name == 'Alexander.Kozlovsky') + self.assertEqual(set(p.id for p in result), {1}) + finally: + sep = ' ' if __name__ == '__main__': unittest.main() From f3a590fe34325c1bb77aaa1c89421c242a64453e Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 5 Jan 2019 20:09:47 +0300 Subject: [PATCH 427/547] Support of select(x for x in y.items) --- pony/orm/core.py | 41 +++++++++++++++++++++++++++++++++--- pony/orm/tests/test_query.py | 4 ++-- 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index f9dcb6a1b..783297095 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -3218,6 +3218,35 @@ def unpickle_setwrapper(obj, attrname, items): setdata.count = len(setdata) return wrapper + +class SetIterator(object): + def __init__(self, wrapper): + self._wrapper = wrapper + self._query = None + self._iter = None + + def __iter__(self): + return self + + def next(self): + if self._iter is None: + self._iter = iter(self._wrapper.copy()) + return next(self._iter) + + __next__ = next + + def _get_query(self): + if self._query is None: + self._query = self._wrapper.select() + return self._query + + def _get_type_(self): + return QueryType(self._get_query()) + + def _normalize_var(self, query_type): + return query_type, self._get_query() + + class SetInstance(object): __slots__ = '_obj_', '_attr_', '_attrnames_' _parent_ = None @@ -3340,7 +3369,7 @@ def count(wrapper): return setdata.count @cut_traceback def __iter__(wrapper): - return iter(wrapper.copy()) + return SetIterator(wrapper) @cut_traceback def __eq__(wrapper, other): if isinstance(other, SetInstance): @@ -5486,8 +5515,8 @@ def extract_vars(code_key, filter_num, extractors, globals, locals, cells=None): if isinstance(value, QueryResult) and value._items: value = tuple(value._items) - if isinstance(value, (Query, QueryResult)): - query = value._query if isinstance(value, QueryResult) else value + if isinstance(value, (Query, QueryResult, SetIterator)): + query = value._get_query() vars.update(query._vars) vartypes.update(query._translator.vartypes) @@ -5529,6 +5558,8 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False prev_query = origin._query elif isinstance(origin, QueryResultIterator): prev_query = origin._query_result._query + elif isinstance(origin, SetIterator): + prev_query = origin._query else: prev_query = None if not isinstance(origin, EntityMeta): @@ -5583,6 +5614,8 @@ def __init__(query, code_key, tree, globals, locals, cells=None, left_join=False query._distinct = None query._prefetch = False query._prefetch_context = PrefetchContext(query._database) + def _get_query(query): + return query def _get_type_(query): return QueryType(query) def _normalize_var(query, query_type): @@ -6140,6 +6173,8 @@ def __init__(self, query, limit, offset, lazy): self._items = None if lazy else self._query._actual_fetch(limit, offset) self._expr_type = translator.expr_type self._col_names = translator.col_names + def _get_query(self): + return self._query def _get_type_(self): if self._items is None: return QueryType(self._query, self._limit, self._offset) diff --git a/pony/orm/tests/test_query.py b/pony/orm/tests/test_query.py index 42559afbd..328d16fca 100644 --- a/pony/orm/tests/test_query.py +++ b/pony/orm/tests/test_query.py @@ -43,10 +43,10 @@ def test1(self): def test2(self): X = [1, 2, 3] select('x for x in X') - @raises_exception(TypeError, "Query can only iterate over entity or another query (not a list of objects)") def test3(self): g = Group[1] - select(s for s in g.students) + students = select(s for s in g.students) + self.assertEqual(set(g.students), set(students)) @raises_exception(ExprEvalError, "`a` raises NameError: global name 'a' is not defined" if PYPY2 else "`a` raises NameError: name 'a' is not defined") def test4(self): From 6aa22d10089471a26daa8c18e0c90a682b7f3036 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Fri, 4 Jan 2019 13:22:26 +0300 Subject: [PATCH 428/547] Fix flask extension: validate for attribute wasn't expect LocalProxy objects --- pony/orm/core.py | 3 ++- pony/orm/ormtypes.py | 7 ++----- pony/utils/utils.py | 6 ++++++ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 783297095..6e87b44c6 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -32,7 +32,7 @@ from pony import utils from pony.utils import localbase, decorator, cut_traceback, cut_traceback_depth, throw, reraise, truncate_repr, \ get_lambda_args, pickle_ast, unpickle_ast, deprecated, import_module, parse_expr, is_ident, tostring, strjoin, \ - between, concat, coalesce, HashableDict + between, concat, coalesce, HashableDict, deref_flask_local_proxy __all__ = [ 'pony', @@ -2144,6 +2144,7 @@ def __repr__(attr): def __lt__(attr, other): return attr.id < other.id def validate(attr, val, obj=None, entity=None, from_db=False): + val = deref_flask_local_proxy(val) if val is None: if not attr.nullable and not from_db and not attr.is_required: # for required attribute the exception will be thrown later with another message diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 66adf4589..8e58afe3b 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -7,7 +7,7 @@ from functools import wraps, WRAPPER_ASSIGNMENTS from uuid import UUID -from pony.utils import throw, parse_expr +from pony.utils import throw, parse_expr, deref_flask_local_proxy NoneType = type(None) @@ -144,11 +144,8 @@ def __ne__(self, other): def normalize(value): + value = deref_flask_local_proxy(value) t = type(value) - if t.__name__ == 'LocalProxy' and '_get_current_object' in t.__dict__: - value = value._get_current_object() - t = type(value) - if t is tuple: item_types, item_values = [], [] for item in value: diff --git a/pony/utils/utils.py b/pony/utils/utils.py index 9bfce9d8d..c607234fc 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -590,3 +590,9 @@ def __deepcopy__(self, memo): popitem = _hashable_wrap(dict.popitem) setdefault = _hashable_wrap(dict.setdefault) update = _hashable_wrap(dict.update) + +def deref_flask_local_proxy(value): + t = type(value) + if t.__name__ == 'LocalProxy' and '_get_current_object' in t.__dict__: + value = value._get_current_object() + return value From 1bd29b514bbfdf118cb99a72a1509b99bd1801bb Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 15 Dec 2018 19:55:48 +0300 Subject: [PATCH 429/547] make_proxy(obj) function creates a proxy object which can be used across different db_sessions --- pony/orm/core.py | 64 +- pony/orm/ormtypes.py | 4 +- pony/orm/tests/test_entity_proxy.py | 161 ++++ pony/utils/utils.py | 1201 ++++++++++++++------------- 4 files changed, 827 insertions(+), 603 deletions(-) create mode 100644 pony/orm/tests/test_entity_proxy.py diff --git a/pony/orm/core.py b/pony/orm/core.py index 6e87b44c6..82ffeec61 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -32,7 +32,7 @@ from pony import utils from pony.utils import localbase, decorator, cut_traceback, cut_traceback_depth, throw, reraise, truncate_repr, \ get_lambda_args, pickle_ast, unpickle_ast, deprecated, import_module, parse_expr, is_ident, tostring, strjoin, \ - between, concat, coalesce, HashableDict, deref_flask_local_proxy + between, concat, coalesce, HashableDict, deref_proxy __all__ = [ 'pony', @@ -55,7 +55,7 @@ 'PrimaryKey', 'Required', 'Optional', 'Set', 'Discriminator', 'composite_key', 'composite_index', - 'flush', 'commit', 'rollback', 'db_session', 'with_transaction', + 'flush', 'commit', 'rollback', 'db_session', 'with_transaction', 'make_proxy', 'LongStr', 'LongUnicode', 'Json', 'IntArray', 'StrArray', 'FloatArray', @@ -2144,7 +2144,7 @@ def __repr__(attr): def __lt__(attr, other): return attr.id < other.id def validate(attr, val, obj=None, entity=None, from_db=False): - val = deref_flask_local_proxy(val) + val = deref_proxy(val) if val is None: if not attr.nullable and not from_db and not attr.is_required: # for required attribute the exception will be thrown later with another message @@ -4547,6 +4547,64 @@ def unpickle_entity(d): def safe_repr(obj): return Entity.__repr__(obj) +def make_proxy(obj): + proxy = EntityProxy(obj) + return proxy + +class EntityProxy(object): + def __init__(self, obj): + entity = obj.__class__ + object.__setattr__(self, '_entity_', entity) + pkval = obj.get_pk() + if pkval is None: + cache = obj._session_cache_ + if obj._status_ in del_statuses or cache is None or not cache.is_alive: + throw(ValueError, 'Cannot make a proxy for %s object: primary key is not specified' % entity.__name__) + flush() + pkval = obj.get_pk() + assert pkval is not None + object.__setattr__(self, '_obj_pk_', pkval) + + def __repr__(self): + entity = self._entity_ + pkval = self._obj_pk_ + pkrepr = ','.join(repr(item) for item in pkval) if isinstance(pkval, tuple) else repr(pkval) + return '' % (entity.__name__, pkrepr) + + def _get_object(self): + entity = self._entity_ + pkval = self._obj_pk_ + cache = entity._database_._get_cache() + attrs = entity._pk_attrs_ + if attrs in cache.indexes and pkval in cache.indexes[attrs]: + obj = cache.indexes[attrs][pkval] + else: + obj = entity[pkval] + return obj + + def __getattr__(self, name): + obj = self._get_object() + return getattr(obj, name) + + def __setattr__(self, name, value): + obj = self._get_object() + setattr(obj, name, value) + + def __eq__(self, other): + entity = self._entity_ + pkval = self._obj_pk_ + if isinstance(other, EntityProxy): + entity2 = other._entity_ + pkval2 = other._obj_pk_ + return entity == entity2 and pkval == pkval2 + elif isinstance(other, entity): + return pkval == other._pkval_ + return False + + def __ne__(self, other): + return not self.__eq__(other) + + class Entity(with_metaclass(EntityMeta)): __slots__ = '_session_cache_', '_status_', '_pkval_', '_newid_', '_dbvals_', '_vals_', '_rbits_', '_wbits_', '_save_pos_', '__weakref__' def __reduce__(obj): diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index 8e58afe3b..ae6cccd8f 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -7,7 +7,7 @@ from functools import wraps, WRAPPER_ASSIGNMENTS from uuid import UUID -from pony.utils import throw, parse_expr, deref_flask_local_proxy +from pony.utils import throw, parse_expr, deref_proxy NoneType = type(None) @@ -144,7 +144,7 @@ def __ne__(self, other): def normalize(value): - value = deref_flask_local_proxy(value) + value = deref_proxy(value) t = type(value) if t is tuple: item_types, item_values = [], [] diff --git a/pony/orm/tests/test_entity_proxy.py b/pony/orm/tests/test_entity_proxy.py new file mode 100644 index 000000000..ab8130ab4 --- /dev/null +++ b/pony/orm/tests/test_entity_proxy.py @@ -0,0 +1,161 @@ +import unittest + +from pony.orm import * +from pony.orm.tests.testutils import * + +class TestProxy(unittest.TestCase): + def setUp(self): + db = self.db = Database('sqlite', ':memory:') + + class Country(db.Entity): + id = PrimaryKey(int) + name = Required(str) + persons = Set("Person") + + class Person(db.Entity): + id = PrimaryKey(int) + name = Required(str) + country = Required(Country) + + db.generate_mapping(create_tables=True) + + with db_session: + c1 = Country(id=1, name='Russia') + c2 = Country(id=2, name='Japan') + Person(id=1, name='Alexander Nevskiy', country=c1) + Person(id=2, name='Raikou Minamoto', country=c2) + Person(id=3, name='Ibaraki Douji', country=c2) + + + def test_1(self): + db = self.db + with db_session: + p = make_proxy(db.Person[2]) + + with db_session: + x1 = db.local_stats[None].db_count # number of queries + # it is possible to access p attributes in a new db_session + name = p.name + country = p.country + x2 = db.local_stats[None].db_count + + # p.name and p.country are loaded with a single query + self.assertEqual(x1, x2-1) + + def test_2(self): + db = self.db + with db_session: + p = make_proxy(db.Person[2]) + name = p.name + country = p.country + + with db_session: + x1 = db.local_stats[None].db_count + name = p.name + country = p.country + x2 = db.local_stats[None].db_count + + # attribute values from the first db_session should be ignored and loaded again + self.assertEqual(x1, x2-1) + + def test_3(self): + db = self.db + with db_session: + p = db.Person[2] + proxy = make_proxy(p) + + with db_session: + p2 = db.Person[2] + name1 = 'Tamamo no Mae' + # It is possible to assign new attribute values to a proxy object + p2.name = name1 + name2 = proxy.name + + self.assertEqual(name1, name2) + + + def test_4(self): + db = self.db + with db_session: + p = db.Person[2] + proxy = make_proxy(p) + + with db_session: + p2 = db.Person[2] + name1 = 'Tamamo no Mae' + p2.name = name1 + + with db_session: + # new attribute value was successfully stored in the database + name2 = proxy.name + + self.assertEqual(name1, name2) + + def test_5(self): + db = self.db + with db_session: + p = db.Person[2] + r = repr(p) + self.assertEqual(r, 'Person[2]') + + proxy = make_proxy(p) + r = repr(proxy) + # proxy object has specific repr + self.assertEqual(r, '') + + r = repr(proxy) + # repr of proxy object can be used outside of db_session + self.assertEqual(r, '') + + del p + r = repr(proxy) + # repr works even if the original object was deleted + self.assertEqual(r, '') + + + def test_6(self): + db = self.db + with db_session: + p = db.Person[2] + proxy = make_proxy(p) + proxy.name = 'Okita Souji' + # after assignment, the attribute value is the same for the proxy and for the original object + self.assertEqual(proxy.name, 'Okita Souji') + self.assertEqual(p.name, 'Okita Souji') + + + def test_7(self): + db = self.db + with db_session: + p = db.Person[2] + proxy = make_proxy(p) + proxy.name = 'Okita Souji' + # after assignment, the attribute value is the same for the proxy and for the original object + self.assertEqual(proxy.name, 'Okita Souji') + self.assertEqual(p.name, 'Okita Souji') + + + def test_8(self): + db = self.db + with db_session: + c1 = db.Country[1] + c1_proxy = make_proxy(c1) + p2 = db.Person[2] + self.assertNotEqual(p2.country, c1) + self.assertNotEqual(p2.country, c1_proxy) + # proxy can be used in attribute assignment + p2.country = c1_proxy + self.assertEqual(p2.country, c1_proxy) + self.assertIs(p2.country, c1) + + + def test_9(self): + db = self.db + with db_session: + c2 = db.Country[2] + c2_proxy = make_proxy(c2) + persons = select(p for p in db.Person if p.country == c2_proxy) + self.assertEqual({p.id for p in persons}, {2, 3}) + +if __name__ == '__main__': + unittest.main() diff --git a/pony/utils/utils.py b/pony/utils/utils.py index c607234fc..abf429eb6 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -1,598 +1,603 @@ -#coding: cp1251 - -from __future__ import absolute_import, print_function -from pony.py23compat import PY2, imap, basestring, unicode, pickle, iteritems - -import io, re, os, os.path, sys, datetime, inspect, types, linecache, warnings, json - -from itertools import count as _count -from inspect import isfunction, ismethod -from time import strptime -from os import urandom -from codecs import BOM_UTF8, BOM_LE, BOM_BE -from locale import getpreferredencoding -from bisect import bisect -from collections import defaultdict -from functools import update_wrapper, wraps -from xml.etree import cElementTree -from copy import deepcopy - -import pony -from pony import options - -from pony.thirdparty.compiler import ast -from pony.thirdparty.decorator import decorator as _decorator - -if pony.MODE.startswith('GAE-'): localbase = object -else: from threading import local as localbase - - -class PonyDeprecationWarning(DeprecationWarning): - pass - -def deprecated(stacklevel, message): - warnings.warn(message, PonyDeprecationWarning, stacklevel) - -warnings.simplefilter('once', PonyDeprecationWarning) - -def _improved_decorator(caller, func): - if isfunction(func): - return _decorator(caller, func) - def pony_wrapper(*args, **kwargs): - return caller(func, *args, **kwargs) - return pony_wrapper - -def decorator(caller, func=None): - if func is not None: - return _improved_decorator(caller, func) - def new_decorator(func): - return _improved_decorator(caller, func) - if isfunction(caller): - update_wrapper(new_decorator, caller) - return new_decorator - -##def simple_decorator(dec): -## def new_dec(func): -## def pony_wrapper(*args, **kwargs): -## return dec(func, *args, **kwargs) -## return copy_func_attrs(pony_wrapper, func, dec.__name__) -## return copy_func_attrs(new_dec, dec, 'simple_decorator') - -##@simple_decorator -##def decorator_with_params(dec, *args, **kwargs): -## if len(args) == 1 and not kwargs: -## func = args[0] -## new_func = dec(func) -## return copy_func_attrs(new_func, func, dec.__name__) -## def parameterized_decorator(old_func): -## new_func = dec(func, *args, **kwargs) -## return copy_func_attrs(new_func, func, dec.__name__) -## return parameterized_decorator - -def decorator_with_params(dec): - def parameterized_decorator(*args, **kwargs): - if len(args) == 1 and isfunction(args[0]) and not kwargs: - return decorator(dec(), args[0]) - return decorator(dec(*args, **kwargs)) - return parameterized_decorator - -@decorator -def cut_traceback(func, *args, **kwargs): - if not options.CUT_TRACEBACK: - return func(*args, **kwargs) - - try: return func(*args, **kwargs) - except AssertionError: raise - except Exception: - exc_type, exc, tb = sys.exc_info() - full_tb = tb - last_pony_tb = None - try: - while tb.tb_next: - module_name = tb.tb_frame.f_globals['__name__'] - if module_name == 'pony' or (module_name is not None # may be None during import - and module_name.startswith('pony.')): - last_pony_tb = tb - tb = tb.tb_next - if last_pony_tb is None: raise - module_name = tb.tb_frame.f_globals.get('__name__') or '' - if module_name.startswith('pony.utils') and tb.tb_frame.f_code.co_name == 'throw': - reraise(exc_type, exc, last_pony_tb) - reraise(exc_type, exc, full_tb) - finally: - del exc, full_tb, tb, last_pony_tb - -cut_traceback_depth = 2 - -if pony.MODE != 'INTERACTIVE': - cut_traceback_depth = 0 - def cut_traceback(func): - return func - -if PY2: - exec('''def reraise(exc_type, exc, tb): - try: raise exc_type, exc, tb - finally: del tb''') -else: - def reraise(exc_type, exc, tb): - try: raise exc.with_traceback(tb) - finally: del exc, tb - -def throw(exc_type, *args, **kwargs): - if isinstance(exc_type, Exception): - assert not args and not kwargs - exc = exc_type - else: exc = exc_type(*args, **kwargs) - exc.__cause__ = None - try: - if not (pony.MODE == 'INTERACTIVE' and options.CUT_TRACEBACK): - raise exc - else: - raise exc # Set "pony.options.CUT_TRACEBACK = False" to see full traceback - finally: del exc - -def truncate_repr(s, max_len=100): - s = repr(s) - return s if len(s) <= max_len else s[:max_len-3] + '...' - -codeobjects = {} - -def get_codeobject_id(codeobject): - codeobject_id = id(codeobject) - if codeobject_id not in codeobjects: - codeobjects[codeobject_id] = codeobject - return codeobject_id - -lambda_args_cache = {} - -def get_lambda_args(func): - if type(func) is types.FunctionType: - codeobject = func.func_code if PY2 else func.__code__ - cache_key = get_codeobject_id(codeobject) - elif isinstance(func, ast.Lambda): - cache_key = func - else: assert False # pragma: no cover - - names = lambda_args_cache.get(cache_key) - if names is not None: return names - - if type(func) is types.FunctionType: - if hasattr(inspect, 'signature'): - names, argsname, kwname, defaults = [], None, None, [] - for p in inspect.signature(func).parameters.values(): - if p.default is not p.empty: - defaults.append(p.default) - - if p.kind == p.POSITIONAL_OR_KEYWORD: - names.append(p.name) - elif p.kind == p.VAR_POSITIONAL: - argsname = p.name - elif p.kind == p.VAR_KEYWORD: - kwname = p.name - elif p.kind == p.POSITIONAL_ONLY: - throw(TypeError, 'Positional-only arguments like %s are not supported' % p.name) - elif p.kind == p.KEYWORD_ONLY: - throw(TypeError, 'Keyword-only arguments like %s are not supported' % p.name) - else: assert False - else: - names, argsname, kwname, defaults = inspect.getargspec(func) - elif isinstance(func, ast.Lambda): - names = func.argnames - if func.kwargs: names, kwname = names[:-1], names[-1] - else: kwname = None - if func.varargs: names, argsname = names[:-1], names[-1] - else: argsname = None - defaults = func.defaults - else: assert False # pragma: no cover - if argsname: throw(TypeError, '*%s is not supported' % argsname) - if kwname: throw(TypeError, '**%s is not supported' % kwname) - if defaults: throw(TypeError, 'Defaults are not supported') - - lambda_args_cache[cache_key] = names - return names - -_cache = {} -MAX_CACHE_SIZE = 1000 - -@decorator -def cached(f, *args, **kwargs): - key = (f, args, tuple(sorted(kwargs.items()))) - value = _cache.get(key) - if value is not None: return value - if len(_cache) == MAX_CACHE_SIZE: _cache.clear() - return _cache.setdefault(key, f(*args, **kwargs)) - -def error_method(*args, **kwargs): - raise TypeError() - -_ident_re = re.compile(r'^[A-Za-z_]\w*\Z') - -# is_ident = ident_re.match -def is_ident(string): - 'is_ident(string) -> bool' - return bool(_ident_re.match(string)) - -_name_parts_re = re.compile(r''' - [A-Z][A-Z0-9]+(?![a-z]) # ACRONYM - | [A-Z][a-z]* # Capitalized or single capital - | [a-z]+ # all-lowercase - | [0-9]+ # numbers - | _+ # underscores - ''', re.VERBOSE) - -def split_name(name): - "split_name('Some_FUNNYName') -> ['Some', 'FUNNY', 'Name']" - if not _ident_re.match(name): - raise ValueError('Name is not correct Python identifier') - list = _name_parts_re.findall(name) - if not (list[0].strip('_') and list[-1].strip('_')): - raise ValueError('Name must not starting or ending with underscores') - return [ s for s in list if s.strip('_') ] - -def uppercase_name(name): - "uppercase_name('Some_FUNNYName') -> 'SOME_FUNNY_NAME'" - return '_'.join(s.upper() for s in split_name(name)) - -def lowercase_name(name): - "uppercase_name('Some_FUNNYName') -> 'some_funny_name'" - return '_'.join(s.lower() for s in split_name(name)) - -def camelcase_name(name): - "uppercase_name('Some_FUNNYName') -> 'SomeFunnyName'" - return ''.join(s.capitalize() for s in split_name(name)) - -def mixedcase_name(name): - "mixedcase_name('Some_FUNNYName') -> 'someFunnyName'" - list = split_name(name) - return list[0].lower() + ''.join(s.capitalize() for s in list[1:]) - -def import_module(name): - "import_module('a.b.c') -> " - mod = sys.modules.get(name) - if mod is not None: return mod - mod = __import__(name) - components = name.split('.') - for comp in components[1:]: mod = getattr(mod, comp) - return mod - -if sys.platform == 'win32': - _absolute_re = re.compile(r'^(?:[A-Za-z]:)?[\\/]') -else: _absolute_re = re.compile(r'^/') - -def is_absolute_path(filename): - return bool(_absolute_re.match(filename)) - -def absolutize_path(filename, frame_depth): - if is_absolute_path(filename): return filename - code_filename = sys._getframe(frame_depth+1).f_code.co_filename - if not is_absolute_path(code_filename): - if code_filename.startswith('<') and code_filename.endswith('>'): - if pony.MODE == 'INTERACTIVE': raise ValueError( - 'When in interactive mode, please provide absolute file path. Got: %r' % filename) - raise EnvironmentError('Unexpected module filename, which is not absolute file path: %r' % code_filename) - code_path = os.path.dirname(code_filename) - return os.path.join(code_path, filename) - -def shortened_filename(filename): - if pony.MAIN_DIR is None: return filename - maindir = pony.MAIN_DIR + os.sep - if filename.startswith(maindir): return filename[len(maindir):] - return filename - -def get_mtime(filename): - stat = os.stat(filename) - mtime = stat.st_mtime - if sys.platform == "win32": mtime -= stat.st_ctime - return mtime - -coding_re = re.compile(r'coding[:=]\s*([-\w.]+)') - -def detect_source_encoding(filename): - for i, line in enumerate(linecache.getlines(filename)): - if i == 0 and line.startswith(BOM_UTF8): return 'utf-8' - if not line.lstrip().startswith('#'): continue - match = coding_re.search(line) - if match is not None: return match.group(1) - else: return options.SOURCE_ENCODING or getpreferredencoding() - -escape_re = re.compile(r''' - (?' % (x.__class__.__name__) - return '<%s object at 0x%X>' % (x.__class__.__name__) - -def strjoin(sep, strings, source_encoding='ascii', dest_encoding=None): - "Can join mix of unicode and byte strings in different encodings" - strings = list(strings) - try: return sep.join(strings) - except UnicodeDecodeError: pass - for i, s in enumerate(strings): - if isinstance(s, str): - strings[i] = s.decode(source_encoding, 'replace').replace(u'\ufffd', '?') - result = sep.join(strings) - if dest_encoding is None: return result - return result.encode(dest_encoding, 'replace') - -def make_offsets(s): - offsets = [ 0 ] - si = -1 - try: - while True: - si = s.index('\n', si + 1) - offsets.append(si + 1) - except ValueError: pass - offsets.append(len(s)) - return offsets - -def pos2lineno(pos, offsets): - line = bisect(offsets, pos, 0, len(offsets)-1) - if line == 1: offset = pos - else: offset = pos - offsets[line - 1] - return line, offset - -def getline(text, offsets, lineno): - return text[offsets[lineno-1]:offsets[lineno]] - -def getlines(text, offsets, lineno, context=1): - if context <= 0: return [], None - start = max(0, lineno - 1 - context//2) - end = min(len(offsets)-1, start + context) - start = max(0, end - context) - lines = [] - for i in range(start, end): lines.append(text[offsets[i]:offsets[i+1]]) - index = lineno - 1 - start - return lines, index - -def getlines2(filename, lineno, context=1): - if context <= 0: return [], None - lines = linecache.getlines(filename) - if not lines: return [], None - start = max(0, lineno - 1 - context//2) - end = min(len(lines), start + context) - start = max(0, end - context) - lines = lines[start:start+context] - index = lineno - 1 - start - return lines, index - -def count(*args, **kwargs): - if kwargs: return _count(*args, **kwargs) - if len(args) != 1: return _count(*args) - arg = args[0] - if hasattr(arg, 'count'): return arg.count() - try: it = iter(arg) - except TypeError: return _count(arg) - return len(set(it)) - -def avg(iter): - count = 0 - sum = 0.0 - for elem in iter: - if elem is None: continue - sum += elem - count += 1 - if not count: return None - return sum / count - -def group_concat(items, sep=','): - if items is None: - return None - return str(sep).join(str(item) for item in items) - -def coalesce(*args): - for arg in args: - if arg is not None: - return arg - return None - -def distinct(iter): - d = defaultdict(int) - for item in iter: - d[item] = d[item] + 1 - return d - -def concat(*args): - return ''.join(tostring(arg) for arg in args) - -def between(x, a, b): - return a <= x <= b - -def is_utf8(encoding): - return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') - -def _persistent_id(obj): - if obj is Ellipsis: - return "Ellipsis" - -def _persistent_load(persid): - if persid == "Ellipsis": - return Ellipsis - raise pickle.UnpicklingError("unsupported persistent object") - -def pickle_ast(val): - pickled = io.BytesIO() - pickler = pickle.Pickler(pickled) - pickler.persistent_id = _persistent_id - pickler.dump(val) - return pickled - -def unpickle_ast(pickled): - pickled.seek(0) - unpickler = pickle.Unpickler(pickled) - unpickler.persistent_load = _persistent_load - return unpickler.load() - -def copy_ast(tree): - return unpickle_ast(pickle_ast(tree)) - -def _hashable_wrap(func): - @wraps(func, assigned=('__name__', '__doc__')) - def new_func(self, *args, **kwargs): - if getattr(self, '_hash', None) is not None: - assert False, 'Cannot mutate HashableDict instance after the hash value is calculated' - return func(self, *args, **kwargs) - return new_func - -class HashableDict(dict): - def __hash__(self): - result = getattr(self, '_hash', None) - if result is None: - result = 0 - for key, value in self.items(): - result ^= hash(key) - result ^= hash(value) - self._hash = result - return result - def __deepcopy__(self, memo): - if getattr(self, '_hash', None) is not None: - return self - return HashableDict({deepcopy(key, memo): deepcopy(value, memo) - for key, value in iteritems(self)}) - __setitem__ = _hashable_wrap(dict.__setitem__) - __delitem__ = _hashable_wrap(dict.__delitem__) - clear = _hashable_wrap(dict.clear) - pop = _hashable_wrap(dict.pop) - popitem = _hashable_wrap(dict.popitem) - setdefault = _hashable_wrap(dict.setdefault) - update = _hashable_wrap(dict.update) - -def deref_flask_local_proxy(value): - t = type(value) - if t.__name__ == 'LocalProxy' and '_get_current_object' in t.__dict__: - value = value._get_current_object() - return value +#coding: cp1251 + +from __future__ import absolute_import, print_function +from pony.py23compat import PY2, imap, basestring, unicode, pickle, iteritems + +import io, re, os, os.path, sys, datetime, inspect, types, linecache, warnings, json + +from itertools import count as _count +from inspect import isfunction, ismethod +from time import strptime +from os import urandom +from codecs import BOM_UTF8, BOM_LE, BOM_BE +from locale import getpreferredencoding +from bisect import bisect +from collections import defaultdict +from functools import update_wrapper, wraps +from xml.etree import cElementTree +from copy import deepcopy + +import pony +from pony import options + +from pony.thirdparty.compiler import ast +from pony.thirdparty.decorator import decorator as _decorator + +if pony.MODE.startswith('GAE-'): localbase = object +else: from threading import local as localbase + + +class PonyDeprecationWarning(DeprecationWarning): + pass + +def deprecated(stacklevel, message): + warnings.warn(message, PonyDeprecationWarning, stacklevel) + +warnings.simplefilter('once', PonyDeprecationWarning) + +def _improved_decorator(caller, func): + if isfunction(func): + return _decorator(caller, func) + def pony_wrapper(*args, **kwargs): + return caller(func, *args, **kwargs) + return pony_wrapper + +def decorator(caller, func=None): + if func is not None: + return _improved_decorator(caller, func) + def new_decorator(func): + return _improved_decorator(caller, func) + if isfunction(caller): + update_wrapper(new_decorator, caller) + return new_decorator + +##def simple_decorator(dec): +## def new_dec(func): +## def pony_wrapper(*args, **kwargs): +## return dec(func, *args, **kwargs) +## return copy_func_attrs(pony_wrapper, func, dec.__name__) +## return copy_func_attrs(new_dec, dec, 'simple_decorator') + +##@simple_decorator +##def decorator_with_params(dec, *args, **kwargs): +## if len(args) == 1 and not kwargs: +## func = args[0] +## new_func = dec(func) +## return copy_func_attrs(new_func, func, dec.__name__) +## def parameterized_decorator(old_func): +## new_func = dec(func, *args, **kwargs) +## return copy_func_attrs(new_func, func, dec.__name__) +## return parameterized_decorator + +def decorator_with_params(dec): + def parameterized_decorator(*args, **kwargs): + if len(args) == 1 and isfunction(args[0]) and not kwargs: + return decorator(dec(), args[0]) + return decorator(dec(*args, **kwargs)) + return parameterized_decorator + +@decorator +def cut_traceback(func, *args, **kwargs): + if not options.CUT_TRACEBACK: + return func(*args, **kwargs) + + try: return func(*args, **kwargs) + except AssertionError: raise + except Exception: + exc_type, exc, tb = sys.exc_info() + full_tb = tb + last_pony_tb = None + try: + while tb.tb_next: + module_name = tb.tb_frame.f_globals['__name__'] + if module_name == 'pony' or (module_name is not None # may be None during import + and module_name.startswith('pony.')): + last_pony_tb = tb + tb = tb.tb_next + if last_pony_tb is None: raise + module_name = tb.tb_frame.f_globals.get('__name__') or '' + if module_name.startswith('pony.utils') and tb.tb_frame.f_code.co_name == 'throw': + reraise(exc_type, exc, last_pony_tb) + reraise(exc_type, exc, full_tb) + finally: + del exc, full_tb, tb, last_pony_tb + +cut_traceback_depth = 2 + +if pony.MODE != 'INTERACTIVE': + cut_traceback_depth = 0 + def cut_traceback(func): + return func + +if PY2: + exec('''def reraise(exc_type, exc, tb): + try: raise exc_type, exc, tb + finally: del tb''') +else: + def reraise(exc_type, exc, tb): + try: raise exc.with_traceback(tb) + finally: del exc, tb + +def throw(exc_type, *args, **kwargs): + if isinstance(exc_type, Exception): + assert not args and not kwargs + exc = exc_type + else: exc = exc_type(*args, **kwargs) + exc.__cause__ = None + try: + if not (pony.MODE == 'INTERACTIVE' and options.CUT_TRACEBACK): + raise exc + else: + raise exc # Set "pony.options.CUT_TRACEBACK = False" to see full traceback + finally: del exc + +def truncate_repr(s, max_len=100): + s = repr(s) + return s if len(s) <= max_len else s[:max_len-3] + '...' + +codeobjects = {} + +def get_codeobject_id(codeobject): + codeobject_id = id(codeobject) + if codeobject_id not in codeobjects: + codeobjects[codeobject_id] = codeobject + return codeobject_id + +lambda_args_cache = {} + +def get_lambda_args(func): + if type(func) is types.FunctionType: + codeobject = func.func_code if PY2 else func.__code__ + cache_key = get_codeobject_id(codeobject) + elif isinstance(func, ast.Lambda): + cache_key = func + else: assert False # pragma: no cover + + names = lambda_args_cache.get(cache_key) + if names is not None: return names + + if type(func) is types.FunctionType: + if hasattr(inspect, 'signature'): + names, argsname, kwname, defaults = [], None, None, [] + for p in inspect.signature(func).parameters.values(): + if p.default is not p.empty: + defaults.append(p.default) + + if p.kind == p.POSITIONAL_OR_KEYWORD: + names.append(p.name) + elif p.kind == p.VAR_POSITIONAL: + argsname = p.name + elif p.kind == p.VAR_KEYWORD: + kwname = p.name + elif p.kind == p.POSITIONAL_ONLY: + throw(TypeError, 'Positional-only arguments like %s are not supported' % p.name) + elif p.kind == p.KEYWORD_ONLY: + throw(TypeError, 'Keyword-only arguments like %s are not supported' % p.name) + else: assert False + else: + names, argsname, kwname, defaults = inspect.getargspec(func) + elif isinstance(func, ast.Lambda): + names = func.argnames + if func.kwargs: names, kwname = names[:-1], names[-1] + else: kwname = None + if func.varargs: names, argsname = names[:-1], names[-1] + else: argsname = None + defaults = func.defaults + else: assert False # pragma: no cover + if argsname: throw(TypeError, '*%s is not supported' % argsname) + if kwname: throw(TypeError, '**%s is not supported' % kwname) + if defaults: throw(TypeError, 'Defaults are not supported') + + lambda_args_cache[cache_key] = names + return names + +_cache = {} +MAX_CACHE_SIZE = 1000 + +@decorator +def cached(f, *args, **kwargs): + key = (f, args, tuple(sorted(kwargs.items()))) + value = _cache.get(key) + if value is not None: return value + if len(_cache) == MAX_CACHE_SIZE: _cache.clear() + return _cache.setdefault(key, f(*args, **kwargs)) + +def error_method(*args, **kwargs): + raise TypeError() + +_ident_re = re.compile(r'^[A-Za-z_]\w*\Z') + +# is_ident = ident_re.match +def is_ident(string): + 'is_ident(string) -> bool' + return bool(_ident_re.match(string)) + +_name_parts_re = re.compile(r''' + [A-Z][A-Z0-9]+(?![a-z]) # ACRONYM + | [A-Z][a-z]* # Capitalized or single capital + | [a-z]+ # all-lowercase + | [0-9]+ # numbers + | _+ # underscores + ''', re.VERBOSE) + +def split_name(name): + "split_name('Some_FUNNYName') -> ['Some', 'FUNNY', 'Name']" + if not _ident_re.match(name): + raise ValueError('Name is not correct Python identifier') + list = _name_parts_re.findall(name) + if not (list[0].strip('_') and list[-1].strip('_')): + raise ValueError('Name must not starting or ending with underscores') + return [ s for s in list if s.strip('_') ] + +def uppercase_name(name): + "uppercase_name('Some_FUNNYName') -> 'SOME_FUNNY_NAME'" + return '_'.join(s.upper() for s in split_name(name)) + +def lowercase_name(name): + "uppercase_name('Some_FUNNYName') -> 'some_funny_name'" + return '_'.join(s.lower() for s in split_name(name)) + +def camelcase_name(name): + "uppercase_name('Some_FUNNYName') -> 'SomeFunnyName'" + return ''.join(s.capitalize() for s in split_name(name)) + +def mixedcase_name(name): + "mixedcase_name('Some_FUNNYName') -> 'someFunnyName'" + list = split_name(name) + return list[0].lower() + ''.join(s.capitalize() for s in list[1:]) + +def import_module(name): + "import_module('a.b.c') -> " + mod = sys.modules.get(name) + if mod is not None: return mod + mod = __import__(name) + components = name.split('.') + for comp in components[1:]: mod = getattr(mod, comp) + return mod + +if sys.platform == 'win32': + _absolute_re = re.compile(r'^(?:[A-Za-z]:)?[\\/]') +else: _absolute_re = re.compile(r'^/') + +def is_absolute_path(filename): + return bool(_absolute_re.match(filename)) + +def absolutize_path(filename, frame_depth): + if is_absolute_path(filename): return filename + code_filename = sys._getframe(frame_depth+1).f_code.co_filename + if not is_absolute_path(code_filename): + if code_filename.startswith('<') and code_filename.endswith('>'): + if pony.MODE == 'INTERACTIVE': raise ValueError( + 'When in interactive mode, please provide absolute file path. Got: %r' % filename) + raise EnvironmentError('Unexpected module filename, which is not absolute file path: %r' % code_filename) + code_path = os.path.dirname(code_filename) + return os.path.join(code_path, filename) + +def shortened_filename(filename): + if pony.MAIN_DIR is None: return filename + maindir = pony.MAIN_DIR + os.sep + if filename.startswith(maindir): return filename[len(maindir):] + return filename + +def get_mtime(filename): + stat = os.stat(filename) + mtime = stat.st_mtime + if sys.platform == "win32": mtime -= stat.st_ctime + return mtime + +coding_re = re.compile(r'coding[:=]\s*([-\w.]+)') + +def detect_source_encoding(filename): + for i, line in enumerate(linecache.getlines(filename)): + if i == 0 and line.startswith(BOM_UTF8): return 'utf-8' + if not line.lstrip().startswith('#'): continue + match = coding_re.search(line) + if match is not None: return match.group(1) + else: return options.SOURCE_ENCODING or getpreferredencoding() + +escape_re = re.compile(r''' + (?' % (x.__class__.__name__) + return '<%s object at 0x%X>' % (x.__class__.__name__) + +def strjoin(sep, strings, source_encoding='ascii', dest_encoding=None): + "Can join mix of unicode and byte strings in different encodings" + strings = list(strings) + try: return sep.join(strings) + except UnicodeDecodeError: pass + for i, s in enumerate(strings): + if isinstance(s, str): + strings[i] = s.decode(source_encoding, 'replace').replace(u'\ufffd', '?') + result = sep.join(strings) + if dest_encoding is None: return result + return result.encode(dest_encoding, 'replace') + +def make_offsets(s): + offsets = [ 0 ] + si = -1 + try: + while True: + si = s.index('\n', si + 1) + offsets.append(si + 1) + except ValueError: pass + offsets.append(len(s)) + return offsets + +def pos2lineno(pos, offsets): + line = bisect(offsets, pos, 0, len(offsets)-1) + if line == 1: offset = pos + else: offset = pos - offsets[line - 1] + return line, offset + +def getline(text, offsets, lineno): + return text[offsets[lineno-1]:offsets[lineno]] + +def getlines(text, offsets, lineno, context=1): + if context <= 0: return [], None + start = max(0, lineno - 1 - context//2) + end = min(len(offsets)-1, start + context) + start = max(0, end - context) + lines = [] + for i in range(start, end): lines.append(text[offsets[i]:offsets[i+1]]) + index = lineno - 1 - start + return lines, index + +def getlines2(filename, lineno, context=1): + if context <= 0: return [], None + lines = linecache.getlines(filename) + if not lines: return [], None + start = max(0, lineno - 1 - context//2) + end = min(len(lines), start + context) + start = max(0, end - context) + lines = lines[start:start+context] + index = lineno - 1 - start + return lines, index + +def count(*args, **kwargs): + if kwargs: return _count(*args, **kwargs) + if len(args) != 1: return _count(*args) + arg = args[0] + if hasattr(arg, 'count'): return arg.count() + try: it = iter(arg) + except TypeError: return _count(arg) + return len(set(it)) + +def avg(iter): + count = 0 + sum = 0.0 + for elem in iter: + if elem is None: continue + sum += elem + count += 1 + if not count: return None + return sum / count + +def group_concat(items, sep=','): + if items is None: + return None + return str(sep).join(str(item) for item in items) + +def coalesce(*args): + for arg in args: + if arg is not None: + return arg + return None + +def distinct(iter): + d = defaultdict(int) + for item in iter: + d[item] = d[item] + 1 + return d + +def concat(*args): + return ''.join(tostring(arg) for arg in args) + +def between(x, a, b): + return a <= x <= b + +def is_utf8(encoding): + return encoding.upper().replace('_', '').replace('-', '') in ('UTF8', 'UTF', 'U8') + +def _persistent_id(obj): + if obj is Ellipsis: + return "Ellipsis" + +def _persistent_load(persid): + if persid == "Ellipsis": + return Ellipsis + raise pickle.UnpicklingError("unsupported persistent object") + +def pickle_ast(val): + pickled = io.BytesIO() + pickler = pickle.Pickler(pickled) + pickler.persistent_id = _persistent_id + pickler.dump(val) + return pickled + +def unpickle_ast(pickled): + pickled.seek(0) + unpickler = pickle.Unpickler(pickled) + unpickler.persistent_load = _persistent_load + return unpickler.load() + +def copy_ast(tree): + return unpickle_ast(pickle_ast(tree)) + +def _hashable_wrap(func): + @wraps(func, assigned=('__name__', '__doc__')) + def new_func(self, *args, **kwargs): + if getattr(self, '_hash', None) is not None: + assert False, 'Cannot mutate HashableDict instance after the hash value is calculated' + return func(self, *args, **kwargs) + return new_func + +class HashableDict(dict): + def __hash__(self): + result = getattr(self, '_hash', None) + if result is None: + result = 0 + for key, value in self.items(): + result ^= hash(key) + result ^= hash(value) + self._hash = result + return result + def __deepcopy__(self, memo): + if getattr(self, '_hash', None) is not None: + return self + return HashableDict({deepcopy(key, memo): deepcopy(value, memo) + for key, value in iteritems(self)}) + __setitem__ = _hashable_wrap(dict.__setitem__) + __delitem__ = _hashable_wrap(dict.__delitem__) + clear = _hashable_wrap(dict.clear) + pop = _hashable_wrap(dict.pop) + popitem = _hashable_wrap(dict.popitem) + setdefault = _hashable_wrap(dict.setdefault) + update = _hashable_wrap(dict.update) + +def deref_proxy(value): + t = type(value) + if t.__name__ == 'LocalProxy' and '_get_current_object' in t.__dict__: + # Flask local proxy + value = value._get_current_object() + elif t.__name__ == 'EntityProxy': + # Pony proxy + value = value._get_object() + + return value From 9e7cd21c5829f6306864e7af9bb7329c44376eef Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 12 Jan 2019 17:40:42 +0300 Subject: [PATCH 430/547] Add support of ON DELETE CASCADE/SET NULL --- pony/orm/core.py | 8 ++- pony/orm/dbschema.py | 11 ++-- pony/orm/tests/test_cascade.py | 92 ++++++++++++++++++++++++++++++++++ 3 files changed, 107 insertions(+), 4 deletions(-) create mode 100644 pony/orm/tests/test_cascade.py diff --git a/pony/orm/core.py b/pony/orm/core.py index 82ffeec61..31b62bd31 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1111,7 +1111,13 @@ def get_columns(table, column_names): parent_table = schema.tables[rentity._table_] parent_columns = get_columns(parent_table, rentity._pk_columns_) child_columns = get_columns(table, attr.columns) - table.add_foreign_key(attr.reverse.fk_name, child_columns, parent_table, parent_columns, attr.index) + if attr.reverse.cascade_delete: + on_delete = 'CASCADE' + elif isinstance(attr, Optional) and attr.nullable: + on_delete = 'SET NULL' + else: + on_delete = None + table.add_foreign_key(attr.reverse.fk_name, child_columns, parent_table, parent_columns, attr.index, on_delete) elif attr.index and attr.columns: if isinstance(attr.py_type, Array) and provider.dialect != 'PostgreSQL': pass # GIN indexes are supported only in PostgreSQL diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index 5079b49f6..a2bc4e372 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -186,12 +186,12 @@ def add_index(table, index_name, columns, is_pk=False, is_unique=None, m2m=False if index and index.name == index_name and index.is_pk == is_pk and index.is_unique == is_unique: return index return table.schema.index_class(index_name, table, columns, is_pk, is_unique) - def add_foreign_key(table, fk_name, child_columns, parent_table, parent_columns, index_name=None): + def add_foreign_key(table, fk_name, child_columns, parent_table, parent_columns, index_name=None, on_delete=False): if fk_name is None: provider = table.schema.provider child_column_names = tuple(column.name for column in child_columns) fk_name = provider.get_default_fk_name(table.name, parent_table.name, child_column_names) - return table.schema.fk_class(fk_name, table, child_columns, parent_table, parent_columns, index_name) + return table.schema.fk_class(fk_name, table, child_columns, parent_table, parent_columns, index_name, on_delete) class Column(object): auto_template = '%(type)s PRIMARY KEY AUTOINCREMENT' @@ -239,6 +239,8 @@ def get_sql(column): append(case('REFERENCES')) append(quote_name(parent_table.name)) append(schema.column_list(foreign_key.parent_columns)) + if foreign_key.on_delete: + append('ON DELETE %s' % foreign_key.on_delete) return ' '.join(result) class Constraint(DBObject): @@ -322,7 +324,7 @@ def _get_create_sql(index, inside_table): class ForeignKey(Constraint): typename = 'Foreign key' - def __init__(foreign_key, name, child_table, child_columns, parent_table, parent_columns, index_name): + def __init__(foreign_key, name, child_table, child_columns, parent_table, parent_columns, index_name, on_delete): schema = parent_table.schema if schema is not child_table.schema: throw(DBSchemaError, 'Parent and child tables of foreign_key cannot belong to different schemata') @@ -348,6 +350,7 @@ def __init__(foreign_key, name, child_table, child_columns, parent_table, parent foreign_key.parent_columns = parent_columns foreign_key.child_table = child_table foreign_key.child_columns = child_columns + foreign_key.on_delete = on_delete if index_name is not False: child_columns_len = len(child_columns) @@ -379,6 +382,8 @@ def _get_create_sql(foreign_key, inside_table): append(case('REFERENCES')) append(quote_name(foreign_key.parent_table.name)) append(schema.column_list(foreign_key.parent_columns)) + if foreign_key.on_delete: + append(case('ON DELETE %s' % foreign_key.on_delete)) return ' '.join(cmd) DBSchema.table_class = Table diff --git a/pony/orm/tests/test_cascade.py b/pony/orm/tests/test_cascade.py new file mode 100644 index 000000000..a05cc07b3 --- /dev/null +++ b/pony/orm/tests/test_cascade.py @@ -0,0 +1,92 @@ +import unittest + +from pony.orm import * +from pony.orm.tests.testutils import * + +class TestCascade(unittest.TestCase): + + def test_1(self): + db = self.db = Database('sqlite', ':memory:') + + class Person(self.db.Entity): + name = Required(str) + group = Required('Group') + + class Group(self.db.Entity): + persons = Set(Person) + + db.generate_mapping(create_tables=True) + + self.assertTrue('ON DELETE CASCADE' in self.db.schema.tables['Person'].get_create_command()) + + def test_2(self): + db = self.db = Database('sqlite', ':memory:') + + class Person(self.db.Entity): + name = Required(str) + group = Required('Group') + + class Group(self.db.Entity): + persons = Set(Person, cascade_delete=True) + + db.generate_mapping(create_tables=True) + + self.assertTrue('ON DELETE CASCADE' in self.db.schema.tables['Person'].get_create_command()) + + + def test_3(self): + db = self.db = Database('sqlite', ':memory:') + + class Person(self.db.Entity): + name = Required(str) + group = Optional('Group') + + class Group(self.db.Entity): + persons = Set(Person, cascade_delete=True) + + db.generate_mapping(create_tables=True) + + self.assertTrue('ON DELETE CASCADE' in self.db.schema.tables['Person'].get_create_command()) + + @raises_exception(TypeError, "'cascade_delete' option cannot be set for attribute Group.persons, because reverse attribute Person.group is collection") + def test_4(self): + db = self.db = Database('sqlite', ':memory:') + + class Person(self.db.Entity): + name = Required(str) + group = Set('Group') + + class Group(self.db.Entity): + persons = Set(Person, cascade_delete=True) + + db.generate_mapping(create_tables=True) + + @raises_exception(TypeError, "'cascade_delete' option cannot be set for both sides of relationship (Person.group and Group.persons) simultaneously") + def test_5(self): + db = self.db = Database('sqlite', ':memory:') + + class Person(self.db.Entity): + name = Required(str) + group = Set('Group', cascade_delete=True) + + class Group(self.db.Entity): + persons = Required(Person, cascade_delete=True) + + db.generate_mapping(create_tables=True) + + def test_6(self): + db = self.db = Database('sqlite', ':memory:') + + class Person(self.db.Entity): + name = Required(str) + group = Set('Group') + + class Group(self.db.Entity): + persons = Optional(Person) + + db.generate_mapping(create_tables=True) + + self.assertTrue('ON DELETE SET NULL' in self.db.schema.tables['Group'].get_create_command()) + +if __name__ == '__main__': + unittest.main() From 7a7719b37d83dd92e1f6680d663dfed89cd2a3d7 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 12 Jan 2019 18:59:06 +0300 Subject: [PATCH 431/547] Show all attribute options in show(Entity) call --- pony/orm/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/core.py b/pony/orm/core.py index 31b62bd31..e79639663 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2493,6 +2493,8 @@ def describe(attr): options = [] if attr.args: options.append(', '.join(imap(str, attr.args))) if attr.auto: options.append('auto=True') + for k, v in sorted(attr.kwargs.items()): + options.append('%s=%r' % (k, v)) if not isinstance(attr, PrimaryKey) and attr.is_unique: options.append('unique=True') if attr.default is not None: options.append('default=%r' % attr.default) if not options: options = '' From 342a15d5027e1757f4d79572ecedd2772d575c13 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 17 Jan 2019 11:28:44 +0300 Subject: [PATCH 432/547] Update changelog and change Pony version: 0.7.7-dev -> 0.7.7 --- CHANGELOG.md | 46 ++++++++++++++++++++++++++++++++++++++++++++++ pony/__init__.py | 2 +- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ffffbd27..acb476677 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,49 @@ +# PonyORM release 0.7.7 (2019-01-17) + +## Major features + +* Array type support for PostgreSQL and SQLite +* isinstance() support in queries +* Support of queries based on collections: select(x for x in y.items) + +## Other features + +* Support of Entity.select(**kwargs) +* Support of SKIP LOCKED option in 'SELECT ... FOR UPDATE' +* New function make_proxy(obj) to make cros-db_session proxy objects +* Specify ON DELETE CASCADE/SET NULL in foreign keys +* Support of LIMIT in `SELECT FROM (SELECT ...)` type of queries +* Support for negative JSON array indexes in SQLite + +## Improvements + +* Improved query prefetching: use fewer number of SQL queries +* Memory optimization: deduplication of values recieved from the database in the same session +* increase DBAPIProvider.max_params_count value + +## Bugfixes + +* #405: breaking change with cx_Oracle 7.0: DML RETURNING now returns a list +* #380: db_session should work with async functions +* #385: test fails with python3.6 +* #386: release unlocked lock error in SQLite +* #390: TypeError: writable buffers are not hashable +* #398: add auto coversion of numpy numeric types +* #404: GAE local run detection +* Fix Flask compatibility: add support of LocalProxy object +* db_session(sql_debug=True) should log SQL commands also during db_session.__exit__() +* Fix duplicated table join in FROM clause +* Fix accessing global variables from hybrid methods and properties +* Fix m2m collection loading bug +* Fix composite index bug: stackoverflow.com/questions/53147694 +* Fix MyEntity[obj.get_pk()] if pk is composite +* MySQL group_concat_max_len option set to max of 32bit platforms to avoid truncation +* Show all attribute options in show(Entity) call +* For nested db_session retry option should be ignored +* Fix py_json_unwrap +* Other minor fixes + + # Pony ORM Release 0.7.6 (2018-08-10) ## Bugfixes diff --git a/pony/__init__.py b/pony/__init__.py index 18d6d1be5..0decf5c25 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.7-dev' +__version__ = '0.7.7' uid = str(random.randint(1, 1000000)) From 213873768bce5db43bb5d85d8a89f3a88cb2c176 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 17 Jan 2019 11:31:46 +0300 Subject: [PATCH 433/547] Update Pony version: 0.7.7 -> 0.7.8-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index 0decf5c25..7f7156ef6 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.7' +__version__ = '0.7.8-dev' uid = str(random.randint(1, 1000000)) From 5c5140736030f4042f3870ca22eea97ba52aeddc Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 18 Jan 2019 23:29:43 +0300 Subject: [PATCH 434/547] Add `stream` option to `query.show()` to make its output testable more easily --- pony/orm/core.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index e79639663..4b69066e8 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5868,8 +5868,8 @@ def _do_prefetch(query, query_result): objects_to_process = next_objects_to_process @cut_traceback - def show(query, width=None): - query._fetch().show(width) + def show(query, width=None, stream=None): + query._fetch().show(width, stream) @cut_traceback def get(query): objects = query[:2] @@ -6305,7 +6305,13 @@ def sort(self, *args, **kwargs): def shuffle(self): shuffle(self._get_items()) @cut_traceback - def show(self, width=None): + def show(self, width=None, stream=None): + if stream is None: + stream = sys.stdout + def writeln(s): + stream.write(s) + stream.write('\n') + if self._items is None: self._items = self._query._actual_fetch(self._limit, self._offset) @@ -6353,10 +6359,11 @@ def to_str(x): for col_num, max_len in remaining_columns.items(): width_dict[col_num] = base_len - print(strjoin('|', (strcut(colname, width_dict[i]) for i, colname in enumerate(col_names)))) - print(strjoin('+', ('-' * width_dict[i] for i in xrange(len(col_names))))) + writeln(strjoin('|', (strcut(colname, width_dict[i]) for i, colname in enumerate(col_names)))) + writeln(strjoin('+', ('-' * width_dict[i] for i in xrange(len(col_names))))) for row in rows: - print(strjoin('|', (strcut(item, width_dict[i]) for i, item in enumerate(row)))) + writeln(strjoin('|', (strcut(item, width_dict[i]) for i, item in enumerate(row)))) + stream.flush() def to_json(self, include=(), exclude=(), converter=None, with_schema=True, schema_hash=None): return self._query._database.to_json(self, include, exclude, converter, with_schema, schema_hash) From 3a8c3fef036bf4f0da18d2298e6c30b61572d53f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 18 Jan 2019 23:32:17 +0300 Subject: [PATCH 435/547] Fix a bug caused by incorrect deduplication of column values --- pony/orm/core.py | 5 ++- pony/orm/tests/test_deduplication.py | 49 ++++++++++++++++++++++++++++ pony/utils/utils.py | 7 ++++ 3 files changed, 58 insertions(+), 3 deletions(-) create mode 100644 pony/orm/tests/test_deduplication.py diff --git a/pony/orm/core.py b/pony/orm/core.py index 4b69066e8..068776a0f 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -32,7 +32,7 @@ from pony import utils from pony.utils import localbase, decorator, cut_traceback, cut_traceback_depth, throw, reraise, truncate_repr, \ get_lambda_args, pickle_ast, unpickle_ast, deprecated, import_module, parse_expr, is_ident, tostring, strjoin, \ - between, concat, coalesce, HashableDict, deref_proxy + between, concat, coalesce, HashableDict, deref_proxy, deduplicate __all__ = [ 'pony', @@ -2206,8 +2206,7 @@ def parse_value(attr, row, offsets, dbvals_deduplication_cache): if len(offsets) > 1: throw(NotImplementedError) offset = offsets[0] dbval = attr.validate(row[offset], None, attr.entity, from_db=True) - try: dbval = dbvals_deduplication_cache.setdefault(dbval, dbval) - except: pass + dbval = deduplicate(dbval, dbvals_deduplication_cache) else: dbvals = [ row[offset] for offset in offsets ] if None in dbvals: diff --git a/pony/orm/tests/test_deduplication.py b/pony/orm/tests/test_deduplication.py new file mode 100644 index 000000000..6842deb26 --- /dev/null +++ b/pony/orm/tests/test_deduplication.py @@ -0,0 +1,49 @@ +from pony.py23compat import StringIO + +import unittest + +from pony import orm + + +db = orm.Database('sqlite', ':memory:') + +class A(db.Entity): + id = orm.PrimaryKey(int) + x = orm.Required(bool) + y = orm.Required(float) + +db.generate_mapping(create_tables=True) + +with orm.db_session: + a1 = A(id=1, x=False, y=3.0) + a2 = A(id=2, x=True, y=4.0) + a3 = A(id=3, x=False, y=1.0) + + +class TestDeduplication(unittest.TestCase): + @orm.db_session + def test_1(self): + a2 = A.get(id=2) + a1 = A.get(id=1) + self.assertIs(a1.id, 1) + + @orm.db_session + def test_2(self): + a3 = A.get(id=3) + a1 = A.get(id=1) + self.assertIs(a1.id, 1) + + @orm.db_session + def test_3(self): + q = A.select().order_by(-1) + stream = StringIO() + q.show(stream=stream) + s = stream.getvalue() + self.assertEqual(s, 'id|x |y \n' + '--+-----+---\n' + '3 |False|1.0\n' + '2 |True |4.0\n' + '1 |False|3.0\n') + + + diff --git a/pony/utils/utils.py b/pony/utils/utils.py index abf429eb6..9aecef744 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -601,3 +601,10 @@ def deref_proxy(value): value = value._get_object() return value + +def deduplicate(value, deduplication_cache): + t = type(value) + try: + return deduplication_cache.setdefault(t, t).setdefault(value, value) + except: + return value From 65a36c8da8bc06202d3bf2d326a6e9c869895184 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 19 Jan 2019 16:57:15 +0300 Subject: [PATCH 436/547] Fixes 414: prefetching Optional relationships fails on 0.7.7 --- pony/orm/core.py | 2 +- pony/orm/tests/test_prefetching.py | 43 +++++++++++++++++++++++++++--- 2 files changed, 41 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 068776a0f..69d7a50cf 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5843,7 +5843,7 @@ def _do_prefetch(query, query_result): collection_prefetch_dict[attr].add(obj) else: obj2 = attr.get(obj) - if obj2 not in all_objects: + if obj2 is not None and obj2 not in all_objects: all_objects.add(obj2) objects_to_prefetch.add(obj2) diff --git a/pony/orm/tests/test_prefetching.py b/pony/orm/tests/test_prefetching.py index d7afb54b4..68509a72f 100644 --- a/pony/orm/tests/test_prefetching.py +++ b/pony/orm/tests/test_prefetching.py @@ -14,6 +14,7 @@ class Student(db.Entity): dob = Optional(date) group = Required('Group') courses = Set('Course') + mentor = Optional('Teacher') biography = Optional(LongStr) class Group(db.Entity): @@ -25,6 +26,10 @@ class Course(db.Entity): name = Required(str, unique=True) students = Set(Student) +class Teacher(db.Entity): + name = Required(str) + students = Set(Student) + db.generate_mapping(create_tables=True) with db_session: @@ -33,10 +38,12 @@ class Course(db.Entity): c1 = Course(name='Math') c2 = Course(name='Physics') c3 = Course(name='Computer Science') - Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='S1 bio') + t1 = Teacher(name='T1') + t2 = Teacher(name='T2') + Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='S1 bio', mentor=t1) Student(id=2, name='S2', group=g1, gpa=4.2, scholarship=100, dob=date(2000, 1, 1), biography='S2 bio') Student(id=3, name='S3', group=g1, gpa=4.7, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) - Student(id=4, name='S4', group=g2, gpa=3.2, biography='S4 bio', courses=[c1, c3]) + Student(id=4, name='S4', group=g2, gpa=3.2, biography='S4 bio', courses=[c1, c3], mentor=t2) Student(id=5, name='S5', group=g2, gpa=4.5, biography='S5 bio', courses=[c1, c3]) class TestPrefetching(unittest.TestCase): @@ -108,7 +115,8 @@ def test_12(self): with db_session: s1 = Student.select().prefetch(Student.biography).first() self.assertEqual(s1.biography, 'S1 bio') - self.assertEqual(db.last_sql, '''SELECT "s"."id", "s"."name", "s"."scholarship", "s"."gpa", "s"."dob", "s"."group", "s"."biography" + self.assertEqual(db.last_sql, +'''SELECT "s"."id", "s"."name", "s"."scholarship", "s"."gpa", "s"."dob", "s"."group", "s"."mentor", "s"."biography" FROM "Student" "s" ORDER BY 1 LIMIT 1''') @@ -167,6 +175,35 @@ def test_17(self): query_count = db.local_stats[None].db_count self.assertEqual(query_count, 4) + def test_18(self): + db.merge_local_stats() + with db_session: + q = Group.select().prefetch(Group.students, Student.biography) + for g in q: # 2 queries + for s in g.students: + m = s.mentor # 0 queries + b = s.biography # 0 queries + query_count = db.local_stats[None].db_count + self.assertEqual(query_count, 2) + + def test_19(self): + db.merge_local_stats() + with db_session: + q = Group.select().prefetch(Group.students, Student.biography, Student.mentor) + mentors = set() + for g in q: # 3 queries + for s in g.students: + m = s.mentor # 0 queries + if m is not None: + mentors.add(m) + b = s.biography # 0 queries + query_count = db.local_stats[None].db_count + self.assertEqual(query_count, 3) + + for m in mentors: + n = m.name # 0 queries + query_count = db.local_stats[None].db_count + self.assertEqual(query_count, 3) if __name__ == '__main__': From c3d3e9d4ed9d4633e8aebc71265c3b33694af3f9 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 19 Jan 2019 17:18:04 +0300 Subject: [PATCH 437/547] Update changelog and pony version: 0.7.8-dev -> 0.7.8 --- CHANGELOG.md | 8 ++++++++ pony/__init__.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index acb476677..487eaac27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +# PonyORM release 0.7.8 (2019-01-19) + +## Bugfixes + +* #414: prefetching Optional relationships fails on 0.7.7 +* Fix a bug caused by incorrect deduplication of column values + + # PonyORM release 0.7.7 (2019-01-17) ## Major features diff --git a/pony/__init__.py b/pony/__init__.py index 7f7156ef6..b3dc363fd 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.8-dev' +__version__ = '0.7.8' uid = str(random.randint(1, 1000000)) From 35fe54e0783644b45045bdd6e42bb1a5d41b008a Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 21 Jan 2019 01:00:52 +0300 Subject: [PATCH 438/547] Update Pony version: 0.7.8 -> 0.7.9-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index b3dc363fd..d77c39169 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.8' +__version__ = '0.7.9-dev' uid = str(random.randint(1, 1000000)) From 2b17c30dc17f7cb4cdfbca693cd1075d06bfb38c Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 21 Jan 2019 00:59:19 +0300 Subject: [PATCH 439/547] Fix empty array param handling --- pony/orm/sqltranslation.py | 2 +- pony/orm/tests/test_array.py | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 1 deletion(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index a0ebd2656..aa2efb44d 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -142,7 +142,7 @@ def dispatch_external(translator, node): elif tt is tuple: params = [] is_array = False - if translator.database.provider.array_converter_cls is not None: + if t and translator.database.provider.array_converter_cls is not None: types = set(t) if len(types) == 1 and unicode in types: item_type = unicode diff --git a/pony/orm/tests/test_array.py b/pony/orm/tests/test_array.py index 728b9ceb5..61d6dbb63 100644 --- a/pony/orm/tests/test_array.py +++ b/pony/orm/tests/test_array.py @@ -223,3 +223,28 @@ def test_35(self): self.assertTrue(['foo', 'bar'] in foo.array3) self.assertTrue(['bar', 'foo'] in foo.array3) self.assertTrue(['baz', 'bar'] not in foo.array3) + + @db_session(sql_debug=True) + def test_36(self): + items = [] + result = select(foo for foo in Foo if foo.id in items)[:] + self.assertEqual(result, []) + + @db_session(sql_debug=True) + def test_37(self): + items = [1] + result = select(foo.id for foo in Foo if foo.id in items)[:] + self.assertEqual(result, [1]) + + @db_session(sql_debug=True) + def test_38(self): + f1 = Foo[1] + items = [f1] + result = select(foo for foo in Foo if foo in items)[:] + self.assertEqual(result, [f1]) + + @db_session(sql_debug=True) + def test_39(self): + items = [] + result = select(foo for foo in Foo if foo in items)[:] + self.assertEqual(result, []) From 5ed45d21a437edfb88292050c0c65a45b75f21dc Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 21 Jan 2019 10:09:38 +0300 Subject: [PATCH 440/547] Fix reading NULL from optional nullable array column --- pony/orm/dbapiprovider.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 378f3bff4..50caac12d 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -846,7 +846,7 @@ def validate(converter, val, obj=None): return TrackedArray(obj, converter.attr, items) def dbval2val(converter, dbval, obj=None): - if obj is None: + if obj is None or dbval is None: return dbval return TrackedArray(obj, converter.attr, dbval) From 61841d360396e3445fa30ab0b63d59c7c31a15a9 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 21 Jan 2019 10:11:10 +0300 Subject: [PATCH 441/547] Fix handling of empty arrays in queries --- pony/orm/sqltranslation.py | 32 ++++++++++++++++++++++---------- pony/orm/tests/test_array.py | 31 +++++++++++++++++++------------ 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index aa2efb44d..6b01fd1c8 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -157,16 +157,15 @@ def dispatch_external(translator, node): else: is_array = True + for i, item_type in enumerate(t): + if item_type is NoneType: + throw(TypeError, 'Expression `%s` should not contain None values' % node.src) + param = ParamMonad.new(item_type, (varkey, i, None)) + params.append(param) + monad = ListMonad(params) if is_array: array_type = array_types.get(item_type, None) - monad = ArrayParamMonad(array_type, (varkey, None, None)) - else: - for i, item_type in enumerate(t): - if item_type is NoneType: - throw(TypeError, 'Expression `%s` should not contain None values' % node.src) - param = ParamMonad.new(item_type, (varkey, i, None)) - params.append(param) - monad = ListMonad(params) + monad = ArrayParamMonad(array_type, (varkey, None, None), list_monad=monad) elif isinstance(t, RawSQLType): monad = RawSQLMonad(t, varkey) else: @@ -2049,6 +2048,11 @@ def contains(monad, key, not_in=False): sql = 'ARRAY_CONTAINS', key.getsql()[0], not_in, monad.getsql()[0] return BoolExprMonad(sql) if isinstance(key, ListMonad): + if not key.items: + if not_in: + return BoolExprMonad(['EQ', ['VALUE', 0], ['VALUE', 1]], nullable=False) + else: + return BoolExprMonad(['EQ', ['VALUE', 1], ['VALUE', 1]], nullable=False) sql = [ 'MAKE_ARRAY' ] sql.extend(item.getsql()[0] for item in key.items) sql = 'ARRAY_SUBSET', sql, not_in, monad.getsql()[0] @@ -2231,7 +2235,7 @@ def new(t, paramkey): result = cls(t, paramkey) result.aggregated = False return result - def __new__(cls, *args): + def __new__(cls, *args, **kwargs): if cls is ParamMonad: assert False, 'Abstract class' # pragma: no cover return Monad.__new__(cls) def __init__(monad, t, paramkey): @@ -2268,7 +2272,15 @@ class TimedeltaParamMonad(TimedeltaMixin, ParamMonad): pass class DatetimeParamMonad(DatetimeMixin, ParamMonad): pass class BufferParamMonad(BufferMixin, ParamMonad): pass class UuidParamMonad(UuidMixin, ParamMonad): pass -class ArrayParamMonad(ArrayMixin, ParamMonad): pass + +class ArrayParamMonad(ArrayMixin, ParamMonad): + def __init__(monad, t, paramkey, list_monad=None): + ParamMonad.__init__(monad, t, paramkey) + monad.list_monad = list_monad + def contains(monad, key, not_in=False): + if key.type is monad.type.item_type: + return monad.list_monad.contains(key, not_in) + return ArrayMixin.contains(monad, key, not_in) class JsonParamMonad(JsonMixin, ParamMonad): def getsql(monad, sqlquery=None): diff --git a/pony/orm/tests/test_array.py b/pony/orm/tests/test_array.py index 61d6dbb63..e446b2aac 100644 --- a/pony/orm/tests/test_array.py +++ b/pony/orm/tests/test_array.py @@ -35,6 +35,11 @@ def test_2(self): foo = select(f for f in Foo if [10, 20, 50] in f.array1)[:] self.assertEqual([Foo[1]], foo) + @db_session + def test_2a(self): + foo = select(f for f in Foo if [] in f.array1)[:] + self.assertEqual([Foo[1]], foo) + @db_session def test_3(self): x = [10, 20, 50] @@ -218,33 +223,35 @@ def test_35(self): self.assertTrue([10, 20] in foo.array1) self.assertTrue([20, 10] in foo.array1) self.assertTrue([10, 1000] not in foo.array1) + self.assertTrue([] in foo.array1) self.assertTrue('bar' in foo.array3) self.assertTrue('baz' not in foo.array3) self.assertTrue(['foo', 'bar'] in foo.array3) self.assertTrue(['bar', 'foo'] in foo.array3) self.assertTrue(['baz', 'bar'] not in foo.array3) + self.assertTrue([] in foo.array3) - @db_session(sql_debug=True) + @db_session def test_36(self): items = [] - result = select(foo for foo in Foo if foo.id in items)[:] + result = select(foo for foo in Foo if foo in items)[:] self.assertEqual(result, []) - @db_session(sql_debug=True) + @db_session def test_37(self): - items = [1] - result = select(foo.id for foo in Foo if foo.id in items)[:] - self.assertEqual(result, [1]) - - @db_session(sql_debug=True) - def test_38(self): f1 = Foo[1] items = [f1] result = select(foo for foo in Foo if foo in items)[:] self.assertEqual(result, [f1]) - @db_session(sql_debug=True) - def test_39(self): + @db_session + def test_38(self): items = [] - result = select(foo for foo in Foo if foo in items)[:] + result = select(foo for foo in Foo if foo.id in items)[:] self.assertEqual(result, []) + + @db_session + def test_39(self): + items = [1] + result = select(foo.id for foo in Foo if foo.id in items)[:] + self.assertEqual(result, [1]) From bf110778ea17809e2a6acb69bf432970f2bc6320 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 21 Jan 2019 10:19:05 +0300 Subject: [PATCH 442/547] Update changelog and change Pony version: 0.7.9-dev -> 0.7.9 --- CHANGELOG.md | 8 ++++++++ pony/__init__.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 487eaac27..713dac2c0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,11 @@ +# PonyORM release 0.7.9 (2019-01-21) + +## Bugfixes + +* Fix handling of empty arrays and empty lists in queries +* Fix reading optional nullable array columns from database + + # PonyORM release 0.7.8 (2019-01-19) ## Bugfixes diff --git a/pony/__init__.py b/pony/__init__.py index d77c39169..a21e8e6ce 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.9-dev' +__version__ = '0.7.9' uid = str(random.randint(1, 1000000)) From 9031bf42b6f6c50d87ece32180ef94c2525369c6 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 21 Jan 2019 10:21:02 +0300 Subject: [PATCH 443/547] Update Pony version: 0.7.9 -> 0.7.10-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index a21e8e6ce..808a03d62 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.9' +__version__ = '0.7.10-dev' uid = str(random.randint(1, 1000000)) From 3b895fd52e02520afafd04c65a00ef0892e48a85 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 30 Jan 2019 11:44:29 +0300 Subject: [PATCH 444/547] Fixes #415: typo --- pony/orm/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 69d7a50cf..791036086 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -834,7 +834,7 @@ def get_connection(database): def disconnect(database): provider = database.provider if provider is None: return - if local.db_context_counter: throw(TransactionError, 'disconnect() cannot be called inside of db_sesison') + if local.db_context_counter: throw(TransactionError, 'disconnect() cannot be called inside of db_session') cache = local.db2cache.get(database) if cache is not None: cache.rollback() provider.disconnect() From 50cf5b7eb5bdfc3a4d7324258c54fd87112b4027 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sun, 10 Feb 2019 15:48:56 +0300 Subject: [PATCH 445/547] Message improved: NotImplementedError on specifying table name for inherited entity --- pony/orm/core.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 791036086..fb6256980 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -974,7 +974,8 @@ def get_columns(table, column_names): is_subclass = entity._root_ is not entity if is_subclass: - if table_name is not None: throw(NotImplementedError) + if table_name is not None: throw(NotImplementedError, + 'Cannot specify table name for entity %r which is subclass of %r' % (entity.__name__, entity._root_.__name__)) table_name = entity._root_._table_ entity._table_ = table_name elif table_name is None: From 0b9667b18aee1361a364d6df99fcf8f1ef6f303b Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 6 Mar 2019 22:48:22 +0300 Subject: [PATCH 446/547] Add second lock to prevent thread starvation on some operating systems --- pony/orm/dbproviders/sqlite.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 483959ff8..04411bc3c 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -296,6 +296,7 @@ class SQLiteProvider(DBAPIProvider): def __init__(provider, *args, **kwargs): DBAPIProvider.__init__(provider, *args, **kwargs) + provider.pre_transaction_lock = Lock() provider.transaction_lock = Lock() @wrap_dbapi_exceptions @@ -308,11 +309,21 @@ def restore_exception(provider): try: reraise(*provider.local_exceptions.exc_info) finally: provider.local_exceptions.exc_info = None + def acquire_lock(provider): + provider.pre_transaction_lock.acquire() + try: + provider.transaction_lock.acquire() + finally: + provider.pre_transaction_lock.release() + + def release_lock(provider): + provider.transaction_lock.release() + @wrap_dbapi_exceptions def set_transaction_mode(provider, connection, cache): assert not cache.in_transaction if cache.immediate: - provider.transaction_lock.acquire() + provider.acquire_lock() try: cursor = connection.cursor() @@ -336,7 +347,7 @@ def set_transaction_mode(provider, connection, cache): elif core.local.debug: log_orm('SWITCH TO AUTOCOMMIT MODE') finally: if cache.immediate and not cache.in_transaction: - provider.transaction_lock.release() + provider.release_lock() def commit(provider, connection, cache=None): in_transaction = cache is not None and cache.in_transaction @@ -345,7 +356,7 @@ def commit(provider, connection, cache=None): finally: if in_transaction: cache.in_transaction = False - provider.transaction_lock.release() + provider.release_lock() def rollback(provider, connection, cache=None): in_transaction = cache is not None and cache.in_transaction @@ -354,7 +365,7 @@ def rollback(provider, connection, cache=None): finally: if in_transaction: cache.in_transaction = False - provider.transaction_lock.release() + provider.release_lock() def drop(provider, connection, cache=None): in_transaction = cache is not None and cache.in_transaction @@ -363,7 +374,7 @@ def drop(provider, connection, cache=None): finally: if in_transaction: cache.in_transaction = False - provider.transaction_lock.release() + provider.release_lock() @wrap_dbapi_exceptions def release(provider, connection, cache=None): From 5618ccee780bb6319847df8acffc217989e26392 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Fri, 15 Mar 2019 21:32:21 +0300 Subject: [PATCH 447/547] Fixes #432: flask can trigger teardown_request without real request. --- pony/flask/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pony/flask/__init__.py b/pony/flask/__init__.py index d63fa6856..75e877cb5 100644 --- a/pony/flask/__init__.py +++ b/pony/flask/__init__.py @@ -8,9 +8,8 @@ def _enter_session(): def _exit_session(exception): session = getattr(request, 'pony_session', None) - if session is None: - raise RuntimeError('Request object lost db_session') - session.__exit__(exc=exception) + if session is not None: + session.__exit__(exc=exception) class Pony(object): def __init__(self, app=None): From 8587f885c8058c291272cf55ee13752eaa85f971 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 24 Mar 2019 18:38:20 +0300 Subject: [PATCH 448/547] Handle case when someone calls db.bind(kwargs) instead of db.bind(**kwargs) --- pony/orm/core.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index fb6256980..958392b21 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -767,13 +767,17 @@ def _bind(self, *args, **kwargs): # argument 'self' cannot be named 'database', because 'database' can be in kwargs if self.provider is not None: throw(BindingError, 'Database object was already bound to %s provider' % self.provider.dialect) + if len(args) == 1 and not kwargs and hasattr(args[0], 'keys'): + args, kwargs = (), args[0] + provider = None if args: provider, args = args[0], args[1:] elif 'provider' not in kwargs: throw(TypeError, 'Database provider is not specified') else: provider = kwargs.pop('provider') if isinstance(provider, type) and issubclass(provider, DBAPIProvider): provider_cls = provider else: - if not isinstance(provider, basestring): throw(TypeError) + if not isinstance(provider, basestring): + throw(TypeError, 'Provider name should be string. Got: %r' % type(provider).__name__) if provider == 'pygresql': throw(TypeError, 'Pony no longer supports PyGreSQL module. Please use psycopg2 instead.') self.provider_name = provider From 9781e061cf5458d89714828b94e6a7f018f09c92 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 24 Mar 2019 18:40:56 +0300 Subject: [PATCH 449/547] Decompiler fixes --- pony/orm/decompiling.py | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index c5513d593..f37fa9aeb 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -154,8 +154,15 @@ def store(decompiler, node): def BINARY_SUBSCR(decompiler): oper2 = decompiler.stack.pop() oper1 = decompiler.stack.pop() - if isinstance(oper2, ast.Tuple): return ast.Subscript(oper1, 'OP_APPLY', list(oper2.nodes)) - else: return ast.Subscript(oper1, 'OP_APPLY', [ oper2 ]) + if isinstance(oper2, ast.Sliceobj) and len(oper2.nodes) == 2: + a, b = oper2.nodes + a = None if isinstance(a, ast.Const) and a.value == None else a + b = None if isinstance(b, ast.Const) and b.value == None else b + return ast.Slice(oper1, 'OP_APPLY', a, b) + elif isinstance(oper2, ast.Tuple): + return ast.Subscript(oper1, 'OP_APPLY', list(oper2.nodes)) + else: + return ast.Subscript(oper1, 'OP_APPLY', [ oper2 ]) def BUILD_CONST_KEY_MAP(decompiler, length): keys = decompiler.stack.pop() @@ -174,7 +181,7 @@ def BUILD_MAP(decompiler, length): data = decompiler.pop_items(2 * length) # [key1, value1, key2, value2, ...] it = iter(data) pairs = list(izip(it, it)) # [(key1, value1), (key2, value2), ...] - return ast.Dict(pairs) + return ast.Dict(tuple(pairs)) def BUILD_SET(decompiler, size): return ast.Set(decompiler.pop_items(size)) @@ -539,8 +546,13 @@ def YIELD_VALUE(decompiler): (a for b in T if f == 5 and r or t) (a for b in T if f and r and t) - (a for b in T if f == 5 and +r or not t) - (a for b in T if -t and ~r or `f`) + # (a for b in T if f == 5 and +r or not t) + # (a for b in T if -t and ~r or `f`) + + # (a for b in T if not x and y) + # (a for b in T if not x and y and z) + # (a for b in T if not x and y or z) + # (a for b in T if x and not y and z) (a**2 for b in T if t * r > y / 3) (a + 2 for b in T if t + r > y // 3) @@ -574,10 +586,12 @@ def YIELD_VALUE(decompiler): (s for s in T if s.a > 20 and (s.x.y == 123 or 'ABC' in s.p.q.r)) (a for b in T1 if c > d for e in T2 if f < g) - (func1(a, a.attr, keyarg=123) for s in T) - (func1(a, a.attr, keyarg=123, *e) for s in T) - (func1(a, b, a.attr1, a.b.c, keyarg1=123, keyarg2='mx', *e, **f) for s in T) - (func(a, a.attr, keyarg=123) for a in T if a.method(x, *y, **z) == 4) + (func1(a, a.attr, x=123) for s in T) + # (func1(a, a.attr, *args) for s in T) + # (func1(a, a.attr, x=123, **kwargs) for s in T) + (func1(a, b, a.attr1, a.b.c, x=123, y='foo') for s in T) + # (func1(a, b, a.attr1, a.b.c, x=123, y='foo', **kwargs) for s in T) + # (func(a, a.attr, keyarg=123) for a in T if a.method(x, *args, **kwargs) == 4) ((x or y) and (p or q) for a in T if (a or b) and (c or d)) (x.y for x in T if (a and (b or (c and d))) or X) From fad8c6863ea439ea043fe597e40b964bbb8792da Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sun, 24 Mar 2019 20:18:43 +0300 Subject: [PATCH 450/547] Fix decompliation of Python 3.7 bytecode with conditions like "if not a and b" --- pony/orm/decompiling.py | 180 +++++++++++++++++++++++++++------------- pony/py23compat.py | 1 + 2 files changed, 123 insertions(+), 58 deletions(-) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index f37fa9aeb..b4a43d19d 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -1,9 +1,10 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PY2, izip, xrange +from pony.py23compat import PY2, izip, xrange, PY37 import sys, types, inspect from opcode import opname as opnames, HAVE_ARGUMENT, EXTENDED_ARG, cmp_op from opcode import hasconst, hasname, hasjrel, haslocal, hascompare, hasfree +from collections import defaultdict from pony.thirdparty.compiler import ast, parse @@ -47,8 +48,6 @@ def simplify(clause): class InvalidQuery(Exception): pass -class AstGenerated(Exception): pass - def binop(node_type, args_holder=tuple): def method(decompiler): oper2 = decompiler.stack.pop() @@ -65,61 +64,99 @@ def __init__(decompiler, code, start=0, end=None): if end is None: end = len(code.co_code) decompiler.end = end decompiler.stack = [] + decompiler.jump_map = defaultdict(list) decompiler.targets = {} decompiler.ast = None decompiler.names = set() decompiler.assnames = set() + decompiler.conditions_end = 0 + decompiler.instructions = [] + decompiler.instructions_map = {} + decompiler.or_jumps = set() + decompiler.get_instructions() + decompiler.analyze_jumps() decompiler.decompile() decompiler.ast = decompiler.stack.pop() decompiler.external_names = decompiler.names - decompiler.assnames assert not decompiler.stack, decompiler.stack - def decompile(decompiler): + def get_instructions(decompiler): PY36 = sys.version_info >= (3, 6) code = decompiler.code co_code = code.co_code free = code.co_cellvars + code.co_freevars - try: - while decompiler.pos < decompiler.end: - i = decompiler.pos - if i in decompiler.targets: decompiler.process_target(i) - op = ord(code.co_code[i]) - if PY36: - extended_arg = 0 + while decompiler.pos < decompiler.end: + i = decompiler.pos + op = ord(code.co_code[i]) + if PY36: + extended_arg = 0 + oparg = ord(code.co_code[i+1]) + while op == EXTENDED_ARG: + extended_arg = (extended_arg | oparg) << 8 + i += 2 + op = ord(code.co_code[i]) oparg = ord(code.co_code[i+1]) - while op == EXTENDED_ARG: - extended_arg = (extended_arg | oparg) << 8 - i += 2 - op = ord(code.co_code[i]) - oparg = ord(code.co_code[i+1]) - oparg = None if op < HAVE_ARGUMENT else oparg | extended_arg + oparg = None if op < HAVE_ARGUMENT else oparg | extended_arg + i += 2 + else: + i += 1 + if op >= HAVE_ARGUMENT: + oparg = ord(co_code[i]) + ord(co_code[i + 1]) * 256 i += 2 - else: - i += 1 - if op >= HAVE_ARGUMENT: - oparg = ord(co_code[i]) + ord(co_code[i + 1]) * 256 + if op == EXTENDED_ARG: + op = ord(code.co_code[i]) + i += 1 + oparg = ord(co_code[i]) + ord(co_code[i + 1]) * 256 + oparg * 65536 i += 2 - if op == EXTENDED_ARG: - op = ord(code.co_code[i]) - i += 1 - oparg = ord(co_code[i]) + ord(co_code[i + 1]) * 256 + oparg * 65536 - i += 2 - if op >= HAVE_ARGUMENT: - if op in hasconst: arg = [code.co_consts[oparg]] - elif op in hasname: arg = [code.co_names[oparg]] - elif op in hasjrel: arg = [i + oparg] - elif op in haslocal: arg = [code.co_varnames[oparg]] - elif op in hascompare: arg = [cmp_op[oparg]] - elif op in hasfree: arg = [free[oparg]] - else: arg = [oparg] - else: arg = [] - opname = opnames[op].replace('+', '_') - # print(opname, arg, decompiler.stack) - method = getattr(decompiler, opname, None) - if method is None: throw(NotImplementedError('Unsupported operation: %s' % opname)) - decompiler.pos = i - x = method(*arg) - if x is not None: decompiler.stack.append(x) - except AstGenerated: pass + if op >= HAVE_ARGUMENT: + if op in hasconst: arg = [code.co_consts[oparg]] + elif op in hasname: arg = [code.co_names[oparg]] + elif op in hasjrel: arg = [i + oparg] + elif op in haslocal: arg = [code.co_varnames[oparg]] + elif op in hascompare: arg = [cmp_op[oparg]] + elif op in hasfree: arg = [free[oparg]] + else: arg = [oparg] + else: arg = [] + opname = opnames[op].replace('+', '_') + if 'JUMP' in opname: + endpos = arg[0] + if endpos < decompiler.pos: + decompiler.conditions_end = i + decompiler.jump_map[endpos].append(decompiler.pos) + decompiler.instructions_map[decompiler.pos] = len(decompiler.instructions) + decompiler.instructions.append((decompiler.pos, i, opname, arg)) + if opname == 'YIELD_VALUE': + return + decompiler.pos = i + def analyze_jumps(decompiler): + i = decompiler.instructions_map[decompiler.conditions_end] + while i > 0: + pos, next_pos, opname, arg = decompiler.instructions[i] + if pos in decompiler.jump_map: + for jump_start_pos in decompiler.jump_map[pos]: + if jump_start_pos > pos: + continue + for or_jump_start_pos in decompiler.or_jumps: + if pos > or_jump_start_pos > jump_start_pos: + break # And jump + else: + decompiler.or_jumps.add(jump_start_pos) + i -= 1 + def decompile(decompiler): + # print(decompiler.conditions_end) + for pos, next_pos, opname, arg in decompiler.instructions: + # print(i, opname, *arg) + if pos in decompiler.targets: + decompiler.process_target(pos) + method = getattr(decompiler, opname, None) + if method is None: + throw(NotImplementedError('Unsupported operation: %s' % opname)) + decompiler.pos = pos + decompiler.next_pos = next_pos + x = method(*arg) + if x is not None: + decompiler.stack.append(x) + # print(decompiler.stack) + def pop_items(decompiler, size): if not size: return () result = decompiler.stack[-size:] @@ -285,18 +322,47 @@ def GET_ITER(decompiler): pass def JUMP_IF_FALSE(decompiler, endpos): - return decompiler.conditional_jump(endpos, ast.And) + return decompiler.conditional_jump(endpos, False) JUMP_IF_FALSE_OR_POP = JUMP_IF_FALSE def JUMP_IF_TRUE(decompiler, endpos): - return decompiler.conditional_jump(endpos, ast.Or) + return decompiler.conditional_jump(endpos, True) JUMP_IF_TRUE_OR_POP = JUMP_IF_TRUE - def conditional_jump(decompiler, endpos, clausetype): - i = decompiler.pos # next instruction - if i in decompiler.targets: decompiler.process_target(i) + def conditional_jump(decompiler, endpos, if_true): + if PY37: return decompiler.conditional_jump_new(endpos, if_true) + return decompiler.conditional_jump_old(endpos, if_true) + + def conditional_jump_old(decompiler, endpos, if_true): + i = decompiler.next_pos + if i in decompiler.targets: + decompiler.process_target(i) + expr = decompiler.stack.pop() + clausetype = ast.Or if if_true else ast.And + clause = clausetype([expr]) + clause.endpos = endpos + decompiler.targets.setdefault(endpos, clause) + return clause + + def conditional_jump_new(decompiler, endpos, if_true): + expr = decompiler.stack.pop() + if decompiler.pos >= decompiler.conditions_end: + clausetype = ast.Or if if_true else ast.And + elif decompiler.pos in decompiler.or_jumps: + clausetype = ast.Or + if not if_true: + expr = ast.Not(expr) + else: + clausetype = ast.And + if if_true: + expr = ast.Not(expr) + decompiler.stack.append(expr) + + if decompiler.next_pos in decompiler.targets: + decompiler.process_target(decompiler.next_pos) + expr = decompiler.stack.pop() clause = clausetype([ expr ]) clause.endpos = endpos @@ -331,7 +397,7 @@ def process_target(decompiler, pos, partial=False): decompiler.stack.append(top) def JUMP_FORWARD(decompiler, endpos): - i = decompiler.pos # next instruction + i = decompiler.next_pos # next instruction decompiler.process_target(i, True) then = decompiler.stack.pop() decompiler.process_target(i, False) @@ -425,10 +491,9 @@ def POP_TOP(decompiler): pass def RETURN_VALUE(decompiler): - if decompiler.pos != decompiler.end: throw(NotImplementedError) + if decompiler.next_pos != decompiler.end: throw(NotImplementedError) expr = decompiler.stack.pop() - decompiler.stack.append(simplify(expr)) - raise AstGenerated() + return simplify(expr) def ROT_TWO(decompiler): tos = decompiler.stack.pop() @@ -530,8 +595,7 @@ def YIELD_VALUE(decompiler): fors.append(top) else: fors.append(top) fors.reverse() - decompiler.stack.append(ast.GenExpr(ast.GenExprInner(simplify(expr), fors))) - raise AstGenerated() + return ast.GenExpr(ast.GenExprInner(simplify(expr), fors)) test_lines = """ (a and b if c and d else e and f for i in T if (A and B if C and D else E and F)) @@ -549,10 +613,10 @@ def YIELD_VALUE(decompiler): # (a for b in T if f == 5 and +r or not t) # (a for b in T if -t and ~r or `f`) - # (a for b in T if not x and y) - # (a for b in T if not x and y and z) - # (a for b in T if not x and y or z) - # (a for b in T if x and not y and z) + (a for b in T if x and not y and z) + (a for b in T if not x and y) + (a for b in T if not x and y and z) + (a for b in T if not x and y or z) #FIXME! (a**2 for b in T if t * r > y / 3) (a + 2 for b in T if t + r > y // 3) diff --git a/pony/py23compat.py b/pony/py23compat.py index f23b99f37..7fe218ee9 100644 --- a/pony/py23compat.py +++ b/pony/py23compat.py @@ -3,6 +3,7 @@ PY2 = sys.version_info[0] == 2 PYPY = platform.python_implementation() == 'PyPy' PYPY2 = PYPY and PY2 +PY37 = sys.version_info[:2] >= (3, 7) if PY2: from future_builtins import zip as izip, map as imap From f044240c75c8b9585af3a293cf50f9843f64a86a Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sun, 14 Apr 2019 14:38:47 +0300 Subject: [PATCH 451/547] Decompiler fix: decompiling lambda leads to pop from empty stack in process targets. --- pony/orm/decompiling.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index b4a43d19d..05f1a4ec8 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -142,9 +142,7 @@ def analyze_jumps(decompiler): decompiler.or_jumps.add(jump_start_pos) i -= 1 def decompile(decompiler): - # print(decompiler.conditions_end) for pos, next_pos, opname, arg in decompiler.instructions: - # print(i, opname, *arg) if pos in decompiler.targets: decompiler.process_target(pos) method = getattr(decompiler, opname, None) @@ -155,7 +153,6 @@ def decompile(decompiler): x = method(*arg) if x is not None: decompiler.stack.append(x) - # print(decompiler.stack) def pop_items(decompiler, size): if not size: return () @@ -378,7 +375,7 @@ def process_target(decompiler, pos, partial=False): top = simplify(top) if top is limit: break if isinstance(top, ast.GenExprFor): break - + if not decompiler.stack: break top2 = decompiler.stack[-1] if isinstance(top2, ast.GenExprFor): break if partial and hasattr(top2, 'endpos') and top2.endpos == pos: break From 314763677f89dd6ce5cec8b33184fa9db260c50e Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sun, 14 Apr 2019 14:38:59 +0300 Subject: [PATCH 452/547] Decompiler tests added --- pony/orm/tests/test_decompiler.py | 98 +++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) create mode 100644 pony/orm/tests/test_decompiler.py diff --git a/pony/orm/tests/test_decompiler.py b/pony/orm/tests/test_decompiler.py new file mode 100644 index 000000000..ac31dc49f --- /dev/null +++ b/pony/orm/tests/test_decompiler.py @@ -0,0 +1,98 @@ +import unittest + +from pony.thirdparty.compiler.transformer import parse +from pony.orm.decompiling import Decompiler +from pony.orm.asttranslation import ast2src + + +def generate_gens(): + patterns = [ + '(x * y) * [z * j)', + '([x * y) * z) * j', + '(x * [y * z)) * j', + 'x * ([y * z) * j)', + 'x * (y * [z * j))' + ] + + ops = ('and', 'or') + nots = (True, False) + + result = [] + + for pat in patterns: + cur = pat + for op1 in ops: + for op2 in ops: + for op3 in ops: + res = cur.replace('*', op1, 1) + res = res.replace('*', op2, 1) + res = res.replace('*', op3, 1) + result.append(res) + + final = [] + + for res in result: + for par1 in nots: + for par2 in nots: + for a in nots: + for b in nots: + for c in nots: + for d in nots: + cur = res.replace('(', 'not(') if not par1 else res + if not par2: + cur = cur.replace('[', 'not(') + else: + cur = cur.replace('[', '(') + if not a: cur = cur.replace('x', 'not x') + if not b: cur = cur.replace('y', 'not y') + if not c: cur = cur.replace('z', 'not z') + if not d: cur = cur.replace('j', 'not j') + final.append(cur) + + return final + +def create_test(gen): + def wrapped_test(self): + def get_condition_values(cond): + result = [] + vals = (True, False) + for x in vals: + for y in vals: + for z in vals: + for j in vals: + result.append(eval(cond, {'x': x, 'y': y, 'z': z, 'j': j})) + return result + src1 = '(a for a in [] if %s)' % gen + src2 = 'lambda x, y, z, j: (%s)' % gen + src3 = '(m for m in [] if %s for n in [] if %s)' % (gen, gen) + + code1 = compile(src1, '', 'eval').co_consts[0] + ast1 = Decompiler(code1).ast + src1 = ast2src(ast1).replace('.0', '[]') + src1 = src1[src1.find('if')+2:-1] + + code2 = compile(src2, '', 'eval').co_consts[0] + ast2 = Decompiler(code2).ast + src2 = ast2src(ast2).replace('.0', '[]') + src2 = src2[src2.find(':')+1:] + + code3 = compile(src3, '', 'eval').co_consts[0] + ast3 = Decompiler(code3).ast + src3 = ast2src(ast3).replace('.0', '[]') + src3 = src3[src3.find('if')+2: src3.rfind('for')-1] + + self.assertEqual(get_condition_values(gen), get_condition_values(src1)) + self.assertEqual(get_condition_values(gen), get_condition_values(src2)) + self.assertEqual(get_condition_values(gen), get_condition_values(src3)) + + return wrapped_test + + +class TestDecompiler(unittest.TestCase): + pass + + +for i, gen in enumerate(generate_gens()): + test_method = create_test(gen) + test_method.__name__ = 'test_decompiler_%d' % i + setattr(TestDecompiler, test_method.__name__, test_method) From c819c2f7f16d0e3888d00dab6bf4fea81d27a9f3 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sun, 14 Apr 2019 18:47:33 +0300 Subject: [PATCH 453/547] PyPy decompiling fix --- pony/orm/decompiling.py | 39 +++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index 05f1a4ec8..4759dfb66 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -1,5 +1,5 @@ from __future__ import absolute_import, print_function, division -from pony.py23compat import PY2, izip, xrange, PY37 +from pony.py23compat import PY2, izip, xrange, PY37, PYPY import sys, types, inspect from opcode import opname as opnames, HAVE_ARGUMENT, EXTENDED_ARG, cmp_op @@ -81,9 +81,11 @@ def __init__(decompiler, code, start=0, end=None): assert not decompiler.stack, decompiler.stack def get_instructions(decompiler): PY36 = sys.version_info >= (3, 6) + before_yield = True code = decompiler.code co_code = code.co_code free = code.co_cellvars + code.co_freevars + decompiler.abs_jump_to_top = decompiler.for_iter_pos = -1 while decompiler.pos < decompiler.end: i = decompiler.pos op = ord(code.co_code[i]) @@ -117,17 +119,33 @@ def get_instructions(decompiler): else: arg = [oparg] else: arg = [] opname = opnames[op].replace('+', '_') - if 'JUMP' in opname: - endpos = arg[0] - if endpos < decompiler.pos: - decompiler.conditions_end = i - decompiler.jump_map[endpos].append(decompiler.pos) - decompiler.instructions_map[decompiler.pos] = len(decompiler.instructions) - decompiler.instructions.append((decompiler.pos, i, opname, arg)) + if opname == 'FOR_ITER': + decompiler.for_iter_pos = decompiler.pos + if opname == 'JUMP_ABSOLUTE' and arg[0] == decompiler.for_iter_pos: + decompiler.abs_jump_to_top = decompiler.pos + + if before_yield: + if 'JUMP' in opname: + endpos = arg[0] + if endpos < decompiler.pos: + decompiler.conditions_end = i + decompiler.jump_map[endpos].append(decompiler.pos) + decompiler.instructions_map[decompiler.pos] = len(decompiler.instructions) + decompiler.instructions.append((decompiler.pos, i, opname, arg)) if opname == 'YIELD_VALUE': - return + before_yield = False decompiler.pos = i def analyze_jumps(decompiler): + if PYPY: + targets = decompiler.jump_map.pop(decompiler.abs_jump_to_top, []) + decompiler.jump_map[decompiler.for_iter_pos] = targets + for i, (x, y, opname, arg) in enumerate(decompiler.instructions): + if 'JUMP' in opname: + target = arg[0] + if target == decompiler.abs_jump_to_top: + decompiler.instructions[i] = (x, y, opname, [decompiler.for_iter_pos]) + decompiler.conditions_end = y + i = decompiler.instructions_map[decompiler.conditions_end] while i > 0: pos, next_pos, opname, arg = decompiler.instructions[i] @@ -329,7 +347,8 @@ def JUMP_IF_TRUE(decompiler, endpos): JUMP_IF_TRUE_OR_POP = JUMP_IF_TRUE def conditional_jump(decompiler, endpos, if_true): - if PY37: return decompiler.conditional_jump_new(endpos, if_true) + if PY37 or PYPY: + return decompiler.conditional_jump_new(endpos, if_true) return decompiler.conditional_jump_old(endpos, if_true) def conditional_jump_old(decompiler, endpos, if_true): From 333e39fd94f0ef57bfef07aff854b231bef16894 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 14 Apr 2019 19:24:43 +0300 Subject: [PATCH 454/547] Fix CALL_METHOD in PyPy --- pony/orm/decompiling.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index 4759dfb66..080420a3d 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -304,6 +304,15 @@ def CALL_FUNCTION_EX(decompiler, argc): def CALL_METHOD(decompiler, argc): pop = decompiler.stack.pop args = [] + if argc >= 256: + kwargc = argc // 256 + argc = argc % 256 + for i in range(kwargc): + v = pop() + k = pop() + assert isinstance(k, ast.Const) + k = k.value # ast.Name(k.value) + args.append(ast.Keyword(k, v)) for i in range(argc): args.append(pop()) args.reverse() From fd3743183928fc00a67505bf7af3455e7fc52317 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 14 Apr 2019 19:25:12 +0300 Subject: [PATCH 455/547] Improve decompiler tests output --- pony/orm/tests/test_decompiler.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/pony/orm/tests/test_decompiler.py b/pony/orm/tests/test_decompiler.py index ac31dc49f..e0c7e6d74 100644 --- a/pony/orm/tests/test_decompiler.py +++ b/pony/orm/tests/test_decompiler.py @@ -68,22 +68,27 @@ def get_condition_values(cond): code1 = compile(src1, '', 'eval').co_consts[0] ast1 = Decompiler(code1).ast - src1 = ast2src(ast1).replace('.0', '[]') - src1 = src1[src1.find('if')+2:-1] + res1 = ast2src(ast1).replace('.0', '[]') + res1 = res1[res1.find('if')+2:-1] code2 = compile(src2, '', 'eval').co_consts[0] ast2 = Decompiler(code2).ast - src2 = ast2src(ast2).replace('.0', '[]') - src2 = src2[src2.find(':')+1:] + res2 = ast2src(ast2).replace('.0', '[]') + res2 = res2[res2.find(':')+1:] code3 = compile(src3, '', 'eval').co_consts[0] ast3 = Decompiler(code3).ast - src3 = ast2src(ast3).replace('.0', '[]') - src3 = src3[src3.find('if')+2: src3.rfind('for')-1] + res3 = ast2src(ast3).replace('.0', '[]') + res3 = res3[res3.find('if')+2: res3.rfind('for')-1] - self.assertEqual(get_condition_values(gen), get_condition_values(src1)) - self.assertEqual(get_condition_values(gen), get_condition_values(src2)) - self.assertEqual(get_condition_values(gen), get_condition_values(src3)) + if get_condition_values(gen) != get_condition_values(res1): + self.fail("Incorrect generator decompilation: %s -> %s" % (gen, res1)) + + if get_condition_values(gen) != get_condition_values(res2): + self.fail("Incorrect lambda decompilation: %s -> %s" % (gen, res2)) + + if get_condition_values(gen) != get_condition_values(res3): + self.fail("Incorrect multi-for generator decompilation: %s -> %s" % (gen, res3)) return wrapped_test From 3a0467dc16e0f63a11265c8659d043d7e59092a0 Mon Sep 17 00:00:00 2001 From: Carl George Date: Sun, 3 Mar 2019 19:29:44 -0600 Subject: [PATCH 456/547] Include LICENSE in sdist The Apache license (section 4) requires the license and copyright be redistributed with copies of the work. --- MANIFEST.in | 1 + 1 file changed, 1 insertion(+) diff --git a/MANIFEST.in b/MANIFEST.in index 8f03fd3f8..e05b34645 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,3 @@ include pony/orm/tests/queries.txt include pony/flask/example/templates *.html +include LICENSE From 0997b3581cae8d2f5782cb673c35cecf475822fe Mon Sep 17 00:00:00 2001 From: Carl George Date: Sun, 3 Mar 2019 19:29:44 -0600 Subject: [PATCH 457/547] Include LICENSE in sdist The Apache license (section 4) requires the license and copyright be redistributed with copies of the work. --- MANIFEST.in | 1 + 1 file changed, 1 insertion(+) diff --git a/MANIFEST.in b/MANIFEST.in index 8f03fd3f8..e05b34645 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,3 @@ include pony/orm/tests/queries.txt include pony/flask/example/templates *.html +include LICENSE From e804b9308cf34005361812b0abcc71819fd41101 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 20 Apr 2019 17:40:45 +0300 Subject: [PATCH 458/547] Fixes error messages for PyPy --- pony/orm/tests/test_declarative_func_monad.py | 1 + pony/orm/tests/test_select_from_select_queries.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pony/orm/tests/test_declarative_func_monad.py b/pony/orm/tests/test_declarative_func_monad.py index 656c3de48..2a7b3f0a3 100644 --- a/pony/orm/tests/test_declarative_func_monad.py +++ b/pony/orm/tests/test_declarative_func_monad.py @@ -115,6 +115,7 @@ def test_datetime_now1(self): self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) @raises_exception(ExprEvalError, "`1 < datetime.now()` raises TypeError: " + ( "can't compare 'datetime' to 'int'" if PYPY2 else + "'<' not supported between instances of 'int' and 'datetime'" if PYPY and sys.version_info >= (3, 6) else "unorderable types: int < datetime" if PYPY else "can't compare datetime.datetime to int" if PY2 else "unorderable types: int() < datetime.datetime()" if sys.version_info < (3, 6) else diff --git a/pony/orm/tests/test_select_from_select_queries.py b/pony/orm/tests/test_select_from_select_queries.py index 9210687e4..4322b76d5 100644 --- a/pony/orm/tests/test_select_from_select_queries.py +++ b/pony/orm/tests/test_select_from_select_queries.py @@ -2,6 +2,7 @@ from pony.orm import * from pony.orm.tests.testutils import * +from pony.py23compat import PYPY2 db = Database('sqlite', ':memory:') @@ -88,7 +89,8 @@ def test_6(self): # selecting hybrid property in the first query self.assertEqual(db.last_sql.count('SELECT'), 1) @db_session - @raises_exception(ExprEvalError, "`s.scholarship > 0` raises NameError: name 's' is not defined") + @raises_exception(ExprEvalError, "`s.scholarship > 0` raises NameError: name 's' is not defined" if not PYPY2 + else "`s.scholarship > 0` raises NameError: global name 's' is not defined") def test_7(self): # test access to original query var name from the new query q = select(s.first_name for s in Student if s.scholarship < 500) q2 = select(x for x in q if s.scholarship > 0) From f781ae4fd8cc0b357cb89e7c29a5b06c5f185a15 Mon Sep 17 00:00:00 2001 From: Fabio Alessandrelli Date: Mon, 25 Feb 2019 13:36:51 +0100 Subject: [PATCH 459/547] Fix GROUP_CONCAT separator syntax. --- pony/orm/sqlbuilding.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 9c4d1d4b8..ce6611147 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -480,7 +480,10 @@ def GROUP_CONCAT(builder, distinct, expr, sep=None): assert distinct in (None, True, False) result = distinct and 'GROUP_CONCAT(DISTINCT ' or 'GROUP_CONCAT(', builder(expr) if sep is not None: - result = result, ', ', builder(sep) + if builder.provider.dialect == 'MySQL': + result = result, ' SEPARATOR ', builder(sep) + else: + result = result, ', ', builder(sep) return result, ')' UPPER = make_unary_func('upper') LOWER = make_unary_func('lower') From 3566508e20a9b91e74ad74441c9982a7df8c0698 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 20 Apr 2019 17:13:34 +0300 Subject: [PATCH 460/547] Oracle 19C supports DISTINCT in LISTAGG --- pony/orm/dbproviders/oracle.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index 57d86d990..d4f338be0 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -276,7 +276,11 @@ def JSON_ARRAY_LENGTH(builder, value): throw(TranslationError, 'Oracle does not provide `length` function for JSON arrays') def GROUP_CONCAT(builder, distinct, expr, sep=None): assert distinct in (None, True, False) - result = 'LISTAGG(', builder(expr) + if distinct and builder.provider.server_version >= (19,): + distinct = 'DISTINCT ' + else: + distinct = '' + result = 'LISTAGG(', distinct, builder(expr) if sep is not None: result = result, ', ', builder(sep) else: From 1b26f90ce24ce5a10b155995ac47711948c048b3 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 20 Apr 2019 17:28:16 +0300 Subject: [PATCH 461/547] Fix test: add second column to ORDER BY section for stable sort order --- pony/orm/tests/test_select_from_select_queries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/tests/test_select_from_select_queries.py b/pony/orm/tests/test_select_from_select_queries.py index 4322b76d5..a8f7f860d 100644 --- a/pony/orm/tests/test_select_from_select_queries.py +++ b/pony/orm/tests/test_select_from_select_queries.py @@ -379,7 +379,7 @@ def test_44(self): @db_session def test_45(self): - q = select(s for s in Student).order_by(Student.first_name).limit(3, offset=1) + q = select(s for s in Student).order_by(Student.first_name, Student.id).limit(3, offset=1) q2 = select(s for s in q if s.age > 18).limit(2, offset=1) q3 = select(s.last_name for s in q2).limit(2, offset=1) self.assertEqual(set(q3), {'Brown'}) From a068e9c8ab0f0cb69770d5e19a70fdcc04fad296 Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Sat, 20 Apr 2019 18:27:12 +0300 Subject: [PATCH 462/547] Adding a list of awesome Pony ORM supporters --- BACKERS.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) create mode 100644 BACKERS.md diff --git a/BACKERS.md b/BACKERS.md new file mode 100644 index 000000000..6b4a6ba51 --- /dev/null +++ b/BACKERS.md @@ -0,0 +1,13 @@ +# Sponsors & Backers + +Pony ORM is Apache 2.0 licensed open source project. If you would like to support Pony ORM development, please consider: + +[Become a backer or sponsor](https://ponyorm.org/donation.html) + +## Backers + +- Sergio Aguilar Guerrero +- David ROUBLOT +- Elijas Dapšauskas +- Dan Swain + From e9f97b205d53093391875abe9239dcde79f277f7 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Sat, 20 Apr 2019 18:30:36 +0300 Subject: [PATCH 463/547] Fix: Readable error message while using infinity or NaN Decimal values --- pony/orm/dbproviders/sqlite.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 04411bc3c..d9009e976 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -173,6 +173,9 @@ def sql_type(converter): return dbapiprovider.IntConverter.sql_type(converter) class SQLiteDecimalConverter(dbapiprovider.DecimalConverter): + inf = Decimal('infinity') + neg_inf = Decimal('-infinity') + NaN = Decimal('NaN') def sql2py(converter, val): try: val = Decimal(str(val)) except: return val @@ -182,7 +185,10 @@ def sql2py(converter, val): def py2sql(converter, val): if type(val) is not Decimal: val = Decimal(val) exp = converter.exp - if exp is not None: val = val.quantize(exp) + if exp is not None: + if val in (converter.inf, converter.neg_inf, converter.NaN): + throw(ValueError, 'Cannot store %s Decimal value in database' % val) + val = val.quantize(exp) return str(val) class SQLiteDateConverter(dbapiprovider.DateConverter): From 884d517538a68c196b50387a6227323b5f5d5c94 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 20 Apr 2019 18:42:05 +0300 Subject: [PATCH 464/547] Update changelog and Pony version: 0.7.10-dev -> 0.7.10 --- CHANGELOG.md | 12 ++++++++++++ pony/__init__.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 713dac2c0..65d16d9c8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,15 @@ +# PonyORM release 0.7.10 (2019-04-20) + +## Bugfixes + +* Python3.7 and PyPy decompiling fixes +* Fix reading NULL from Optional nullable array column +* Fix handling of empty arrays in queries +* #415: error message typo +* #432: PonyFlask - request object can trigger teardown_request without real request +* Fix GROUP CONCAT separator for MySQL + + # PonyORM release 0.7.9 (2019-01-21) ## Bugfixes diff --git a/pony/__init__.py b/pony/__init__.py index 808a03d62..202a679ed 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.10-dev' +__version__ = '0.7.10' uid = str(random.randint(1, 1000000)) From bb92bb2128609b9ac6b7ea89f302e640263e0053 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 27 Apr 2019 13:28:10 +0300 Subject: [PATCH 465/547] Update Pony version: 0.7.10 -> 0.7.11-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index 202a679ed..f0f04fe91 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.10' +__version__ = '0.7.11-dev' uid = str(random.randint(1, 1000000)) From 8a05edf7ee0eb0d9a7556a6a9b029c8b811ad244 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 4 Sep 2019 15:41:40 +0300 Subject: [PATCH 466/547] Fix #463: changing ponyorm.com to ponyorm.org --- README.md | 8 ++++---- pony/orm/core.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 05416ff6b..374c82742 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ Pony ORM also has the Entity-Relationship Diagram Editor which is a great tool f Documentation ------------- -Documenation is available at [https://docs.ponyorm.com](https://docs.ponyorm.com) +Documenation is available at [https://docs.ponyorm.org](https://docs.ponyorm.org) The documentation source is avaliable at [https://github.com/ponyorm/pony-doc](https://github.com/ponyorm/pony-doc). Please create new documentation related issues [here](https://github.com/ponyorm/pony-doc/issues) or make a pull request with your improvements. @@ -46,8 +46,8 @@ PonyORM community ----------------- Please post your questions on [Stack Overflow](http://stackoverflow.com/questions/tagged/ponyorm). -Meet the PonyORM team, chat with the community members, and get your questions answered on our community [Telegram group](https://telegram.me/ponyorm). -Join our newsletter at [ponyorm.com](https://ponyorm.com). +Meet the PonyORM team, chat with the community members, and get your questions answered on our community [Telegram group](https://t.me/ponyorm). +Join our newsletter at [ponyorm.org](https://ponyorm.org). Reach us on [Twitter](https://twitter.com/ponyorm). -Copyright (c) 2018 Pony ORM, LLC. All rights reserved. team (at) ponyorm.com +Copyright (c) 2013-2019 Pony ORM. All rights reserved. info (at) ponyorm.org diff --git a/pony/orm/core.py b/pony/orm/core.py index 958392b21..a7820ce14 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4869,7 +4869,7 @@ def _db_set_(obj, avdict, unpickling=False): bit = obj._bits_except_volatile_[attr] if rbits & bit: errormsg = 'Please contact PonyORM developers so they can ' \ - 'reproduce your error and fix a bug: support@ponyorm.com' + 'reproduce your error and fix a bug: support@ponyorm.org' assert old_dbval is not NOT_LOADED, errormsg throw(UnrepeatableReadError, 'Value of %s.%s for %s was updated outside of current transaction (was: %r, now: %r)' From c24789e980f35f19a2308841f063a30a7a8c6009 Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Sat, 20 Apr 2019 20:25:38 +0300 Subject: [PATCH 467/547] Minor change in README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 374c82742..1ff502a1f 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ Pony Object-Relational Mapper ============================= -Pony is an advanced object-relational mapper. The most interesting feature of Pony is its ability to write queries to the database using Python generator expressions. Pony analyzes the abstract syntax tree of the generator expression and translates it into a SQL query. +Pony is an advanced object-relational mapper. The most interesting feature of Pony is its ability to write queries to the database using Python generator expressions and lambdas. Pony analyzes the abstract syntax tree of the expression and translates it into a SQL query. Here is an example query in Pony: From cf8b430b6ace1f1dedc54d048bf015c55c670210 Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Sat, 20 Apr 2019 20:25:38 +0300 Subject: [PATCH 468/547] Add support development information --- README.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.md b/README.md index 1ff502a1f..6ebf3a190 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,14 @@ All this helps the developer to focus on implementing the business logic of an a See the example [here](https://github.com/ponyorm/pony/blob/orm/pony/orm/examples/estore.py) +Support Pony ORM Development +---------------------------- + +Pony ORM is Apache 2.0 licensed open source project. If you would like to support Pony ORM development, please consider: + +[Become a backer or sponsor](https://ponyorm.org/donation.html) + + Online tool for database design ------------------------------- From ccfd6232b35b47533497667fbb51072db36af31f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sun, 2 Jun 2019 18:06:22 +0300 Subject: [PATCH 469/547] Fixed #430: add ON DELETE CASCADE for many-to-many relationships --- pony/orm/core.py | 19 +++++++++----- pony/orm/tests/test_relations_m2m.py | 38 ++++++++++++++++++---------- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index a7820ce14..6766ecf28 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1107,10 +1107,13 @@ def get_columns(table, column_names): m2m_table = schema.tables[attr.table] parent_columns = get_columns(table, entity._pk_columns_) child_columns = get_columns(m2m_table, reverse.columns) - m2m_table.add_foreign_key(reverse.fk_name, child_columns, table, parent_columns, attr.index) + on_delete = 'CASCADE' + m2m_table.add_foreign_key(reverse.fk_name, child_columns, table, parent_columns, + attr.index, on_delete) if attr.symmetric: - child_columns = get_columns(m2m_table, attr.reverse_columns) - m2m_table.add_foreign_key(attr.reverse_fk_name, child_columns, table, parent_columns) + reverse_child_columns = get_columns(m2m_table, attr.reverse_columns) + m2m_table.add_foreign_key(attr.reverse_fk_name, reverse_child_columns, table, parent_columns, + attr.reverse_index, on_delete) elif attr.reverse and attr.columns: rentity = attr.reverse.entity parent_table = schema.tables[rentity._table_] @@ -2008,7 +2011,7 @@ class Attribute(object): 'id', 'pk_offset', 'pk_columns_offset', 'py_type', 'sql_type', 'entity', 'name', \ 'lazy', 'lazy_sql_cache', 'args', 'auto', 'default', 'reverse', 'composite_keys', \ 'column', 'columns', 'col_paths', '_columns_checked', 'converters', 'kwargs', \ - 'cascade_delete', 'index', 'original_default', 'sql_default', 'py_check', 'hidden', \ + 'cascade_delete', 'index', 'reverse_index', 'original_default', 'sql_default', 'py_check', 'hidden', \ 'optimistic', 'fk_name', 'type_has_empty_value' def __deepcopy__(attr, memo): return attr # Attribute cannot be cloned by deepcopy() @@ -2069,6 +2072,7 @@ def __init__(attr, py_type, *args, **kwargs): if len(attr.columns) == 1: attr.column = attr.columns[0] else: attr.columns = [] attr.index = kwargs.pop('index', None) + attr.reverse_index = kwargs.pop('reverse_index', None) attr.fk_name = kwargs.pop('fk_name', None) attr.col_paths = [] attr._columns_checked = False @@ -2714,8 +2718,11 @@ def _init_(attr, entity, name): if attr.default is not None: throw(TypeError, 'Default value could not be set for collection attribute') attr.symmetric = (attr.py_type == entity.__name__ and attr.reverse == name) - if not attr.symmetric and attr.reverse_columns: throw(TypeError, - "'reverse_column' and 'reverse_columns' options can be set for symmetric relations only") + if not attr.symmetric: + if attr.reverse_columns: + throw(TypeError, "'reverse_column' and 'reverse_columns' options can be set for symmetric relations only") + if attr.reverse_index: + throw(TypeError, "'reverse_index' option can be set for symmetric relations only") if attr.py_check is not None: throw(NotImplementedError, "'py_check' parameter is not supported for collection attributes") def load(attr, obj): diff --git a/pony/orm/tests/test_relations_m2m.py b/pony/orm/tests/test_relations_m2m.py index 4c80f885b..2f1708986 100644 --- a/pony/orm/tests/test_relations_m2m.py +++ b/pony/orm/tests/test_relations_m2m.py @@ -32,6 +32,16 @@ class Subject(db.Entity): g1.subjects = [ s1, s2 ] def test_1(self): + schema = self.db.schema + m2m_table_name = 'Group_Subject' + self.assertIn(m2m_table_name, schema.tables) + m2m_table = schema.tables[m2m_table_name] + fkeys = list(m2m_table.foreign_keys.values()) + self.assertEqual(len(fkeys), 2) + for fk in fkeys: + self.assertEqual(fk.on_delete, 'CASCADE') + + def test_2(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: @@ -43,7 +53,7 @@ def test_1(self): db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2']) - def test_2(self): + def test_3(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: @@ -55,7 +65,7 @@ def test_2(self): db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2', 'Subj3']) - def test_3(self): + def test_4(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: @@ -67,7 +77,7 @@ def test_3(self): db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2']) - def test_4(self): + def test_5(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: @@ -93,7 +103,7 @@ def test_5(self): self.assertEqual(db_subjects , ['Subj3', 'Subj4']) self.assertEqual(Group[101].subjects, {Subject['Subj3'], Subject['Subj4']}) - def test_6(self): + def test_7(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: @@ -108,7 +118,7 @@ def test_6(self): db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2']) - def test_7(self): + def test_8(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: @@ -123,7 +133,7 @@ def test_7(self): db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2']) - def test_8(self): + def test_9(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: @@ -139,7 +149,7 @@ def test_8(self): db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2']) - def test_9(self): + def test_10(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: @@ -154,7 +164,7 @@ def test_9(self): db_subjects = db.select('subject from Group_Subject where "group" = 102') self.assertEqual(db_subjects , []) - def test_10(self): + def test_11(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: @@ -166,7 +176,7 @@ def test_10(self): db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj2', 'Subj3']) - def test_11(self): + def test_12(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: @@ -181,7 +191,7 @@ def test_11(self): db_subjects = db.select('subject from Group_Subject where "group" = 101') self.assertEqual(db_subjects , ['Subj1', 'Subj2']) - def test_12(self): + def test_13(self): db, Group, Subject = self.db, self.Group, self.Subject with db_session: @@ -197,7 +207,7 @@ def test_12(self): self.assertEqual(db_subjects , ['Subj1', 'Subj2']) @db_session - def test_13(self): + def test_14(self): db, Group, Subject = self.db, self.Group, self.Subject g1 = Group[101] @@ -231,7 +241,7 @@ def test_13(self): self.assertEqual(subj_setdata.removed, set()) @db_session - def test_14(self): + def test_15(self): db, Group, Subject = self.db, self.Group, self.Subject g = Group[101] @@ -253,7 +263,7 @@ def test_14(self): self.assertEqual(db.last_sql, None) @db_session - def test_15(self): + def test_16(self): db, Group = self.db, self.Group g = Group[101] @@ -273,7 +283,7 @@ def test_15(self): self.assertEqual(db.last_sql, None) @db_session - def test_16(self): + def test_17(self): db, Group, Subject = self.db, self.Group, self.Subject g = Group[101] From 61605272e464c5ebb898aa62efd129a71323dd20 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Sat, 22 Jun 2019 18:06:32 +0300 Subject: [PATCH 470/547] Fix aggregate bug by reverting 876a844a --- pony/orm/sqltranslation.py | 16 ++++++---------- .../tests/test_declarative_query_set_monad.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 6b01fd1c8..a402831af 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -659,11 +659,6 @@ def construct_sql_ast(translator, limit=None, offset=None, distinct=None, translator.query_result_is_cacheable = False else: sql_ast = [ 'SELECT' ] - groupby_monads = translator.groupby_monads - if distinct and translator.aggregated and not groupby_monads: - distinct = False - groupby_monads = translator.expr_monads - select_ast = [ 'DISTINCT' if distinct else 'ALL' ] + translator.expr_columns if aggr_func_name: expr_type = translator.expr_type @@ -681,9 +676,10 @@ def construct_sql_ast(translator, limit=None, offset=None, distinct=None, throw(TypeError, '%r is valid for numeric attributes only' % aggr_func_name.lower()) assert len(translator.expr_columns) == 1 aggr_ast = None - if groupby_monads or (aggr_func_name == 'COUNT' and distinct - and isinstance(translator.expr_type, EntityMeta) - and len(translator.expr_columns) > 1): + if translator.groupby_monads or ( + aggr_func_name == 'COUNT' and distinct + and isinstance(translator.expr_type, EntityMeta) + and len(translator.expr_columns) > 1): outer_alias = 't' if aggr_func_name == 'COUNT' and not aggr_func_distinct: outer_aggr_ast = [ 'COUNT', None ] @@ -736,9 +732,9 @@ def ast_transformer(ast): if conditions: sql_ast.append([ 'WHERE' ] + conditions) - if groupby_monads: + if translator.groupby_monads: group_by = [ 'GROUP_BY' ] - for m in groupby_monads: group_by.extend(m.getsql()) + for m in translator.groupby_monads: group_by.extend(m.getsql()) sql_ast.append(group_by) else: group_by = None diff --git a/pony/orm/tests/test_declarative_query_set_monad.py b/pony/orm/tests/test_declarative_query_set_monad.py index c0c989662..7c545a2d4 100644 --- a/pony/orm/tests/test_declarative_query_set_monad.py +++ b/pony/orm/tests/test_declarative_query_set_monad.py @@ -352,6 +352,16 @@ def test_select_from_select_10(self): result = select(n for n, a in query if n.endswith('2') and a > 20) self.assertEqual(set(x for x in result), {'S2'}) + def test_aggregations_1(self): + query = select((min(s.age), max(s.scholarship)) for s in Student) + result = query[:] + self.assertEqual(result, [(20, 500)]) + + def test_aggregations_2(self): + query = select((min(s.age), max(s.scholarship)) for s in Student for g in Group) + result = query[:] + self.assertEqual(result, [(20, 500)]) + if __name__ == "__main__": unittest.main() From ac69f12cc875a5085d8dfb1959166ef2b7b588a7 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 27 Aug 2019 18:24:31 +0300 Subject: [PATCH 471/547] Refactoring of Entity._parse_row_() --- pony/orm/core.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 6766ecf28..57cdfe069 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4288,10 +4288,14 @@ def _parse_row_(entity, row, attr_offsets): avdict = {} for attr in real_entity_subclass._attrs_: offsets = attr_offsets.get(attr) - if offsets is None or attr.is_discriminator: continue - avdict[attr] = attr.parse_value(row, offsets, cache.dbvals_deduplication_cache) + if offsets is None: + continue + if attr.is_discriminator: + avdict[attr] = discr_value + else: + avdict[attr] = attr.parse_value(row, offsets, cache.dbvals_deduplication_cache) - pkval = tuple(avdict.pop(attr, discr_value) for attr in entity._pk_attrs_) + pkval = tuple(avdict.pop(attr) for attr in entity._pk_attrs_) assert None not in pkval if not entity._pk_is_composite_: pkval = pkval[0] return real_entity_subclass, pkval, avdict From 1c8498cf8da4e55bb9523daa506e737995329371 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 27 Aug 2019 20:01:02 +0300 Subject: [PATCH 472/547] Check value of discriminator column on object creation --- pony/orm/core.py | 27 +++++++++++++++++++-------- pony/orm/tests/test_inheritance.py | 13 +++++++++++++ 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 57cdfe069..9f0e0e4db 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2158,6 +2158,12 @@ def __repr__(attr): return '%s.%s' % (owner_name, attr.name or '?') def __lt__(attr, other): return attr.id < other.id + def _get_entity(attr, obj, entity): + if entity is not None: + return entity + if obj is not None: + return obj.__class__ + return attr.entity def validate(attr, val, obj=None, entity=None, from_db=False): val = deref_proxy(val) if val is None: @@ -2172,10 +2178,7 @@ def validate(attr, val, obj=None, entity=None, from_db=False): if callable(default): val = default() else: val = default - if entity is not None: pass - elif obj is not None: entity = obj.__class__ - else: entity = attr.entity - + entity = attr._get_entity(obj, entity) reverse = attr.reverse if not reverse: if isinstance(val, Entity): throw(TypeError, 'Attribute %s must be of %s type. Got: %s' @@ -2555,7 +2558,7 @@ def process_entity_inheritance(attr, entity): entity._discriminator_ = entity.__name__ discr_value = entity._discriminator_ if discr_value is not None: - try: entity._discriminator_ = discr_value = attr.validate(discr_value) + try: entity._discriminator_ = discr_value = attr.validate(discr_value, None, entity) except ValueError: throw(TypeError, "Incorrect discriminator value is set for %s attribute '%s' of '%s' type: %r" % (entity.__name__, attr.name, attr.py_type.__name__, discr_value)) @@ -2566,10 +2569,18 @@ def process_entity_inheritance(attr, entity): % (entity.__name__, attr.name, attr.py_type.__name__)) attr.code2cls[discr_value] = entity def validate(attr, val, obj=None, entity=None, from_db=False): - if from_db: return val - elif val is DEFAULT: + if from_db: + return val + entity = attr._get_entity(obj, entity) + if val is DEFAULT: assert entity is not None return entity._discriminator_ + if val != entity._discriminator_: + for cls in entity._subclasses_: + if val == cls._discriminator_: + break + else: throw(TypeError, 'Invalid discriminator attribute value for %s. Expected: %r, got: %r' + % (entity.__name__, entity._discriminator_, val)) return Attribute.validate(attr, val, obj, entity) def load(attr, obj): assert False # pragma: no cover @@ -4655,7 +4666,7 @@ def __init__(obj, *args, **kwargs): if name not in entity._adict_: throw(TypeError, 'Unknown attribute %r' % name) for attr in entity._attrs_: val = kwargs.get(attr.name, DEFAULT) - avdict[attr] = attr.validate(val, obj, entity, from_db=False) + avdict[attr] = attr.validate(val, obj, from_db=False) if entity._pk_is_composite_: pkval = tuple(imap(avdict.get, entity._pk_attrs_)) if None in pkval: pkval = None diff --git a/pony/orm/tests/test_inheritance.py b/pony/orm/tests/test_inheritance.py index 6538d4abf..746be42d9 100644 --- a/pony/orm/tests/test_inheritance.py +++ b/pony/orm/tests/test_inheritance.py @@ -276,10 +276,23 @@ class Entity2(db.Entity1): self.assertEqual(obj._pkval_, ('Entity2', 20)) with db_session: obj = Entity1['Entity2', 20] + self.assertIsInstance(obj, Entity2) self.assertEqual(obj.a, 'Entity2') self.assertEqual(obj.b, 20) self.assertEqual(obj._pkval_, ('Entity2', 20)) + @raises_exception(TypeError, "Invalid discriminator attribute value for Foo. Expected: 'Foo', got: 'Baz'") + def test_discriminator_2(self): + db = Database('sqlite', ':memory:') + class Foo(db.Entity): + id = PrimaryKey(int) + a = Discriminator(str) + b = Required(int) + class Bar(db.Entity): + c = Required(int) + db.generate_mapping(create_tables=True) + with db_session: + x = Foo(id=1, a='Baz', b=100) if __name__ == '__main__': unittest.main() From 803144b421d96c26b3d740dd606c6e6304bb9407 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 28 Aug 2019 14:25:33 +0300 Subject: [PATCH 473/547] Remove incorrect assertion --- pony/orm/core.py | 1 - pony/orm/tests/test_frames.py | 16 ++++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 9f0e0e4db..fd5a348e9 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4358,7 +4358,6 @@ def _query_from_args_(entity, args, kwargs, frame_depth): for_expr = ast.GenExprFor(ast.AssName(name, 'OP_ASSIGN'), ast.Name('.0'), [ if_expr ]) inner_expr = ast.GenExprInner(ast.Name(name), [ for_expr ]) locals = locals.copy() if locals is not None else {} - assert '.0' not in locals locals['.0'] = entity return Query(code_key, inner_expr, globals, locals, cells) def _get_from_identity_map_(entity, pkval, status, for_update=False, undo_funcs=None, obj_to_init=None): diff --git a/pony/orm/tests/test_frames.py b/pony/orm/tests/test_frames.py index 68ac3a601..6d4bc1a39 100644 --- a/pony/orm/tests/test_frames.py +++ b/pony/orm/tests/test_frames.py @@ -3,6 +3,8 @@ import unittest from pony.orm.core import * +import pony.orm.decompiling +from pony.orm.tests.testutils import * db = Database('sqlite', ':memory:') @@ -167,5 +169,19 @@ def test_db_exists(self): result = db.exists('name from Person where age = $x') self.assertEqual(result, True) + @raises_exception(pony.orm.decompiling.InvalidQuery, + 'Use generator expression (... for ... in ...) ' + 'instead of list comprehension [... for ... in ...] inside query') + @db_session + def test_inner_list_comprehension(self): + result = select(p.id for p in Person if p.age not in [ + p2.age for p2 in Person if p2.name.startswith('M')])[:] + + @db_session + def test_outer_list_comprehension(self): + names = ['John', 'Mary', 'Mike'] + persons = [ Person.select(lambda p: p.name == name).first() for name in names ] + self.assertEqual(set(p.name for p in persons), {'John', 'Mary', 'Mike'}) + if __name__ == '__main__': unittest.main() From 0f136ec1e54d1aadfa560c684e9523a73e252271 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 28 Aug 2019 16:17:06 +0300 Subject: [PATCH 474/547] Fix test for GROUP_CONCAT syntax for MySQL --- pony/orm/tests/queries.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/tests/queries.txt b/pony/orm/tests/queries.txt index 2ce040941..0c785f433 100644 --- a/pony/orm/tests/queries.txt +++ b/pony/orm/tests/queries.txt @@ -949,7 +949,7 @@ GROUP BY "g"."number" MySQL: -SELECT `g`.`number`, GROUP_CONCAT(`s`.`name`, '+') +SELECT `g`.`number`, GROUP_CONCAT(`s`.`name` SEPARATOR '+') FROM `group` `g`, `student` `s` WHERE `g`.`number` = `s`.`group` GROUP BY `g`.`number` From bc8521773f9bee4de0314d40dbcac5dc1f9c44a6 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 28 Aug 2019 18:31:24 +0300 Subject: [PATCH 475/547] Remove unnecessary DISTINCT in aggregated queries --- pony/orm/sqltranslation.py | 2 + pony/orm/tests/queries.txt | 8 +-- pony/orm/tests/test_distinct.py | 86 +++++++++++++++++++++++++++++++++ 3 files changed, 92 insertions(+), 4 deletions(-) create mode 100644 pony/orm/tests/test_distinct.py diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index a402831af..65202d8e0 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -504,6 +504,8 @@ def func(value, converter=converter): offset += 1 translator.row_layout = row_layout translator.col_names = [ src for func, slice_or_offset, src in translator.row_layout ] + if translator.aggregated: + translator.distinct = False translator.vars = None if translator is not this: raise UseAnotherTranslator(translator) diff --git a/pony/orm/tests/queries.txt b/pony/orm/tests/queries.txt index 0c785f433..ac371de88 100644 --- a/pony/orm/tests/queries.txt +++ b/pony/orm/tests/queries.txt @@ -255,7 +255,7 @@ HAVING "order"."total_price" < coalesce(SUM(("item"."price" * "item"."quantity") >>> select(c for c in Customer for p in c.orders.items.product if 'Tablets' in p.categories.name and count(p) > 1) -SELECT DISTINCT "c"."id" +SELECT "c"."id" FROM "Customer" "c", "Order" "order", "OrderItem" "orderitem" WHERE 'Tablets' IN ( SELECT "category"."name" @@ -516,7 +516,7 @@ GROUP BY "g"."number" >>> select((s, count(c)) for s in Student for c in s.courses) -SELECT DISTINCT "s"."id", COUNT(DISTINCT "c"."ROWID") +SELECT "s"."id", COUNT(DISTINCT "c"."ROWID") FROM "Student" "s", "Course_Student" "t-1", "Course" "c" WHERE "s"."id" = "t-1"."student" AND "t-1"."course_name" = "c"."name" @@ -525,7 +525,7 @@ GROUP BY "s"."id" Oracle: -SELECT DISTINCT "s"."ID", COUNT(DISTINCT "c"."ROWID") +SELECT "s"."ID", COUNT(DISTINCT "c"."ROWID") FROM "STUDENT" "s", "COURSE_STUDENT" "t-1", "COURSE" "c" WHERE "s"."ID" = "t-1"."STUDENT" AND "t-1"."COURSE_NAME" = "c"."NAME" @@ -534,7 +534,7 @@ GROUP BY "s"."ID" PostgreSQL: -SELECT DISTINCT "s"."id", COUNT(DISTINCT case when ("t-1"."course_name", "t-1"."course_semester") IS NULL then null else ("t-1"."course_name", "t-1"."course_semester") end) +SELECT "s"."id", COUNT(DISTINCT case when ("t-1"."course_name", "t-1"."course_semester") IS NULL then null else ("t-1"."course_name", "t-1"."course_semester") end) FROM "student" "s", "course_student" "t-1" WHERE "s"."id" = "t-1"."student" GROUP BY "s"."id" diff --git a/pony/orm/tests/test_distinct.py b/pony/orm/tests/test_distinct.py new file mode 100644 index 000000000..5f4ce1eb5 --- /dev/null +++ b/pony/orm/tests/test_distinct.py @@ -0,0 +1,86 @@ +from __future__ import absolute_import, print_function, division + +import unittest + +from pony.orm.core import * +from pony.orm.tests.testutils import * + +db = Database('sqlite', ':memory:') + +class Department(db.Entity): + number = PrimaryKey(int) + groups = Set('Group') + +class Group(db.Entity): + id = PrimaryKey(int) + dept = Required('Department') + students = Set('Student') + +class Student(db.Entity): + name = Required(unicode) + age = Required(int) + group = Required('Group') + scholarship = Required(int, default=0) + courses = Set('Course') + +class Course(db.Entity): + name = Required(unicode) + semester = Required(int) + PrimaryKey(name, semester) + students = Set('Student') + +db.generate_mapping(create_tables=True) + +with db_session: + d1 = Department(number=1) + d2 = Department(number=2) + g1 = Group(id=1, dept=d1) + g2 = Group(id=2, dept=d2) + s1 = Student(id=1, name='S1', age=20, group=g1, scholarship=0) + s2 = Student(id=2, name='S2', age=21, group=g1, scholarship=100) + s3 = Student(id=3, name='S3', age=23, group=g1, scholarship=200) + s4 = Student(id=4, name='S4', age=21, group=g1, scholarship=100) + s5 = Student(id=5, name='S5', age=23, group=g2, scholarship=0) + s6 = Student(id=6, name='S6', age=23, group=g2, scholarship=200) + c1 = Course(name='C1', semester=1, students=[s1, s2, s3]) + c2 = Course(name='C2', semester=1, students=[s2, s3, s5, s6]) + c3 = Course(name='C3', semester=2, students=[s4, s5, s6]) + + +class TestDistinct(unittest.TestCase): + def setUp(self): + db_session.__enter__() + + def tearDown(self): + db_session.__exit__() + + def test_group_by(self): + result = set(select((s.age, sum(s.scholarship)) for s in Student if s.scholarship > 0)) + self.assertEqual(result, {(21, 200), (23, 400)}) + self.assertNotIn('distinct', db.last_sql.lower()) + + def test_group_by_having(self): + result = set(select((s.age, sum(s.scholarship)) for s in Student if sum(s.scholarship) < 300)) + self.assertEqual(result, {(20, 0), (21, 200)}) + self.assertNotIn('distinct', db.last_sql.lower()) + + def test_aggregation_no_group_by_1(self): + result = set(select(sum(s.scholarship) for s in Student if s.age < 23)) + self.assertEqual(result, {200}) + self.assertNotIn('distinct', db.last_sql.lower()) + + def test_aggregation_no_group_by_2(self): + result = set(select((sum(s.scholarship), min(s.scholarship)) for s in Student if s.age < 23)) + self.assertEqual(result, {(200, 0)}) + self.assertNotIn('distinct', db.last_sql.lower()) + + def test_aggregation_no_group_by_3(self): + result = set(select((sum(s.scholarship), min(s.scholarship)) + for s in Student for g in Group + if s.group == g and g.dept.number == 1)) + self.assertEqual(result, {(400, 0)}) + self.assertNotIn('distinct', db.last_sql.lower()) + + +if __name__ == "__main__": + unittest.main() From 6c37d6c55ce8a0365a2065aecdfb82c1c5df2b58 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 12 Sep 2019 13:04:45 +0300 Subject: [PATCH 476/547] Deref proxies when add items to collections --- pony/orm/core.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/core.py b/pony/orm/core.py index fd5a348e9..c5994ca58 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2787,6 +2787,7 @@ def param(i, j, converter): class Set(Collection): __slots__ = [] def validate(attr, val, obj=None, entity=None, from_db=False): + val = deref_proxy(val) assert val is not NOT_LOADED if val is DEFAULT: return set() reverse = attr.reverse @@ -2803,6 +2804,7 @@ def validate(attr, val, obj=None, entity=None, from_db=False): except TypeError: throw(TypeError, 'Item of collection %s.%s must be an instance of %s. Got: %r' % (entity.__name__, attr.name, rentity.__name__, val)) for item in items: + item = deref_proxy(item) if not isinstance(item, rentity): throw(TypeError, 'Item of collection %s.%s must be an instance of %s. Got: %r' % (entity.__name__, attr.name, rentity.__name__, item)) From 454f6a566cb06f2db46c98c5ef45e538f9c8a3fd Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Thu, 19 Sep 2019 19:42:08 +0300 Subject: [PATCH 477/547] Update BACKERS.md --- BACKERS.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/BACKERS.md b/BACKERS.md index 6b4a6ba51..aa7ee850b 100644 --- a/BACKERS.md +++ b/BACKERS.md @@ -10,4 +10,6 @@ Pony ORM is Apache 2.0 licensed open source project. If you would like to suppor - David ROUBLOT - Elijas Dapšauskas - Dan Swain - +- Christian Macht +- Johnathan Nader +- Andrei Rachalouski From 4d45ddf0dd8315e1f07c2a1e5dd8bf24a6f3837e Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 24 Sep 2019 13:37:33 +0300 Subject: [PATCH 478/547] Remove unnecessary @db_session --- pony/orm/tests/test_query.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/pony/orm/tests/test_query.py b/pony/orm/tests/test_query.py index 328d16fca..eb6fa2615 100644 --- a/pony/orm/tests/test_query.py +++ b/pony/orm/tests/test_query.py @@ -119,26 +119,21 @@ def test22(self): def test23(self): r = max(s.dob.year for s in Student) self.assertEqual(r, 2001) - @db_session def test_first1(self): q = select(s for s in Student).order_by(Student.gpa) self.assertEqual(q.first(), Student[1]) - @db_session def test_first2(self): q = select((s.name, s.group) for s in Student) self.assertEqual(q.first(), ('S1', Group[1])) - @db_session def test_first3(self): q = select(s for s in Student) self.assertEqual(q.first(), Student[1]) - @db_session def test_closures_1(self): def find_by_gpa(gpa): return lambda s: s.gpa > gpa fn = find_by_gpa(Decimal('3.1')) students = list(Student.select(fn)) self.assertEqual(students, [ Student[2], Student[3] ]) - @db_session def test_closures_2(self): def find_by_gpa(gpa): return lambda s: s.gpa > gpa From 26c1d8c187cbca924235f5175b9257dc7e1edc0c Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 24 Sep 2019 13:38:42 +0300 Subject: [PATCH 479/547] Bulk delete should clear query results cache --- pony/orm/core.py | 1 + pony/orm/tests/test_query.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/pony/orm/core.py b/pony/orm/core.py index c5994ca58..484f73ac3 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5948,6 +5948,7 @@ def delete(query, bulk=None): cache.immediate = True cache.prepare_connection_for_query_execution() # may clear cache.query_results cursor = database._exec_sql(sql, arguments) + cache.query_results.clear() return cursor.rowcount @cut_traceback def __len__(query): diff --git a/pony/orm/tests/test_query.py b/pony/orm/tests/test_query.py index eb6fa2615..f208541ba 100644 --- a/pony/orm/tests/test_query.py +++ b/pony/orm/tests/test_query.py @@ -155,6 +155,15 @@ def test_pickle(self): rollback() objects = pickle.loads(data) self.assertEqual([obj.id for obj in objects], [3, 2]) + def test_bulk_delete_clear_query_cache(self): + students1 = Student.select(lambda s: s.id > 1).order_by(Student.id)[:] + self.assertEqual([s.id for s in students1], [2, 3]) + Student.select(lambda s: s.id < 3).delete(bulk=True) + students2 = Student.select(lambda s: s.id > 1).order_by(Student.id)[:] + self.assertEqual([s.id for s in students2], [3]) + rollback() + students1 = Student.select(lambda s: s.id > 1).order_by(Student.id)[:] + self.assertEqual([s.id for s in students1], [2, 3]) if __name__ == '__main__': From ea8020b3e4d036aca456e9d09164fe72d4c39ebe Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 24 Sep 2019 17:50:27 +0300 Subject: [PATCH 480/547] Fix error message when hybrid method is too complex to decompile --- pony/orm/decompiling.py | 13 ++++++---- pony/orm/sqltranslation.py | 7 ++++-- .../test_hybrid_methods_and_properties.py | 24 +++++++++++++++++++ 3 files changed, 37 insertions(+), 7 deletions(-) diff --git a/pony/orm/decompiling.py b/pony/orm/decompiling.py index 080420a3d..068714754 100644 --- a/pony/orm/decompiling.py +++ b/pony/orm/decompiling.py @@ -13,6 +13,9 @@ ##ast.And.__repr__ = lambda self: "And(%s: %s)" % (getattr(self, 'endpos', '?'), repr(self.nodes),) ##ast.Or.__repr__ = lambda self: "Or(%s: %s)" % (getattr(self, 'endpos', '?'), repr(self.nodes),) +class DecompileError(NotImplementedError): + pass + ast_cache = {} def decompile(x): @@ -165,7 +168,7 @@ def decompile(decompiler): decompiler.process_target(pos) method = getattr(decompiler, opname, None) if method is None: - throw(NotImplementedError('Unsupported operation: %s' % opname)) + throw(DecompileError('Unsupported operation: %s' % opname)) decompiler.pos = pos decompiler.next_pos = next_pos x = method(*arg) @@ -296,7 +299,7 @@ def CALL_FUNCTION_VAR_KW(decompiler, argc): def CALL_FUNCTION_EX(decompiler, argc): star2 = None if argc: - if argc != 1: throw(NotImplementedError) + if argc != 1: throw(DecompileError) star2 = decompiler.stack.pop() star = decompiler.stack.pop() return decompiler._call_function([], star, star2) @@ -416,7 +419,7 @@ def process_target(decompiler, pos, partial=False): if hasattr(top, 'endpos'): top2.endpos = top.endpos if decompiler.targets.get(top.endpos) is top: decompiler.targets[top.endpos] = top2 - else: throw(NotImplementedError('Expression is too complex to decompile, try to pass query as string, e.g. select("x for x in Something")')) + else: throw(DecompileError('Expression is too complex to decompile, try to pass query as string, e.g. select("x for x in Something")')) top2.endpos = max(top2.endpos, getattr(top, 'endpos', 0)) top = decompiler.stack.pop() decompiler.stack.append(top) @@ -487,7 +490,7 @@ def MAKE_FUNCTION(decompiler, argc): kwonly_defaults = decompiler.stack.pop() if argc & 0x01: defaults = decompiler.stack.pop() - throw(NotImplementedError) + throw(DecompileError) else: if not PY2: qualname = decompiler.stack.pop() @@ -516,7 +519,7 @@ def POP_TOP(decompiler): pass def RETURN_VALUE(decompiler): - if decompiler.next_pos != decompiler.end: throw(NotImplementedError) + if decompiler.next_pos != decompiler.end: throw(DecompileError) expr = decompiler.stack.pop() return simplify(expr) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 65202d8e0..1927b8910 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -14,7 +14,7 @@ from pony import options, utils from pony.utils import localbase, is_ident, throw, reraise, copy_ast, between, concat, coalesce from pony.orm.asttranslation import ASTTranslator, ast2src, TranslationError, create_extractors -from pony.orm.decompiling import decompile +from pony.orm.decompiling import decompile, DecompileError from pony.orm.ormtypes import \ numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ normalize, normalize_type, coerce_types, are_comparable_types, \ @@ -1645,7 +1645,10 @@ def __call__(monad, *args, **kwargs): if PY2 and isinstance(func, types.UnboundMethodType): func = func.im_func func_id = id(func) - func_ast, external_names, cells = decompile(func) + try: + func_ast, external_names, cells = decompile(func) + except DecompileError: + throw(TranslationError, '%s(...) is too complex to decompile' % ast2src(monad.node)) func_ast, func_extractors = create_extractors( func_id, func_ast, func.__globals__, {}, special_functions, const_functions, outer_names=name_mapping) diff --git a/pony/orm/tests/test_hybrid_methods_and_properties.py b/pony/orm/tests/test_hybrid_methods_and_properties.py index aa54a52b8..d6568e27e 100644 --- a/pony/orm/tests/test_hybrid_methods_and_properties.py +++ b/pony/orm/tests/test_hybrid_methods_and_properties.py @@ -44,6 +44,14 @@ def incorrect_full_name(self): def find_by_full_name(cls, full_name): return cls.select(lambda p: p.full_name_2 == full_name) + def complex_method(self): + result = '' + for i in range(10): + result += str(i) + return result + + def simple_method(self): + return self.complex_method() class FakePerson(object): pass @@ -190,5 +198,21 @@ def test19(self): finally: sep = ' ' + @db_session + @raises_exception(TranslationError, 'p.complex_method(...) is too complex to decompile') + def test_20(self): + q = select(p.complex_method() for p in Person)[:] + + @db_session + @raises_exception(TranslationError, 'p.to_dict(...) is too complex to decompile') + def test_21(self): + q = select(p.to_dict() for p in Person)[:] + + @db_session + @raises_exception(TranslationError, 'self.complex_method(...) is too complex to decompile (inside Person.simple_method)') + def test_22(self): + q = select(p.simple_method() for p in Person)[:] + + if __name__ == '__main__': unittest.main() From c938694a9cf048fa71c665d8e4b6cdf7186daeb9 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 1 Oct 2019 14:18:28 +0300 Subject: [PATCH 481/547] Move code around --- pony/orm/sqltranslation.py | 100 ++++++++++++++++++------------------- 1 file changed, 50 insertions(+), 50 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 1927b8910..a9fcbcbf1 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1633,56 +1633,6 @@ def __pow__(monad, monad2): raise_forgot_parentheses(monad) def __neg__(monad): raise_forgot_parentheses(monad) def abs(monad): raise_forgot_parentheses(monad) -class HybridMethodMonad(MethodMonad): - def __init__(monad, parent, attrname, func): - MethodMonad.__init__(monad, parent, attrname) - monad.func = func - def __call__(monad, *args, **kwargs): - translator = monad.translator - name_mapping = inspect.getcallargs(monad.func, monad.parent, *args, **kwargs) - - func = monad.func - if PY2 and isinstance(func, types.UnboundMethodType): - func = func.im_func - func_id = id(func) - try: - func_ast, external_names, cells = decompile(func) - except DecompileError: - throw(TranslationError, '%s(...) is too complex to decompile' % ast2src(monad.node)) - - func_ast, func_extractors = create_extractors( - func_id, func_ast, func.__globals__, {}, special_functions, const_functions, outer_names=name_mapping) - - root_translator = translator.root_translator - if func not in root_translator.func_extractors_map: - func_vars, func_vartypes = extract_vars(func_id, translator.filter_num, func_extractors, func.__globals__, {}, cells) - translator.database.provider.normalize_vars(func_vars, func_vartypes) - if func.__closure__: - translator.can_be_cached = False - if func_extractors: - root_translator.func_extractors_map[func] = func_extractors - root_translator.func_vartypes.update(func_vartypes) - root_translator.vartypes.update(func_vartypes) - root_translator.vars.update(func_vars) - - stack = translator.namespace_stack - stack.append(name_mapping) - func_ast = copy_ast(func_ast) - try: - prev_code_key = translator.code_key - translator.code_key = func_id - try: - translator.dispatch(func_ast) - finally: - translator.code_key = prev_code_key - except Exception as e: - if len(e.args) == 1 and isinstance(e.args[0], basestring): - msg = e.args[0] + ' (inside %s.%s)' % (monad.parent.type.__name__, monad.attrname) - e.args = (msg,) - raise - stack.pop() - return func_ast.monad - class EntityMonad(Monad): def __init__(monad, entity): Monad.__init__(monad, SetType(entity)) @@ -2542,6 +2492,56 @@ def negate(monad): def getsql(monad, sqlquery=None): return [ [ 'NOT', monad.operand.getsql()[0] ] ] +class HybridMethodMonad(MethodMonad): + def __init__(monad, parent, attrname, func): + MethodMonad.__init__(monad, parent, attrname) + monad.func = func + def __call__(monad, *args, **kwargs): + translator = monad.translator + name_mapping = inspect.getcallargs(monad.func, monad.parent, *args, **kwargs) + + func = monad.func + if PY2 and isinstance(func, types.UnboundMethodType): + func = func.im_func + func_id = id(func) + try: + func_ast, external_names, cells = decompile(func) + except DecompileError: + throw(TranslationError, '%s(...) is too complex to decompile' % ast2src(monad.node)) + + func_ast, func_extractors = create_extractors( + func_id, func_ast, func.__globals__, {}, special_functions, const_functions, outer_names=name_mapping) + + root_translator = translator.root_translator + if func not in root_translator.func_extractors_map: + func_vars, func_vartypes = extract_vars(func_id, translator.filter_num, func_extractors, func.__globals__, {}, cells) + translator.database.provider.normalize_vars(func_vars, func_vartypes) + if func.__closure__: + translator.can_be_cached = False + if func_extractors: + root_translator.func_extractors_map[func] = func_extractors + root_translator.func_vartypes.update(func_vartypes) + root_translator.vartypes.update(func_vartypes) + root_translator.vars.update(func_vars) + + stack = translator.namespace_stack + stack.append(name_mapping) + func_ast = copy_ast(func_ast) + try: + prev_code_key = translator.code_key + translator.code_key = func_id + try: + translator.dispatch(func_ast) + finally: + translator.code_key = prev_code_key + except Exception as e: + if len(e.args) == 1 and isinstance(e.args[0], basestring): + msg = e.args[0] + ' (inside %s.%s)' % (monad.parent.type.__name__, monad.attrname) + e.args = (msg,) + raise + stack.pop() + return func_ast.monad + class ErrorSpecialFuncMonad(Monad): def __init__(monad, func): Monad.__init__(monad, func) From e52c1f652846744b2f6d56c59086776253943d7c Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 1 Oct 2019 14:26:28 +0300 Subject: [PATCH 482/547] Move raw_sql function to ormtypes --- pony/orm/core.py | 7 +------ pony/orm/ormtypes.py | 7 ++++++- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 484f73ac3..a5ec899c3 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -21,7 +21,7 @@ from pony import options from pony.orm.decompiling import decompile from pony.orm.ormtypes import ( - LongStr, LongUnicode, numeric_types, RawSQL, normalize, Json, TrackedValue, QueryType, + LongStr, LongUnicode, numeric_types, raw_sql, RawSQL, normalize, Json, TrackedValue, QueryType, Array, IntArray, StrArray, FloatArray ) from pony.orm.asttranslation import ast2src, create_extractors, TranslationError @@ -5579,11 +5579,6 @@ def desc(expr): return 'desc(%s)' % expr return expr -def raw_sql(sql, result_type=None): - globals = sys._getframe(1).f_globals - locals = sys._getframe(1).f_locals - return RawSQL(sql, globals, locals, result_type) - def extract_vars(code_key, filter_num, extractors, globals, locals, cells=None): if cells: locals = locals.copy() diff --git a/pony/orm/ormtypes.py b/pony/orm/ormtypes.py index ae6cccd8f..080b7f5cf 100644 --- a/pony/orm/ormtypes.py +++ b/pony/orm/ormtypes.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, print_function, division from pony.py23compat import PY2, items_list, izip, basestring, unicode, buffer, int_types, iteritems -import types, weakref +import sys, types, weakref from decimal import Decimal from datetime import date, time, datetime, timedelta from functools import wraps, WRAPPER_ASSIGNMENTS @@ -97,6 +97,11 @@ def parse_raw_sql(sql): raw_sql_cache[sql] = result return result +def raw_sql(sql, result_type=None): + globals = sys._getframe(1).f_globals + locals = sys._getframe(1).f_locals + return RawSQL(sql, globals, locals, result_type) + class RawSQL(object): def __deepcopy__(self, memo): assert False # should not attempt to deepcopy RawSQL instances, because of locals/globals From cd7fe6ca0eb14bb4e478c82d67fae965b15086c7 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 1 Oct 2019 16:15:58 +0300 Subject: [PATCH 483/547] Hybrid functions --- pony/orm/sqltranslation.py | 43 ++++++++++++------- .../tests/test_declarative_sqltranslator.py | 4 -- .../test_hybrid_methods_and_properties.py | 18 ++++++++ pony/orm/tests/test_query.py | 6 +-- pony/orm/tests/test_raw_sql.py | 3 +- 5 files changed, 50 insertions(+), 24 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index a9fcbcbf1..78f2dc009 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -16,7 +16,7 @@ from pony.orm.asttranslation import ASTTranslator, ast2src, TranslationError, create_extractors from pony.orm.decompiling import decompile, DecompileError from pony.orm.ormtypes import \ - numeric_types, comparable_types, SetType, FuncType, MethodType, RawSQLType, \ + numeric_types, comparable_types, SetType, FuncType, MethodType, raw_sql, RawSQLType, \ normalize, normalize_type, coerce_types, are_comparable_types, \ Json, QueryType, Array, array_types from pony.orm import core @@ -125,8 +125,11 @@ def dispatch_external(translator, node): monad = monad.call_limit(t.limit, t.offset) elif tt is FuncType: func = t.func - func_monad_class = translator.registered_functions.get(func, ErrorSpecialFuncMonad) - monad = func_monad_class(func) + func_monad_class = translator.registered_functions.get(func) + if func_monad_class is not None: + monad = func_monad_class(func) + else: + monad = HybridFuncMonad(t, func.__name__) elif tt is MethodType: obj, func = t.obj, t.func if isinstance(obj, EntityMeta): @@ -1070,8 +1073,6 @@ def postCallFunc(translator, node): kwargs[arg.name] = arg.expr.monad else: args.append(arg.monad) func_monad = node.node.monad - if isinstance(func_monad, ErrorSpecialFuncMonad): throw(TypeError, - 'Function %r cannot be used this way: %s' % (func_monad.func.__name__, ast2src(node))) return func_monad(*args, **kwargs) def postKeyword(translator, node): pass # this node will be processed by postCallFunc @@ -2492,13 +2493,15 @@ def negate(monad): def getsql(monad, sqlquery=None): return [ [ 'NOT', monad.operand.getsql()[0] ] ] -class HybridMethodMonad(MethodMonad): - def __init__(monad, parent, attrname, func): - MethodMonad.__init__(monad, parent, attrname) - monad.func = func +class HybridFuncMonad(Monad): + def __init__(monad, func_type, func_name, *params): + Monad.__init__(monad, func_type) + monad.func = func_type.func + monad.func_name = func_name + monad.params = params def __call__(monad, *args, **kwargs): translator = monad.translator - name_mapping = inspect.getcallargs(monad.func, monad.parent, *args, **kwargs) + name_mapping = inspect.getcallargs(monad.func, *(monad.params + args), **kwargs) func = monad.func if PY2 and isinstance(func, types.UnboundMethodType): @@ -2536,16 +2539,18 @@ def __call__(monad, *args, **kwargs): translator.code_key = prev_code_key except Exception as e: if len(e.args) == 1 and isinstance(e.args[0], basestring): - msg = e.args[0] + ' (inside %s.%s)' % (monad.parent.type.__name__, monad.attrname) + msg = e.args[0] + ' (inside %s)' % (monad.func_name) e.args = (msg,) raise stack.pop() return func_ast.monad -class ErrorSpecialFuncMonad(Monad): - def __init__(monad, func): - Monad.__init__(monad, func) - monad.func = func +class HybridMethodMonad(HybridFuncMonad): + def __init__(monad, parent, attrname, func): + entity = parent.type + assert isinstance(entity, EntityMeta) + func_name = '%s.%s' % (entity.__name__, attrname) + HybridFuncMonad.__init__(monad, FuncType(func), func_name, parent) registered_functions = SQLTranslator.registered_functions = {} @@ -2726,7 +2731,7 @@ class FuncLenMonad(FuncMonad): def call(monad, x): return x.len() -class GetattrMonad(FuncMonad): +class FuncGetattrMonad(FuncMonad): func = getattr def call(monad, obj_monad, name_monad): if isinstance(name_monad, ConstMonad): @@ -2745,6 +2750,12 @@ def call(monad, obj_monad, name_monad): throw(TypeError, 'In `{EXPR}` second argument should be a string. Got: %r' % attrname) return obj_monad.getattr(attrname) +class FuncRawSQLMonad(FuncMonad): + func = raw_sql + def call(monad, *args): + throw(TranslationError, 'Expression `{EXPR}` cannot be translated into SQL ' + 'because raw SQL fragment will be different for each row') + class FuncCountMonad(FuncMonad): func = itertools.count, utils.count, core.count def call(monad, x=None, distinct=None): diff --git a/pony/orm/tests/test_declarative_sqltranslator.py b/pony/orm/tests/test_declarative_sqltranslator.py index 8181646c7..ccddd5779 100644 --- a/pony/orm/tests/test_declarative_sqltranslator.py +++ b/pony/orm/tests/test_declarative_sqltranslator.py @@ -343,10 +343,6 @@ def test_tuple_param_2(self): x = Student[1], None result = set(select(s for s in Student if s not in x)) self.assertEqual(result, {Student[3]}) - @raises_exception(TypeError, "Function 'f' cannot be used this way: f(s)") - def test_unknown_func(self): - def f(x): return x - select(s for s in Student if f(s)) def test_method_monad(self): result = set(select(s for s in Student if s not in Student.select(lambda s: s.scholarship > 0))) self.assertEqual(result, {Student[1]}) diff --git a/pony/orm/tests/test_hybrid_methods_and_properties.py b/pony/orm/tests/test_hybrid_methods_and_properties.py index d6568e27e..c1d454dc9 100644 --- a/pony/orm/tests/test_hybrid_methods_and_properties.py +++ b/pony/orm/tests/test_hybrid_methods_and_properties.py @@ -70,6 +70,15 @@ class Car(db.Entity): db.generate_mapping(create_tables=True) + +def simple_func(person): + return person.full_name + + +def complex_func(person): + return person.complex_method() + + with db_session: p1 = Person(id=1, first_name='Alexander', last_name='Kozlovsky', favorite_color='white') p2 = Person(id=2, first_name='Alexei', last_name='Malashkevich', favorite_color='green') @@ -213,6 +222,15 @@ def test_21(self): def test_22(self): q = select(p.simple_method() for p in Person)[:] + @db_session + def test_23(self): + q = select(simple_func(p) for p in Person)[:] + + @db_session + @raises_exception(TranslationError, 'person.complex_method(...) is too complex to decompile (inside complex_func)') + def test_24(self): + q = select(complex_func(p) for p in Person)[:] + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_query.py b/pony/orm/tests/test_query.py index f208541ba..94603fe06 100644 --- a/pony/orm/tests/test_query.py +++ b/pony/orm/tests/test_query.py @@ -55,11 +55,11 @@ def test4(self): def test5(self): x = ['A'] select(s for s in Student if s.name == x) - @raises_exception(TypeError, "Function 'f1' cannot be used this way: f1(s.gpa)") def test6(self): def f1(x): - return x + 1 - select(s for s in Student if f1(s.gpa) > 3) + return float(x) + 1 + students = select(s for s in Student if f1(s.gpa) > 4.25)[:] + self.assertEqual({s.id for s in students}, {3}) @raises_exception(NotImplementedError, "m1") def test7(self): class C1(object): diff --git a/pony/orm/tests/test_raw_sql.py b/pony/orm/tests/test_raw_sql.py index 99ab1aa86..4c869be73 100644 --- a/pony/orm/tests/test_raw_sql.py +++ b/pony/orm/tests/test_raw_sql.py @@ -153,7 +153,8 @@ def test_18(self): self.assertEqual(persons, [Person[1], Person[3], Person[2]]) @db_session - @raises_exception(TypeError, "Function 'raw_sql' cannot be used this way: raw_sql(p.name)") + @raises_exception(TranslationError, "Expression `raw_sql(p.name)` cannot be translated into SQL " + "because raw SQL fragment will be different for each row") def test_19(self): # raw_sql argument cannot depend on iterator variables select(p for p in Person if raw_sql(p.name))[:] From c85e2196e62c67b5d9f784d1b21e65433485842f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 7 Oct 2019 13:43:54 +0300 Subject: [PATCH 484/547] Improved ProgrammingError message in PostgreSQL: "Note: use column type `jsonb` instead of `json`" --- pony/orm/dbapiprovider.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 50caac12d..9d31dd8ce 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -53,7 +53,13 @@ def wrap_dbapi_exceptions(func, provider, *args, **kwargs): try: return func(provider, *args, **kwargs) finally: provider.local_exceptions.keep_traceback = False except dbapi_module.NotSupportedError as e: raise NotSupportedError(e) - except dbapi_module.ProgrammingError as e: raise ProgrammingError(e) + except dbapi_module.ProgrammingError as e: + if provider.dialect == 'PostgreSQL': + msg = str(e) + if msg.startswith('operator does not exist:') and ' json ' in msg: + msg += ' (Note: use column type `jsonb` instead of `json`)' + raise ProgrammingError(e, msg, *e.args[1:]) + raise ProgrammingError(e) except dbapi_module.InternalError as e: raise InternalError(e) except dbapi_module.IntegrityError as e: raise IntegrityError(e) except dbapi_module.OperationalError as e: From 5e53624b4632a9260d478d06378df620c550e835 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 8 Oct 2019 15:23:07 +0300 Subject: [PATCH 485/547] Remove unused code --- pony/orm/core.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index a5ec899c3..3f3edb28d 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4649,7 +4649,6 @@ def __reduce__(obj): OrmError, '%s object %s has to be stored in DB before it can be pickled' % (obj._status_.capitalize(), safe_repr(obj))) d = {'__class__' : obj.__class__} - adict = obj._adict_ for attr, val in iteritems(obj._vals_): if not attr.is_collection: d[attr.name] = val return unpickle_entity, (d,) From fb7e384a32f6097b9fb2fdb4327966b8529f8bea Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 7 Oct 2019 16:57:44 +0300 Subject: [PATCH 486/547] Bug fixed: incorrect unpickling of objects with Json attributes --- pony/orm/core.py | 42 +++-- pony/orm/tests/test_json.py | 314 +++++++++++++++++++----------------- 2 files changed, 188 insertions(+), 168 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 3f3edb28d..70515dcbb 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -4888,6 +4888,18 @@ def _db_set_(obj, avdict, unpickling=False): del avdict[attr] continue + if unpickling: + new_vals = avdict + new_dbvals = {attr: attr.converters[0].val2dbval(val, obj) if not attr.reverse else val + for attr, val in iteritems(avdict)} + else: + new_dbvals = avdict + new_vals = {attr: attr.converters[0].dbval2val(dbval, obj) if not attr.reverse else dbval + for attr, dbval in iteritems(avdict)} + + for attr, new_val in items_list(new_vals): + new_dbval = new_dbvals[attr] + old_dbval = get_dbval(attr, NOT_LOADED) bit = obj._bits_except_volatile_[attr] if rbits & bit: errormsg = 'Please contact PonyORM developers so they can ' \ @@ -4899,28 +4911,26 @@ def _db_set_(obj, avdict, unpickling=False): if attr.reverse: attr.db_update_reverse(obj, old_dbval, new_dbval) obj._dbvals_[attr] = new_dbval - if wbits & bit: del avdict[attr] + if wbits & bit: + del new_vals[attr] + + for attr, new_val in iteritems(new_vals): if attr.is_unique: old_val = get_val(attr) - if old_val != new_dbval: - cache.db_update_simple_index(obj, attr, old_val, new_dbval) + if old_val != new_val: + cache.db_update_simple_index(obj, attr, old_val, new_val) for attrs in obj._composite_keys_: - if any(attr in avdict for attr in attrs): - vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! - prev_vals = tuple(vals) + if any(attr in new_vals for attr in attrs): + key_vals = [ get_val(a) for a in attrs ] # In Python 2 var name leaks into the function scope! + prev_key_vals = tuple(key_vals) for i, attr in enumerate(attrs): - if attr in avdict: vals[i] = avdict[attr] - new_vals = tuple(vals) - if prev_vals != new_vals: - cache.db_update_composite_index(obj, attrs, prev_vals, new_vals) + if attr in new_vals: key_vals[i] = new_vals[attr] + new_key_vals = tuple(key_vals) + if prev_key_vals != new_key_vals: + cache.db_update_composite_index(obj, attrs, prev_key_vals, new_key_vals) - for attr, new_val in iteritems(avdict): - if not attr.reverse: - assert len(attr.converters) == 1, attr - converter = attr.converters[0] - new_val = converter.dbval2val(new_val, obj) - obj._vals_[attr] = new_val + obj._vals_.update(new_vals) def _delete_(obj, undo_funcs=None): status = obj._status_ if status in del_statuses: return diff --git a/pony/orm/tests/test_json.py b/pony/orm/tests/test_json.py index 23a50a142..2c01d5dc0 100644 --- a/pony/orm/tests/test_json.py +++ b/pony/orm/tests/test_json.py @@ -1,4 +1,4 @@ -from pony.py23compat import basestring +from pony.py23compat import basestring, pickle import unittest @@ -7,22 +7,23 @@ from pony.orm.ormtypes import Json, TrackedValue, TrackedList, TrackedDict +db = Database('sqlite', ':memory:') -class TestJson(unittest.TestCase): +class Product(db.Entity): + name = Required(str) + info = Optional(Json) + tags = Optional(Json) - def setUp(self): - self.db = Database('sqlite', ':memory:') - class Product(self.db.Entity): - name = Required(str) - info = Optional(Json) - tags = Optional(Json) +db.generate_mapping(create_tables=True) - self.db.generate_mapping(create_tables=True) - self.Product = Product +class TestJson(unittest.TestCase): + def setUp(self): with db_session: - self.Product( + Product.select().delete(bulk=True) + flush() + Product( name='Apple iPad Air 2', info={ 'name': 'Apple iPad Air 2', @@ -66,50 +67,50 @@ class Product(self.db.Entity): def test(self): with db_session: - result = select(p for p in self.Product)[:] + result = select(p for p in Product)[:] self.assertEqual(len(result), 1) p = result[0] p.info['os']['version'] = '9' with db_session: - result = select(p for p in self.Product)[:] + result = select(p for p in Product)[:] self.assertEqual(len(result), 1) p = result[0] self.assertEqual(p.info['os']['version'], '9') @db_session def test_query_int(self): - val = get(p.info['display']['resolution'][0] for p in self.Product) + val = get(p.info['display']['resolution'][0] for p in Product) self.assertEqual(val, 2048) @db_session def test_query_float(self): - val = get(p.info['display']['size'] for p in self.Product) + val = get(p.info['display']['size'] for p in Product) self.assertAlmostEqual(val, 9.7) @db_session def test_query_true(self): - val = get(p.info['display']['multi-touch'] for p in self.Product) + val = get(p.info['display']['multi-touch'] for p in Product) self.assertIs(val, True) @db_session def test_query_false(self): - val = get(p.info['discontinued'] for p in self.Product) + val = get(p.info['discontinued'] for p in Product) self.assertIs(val, False) @db_session def test_query_null(self): - val = get(p.info['videoUrl'] for p in self.Product) + val = get(p.info['videoUrl'] for p in Product) self.assertIs(val, None) @db_session def test_query_list(self): - val = get(p.info['colors'] for p in self.Product) + val = get(p.info['colors'] for p in Product) self.assertListEqual(val, ['Gold', 'Silver', 'Space Gray']) self.assertNotIsInstance(val, TrackedValue) @db_session def test_query_dict(self): - val = get(p.info['display'] for p in self.Product) + val = get(p.info['display'] for p in Product) self.assertDictEqual(val, { 'size': 9.7, 'resolution': [2048, 1536], @@ -120,7 +121,7 @@ def test_query_dict(self): @db_session def test_query_json_field(self): - val = get(p.info for p in self.Product) + val = get(p.info for p in Product) self.assertDictEqual(val['display'], { 'size': 9.7, 'resolution': [2048, 1536], @@ -128,13 +129,13 @@ def test_query_json_field(self): 'multi-touch': True }) self.assertNotIsInstance(val['display'], TrackedDict) - val = get(p.tags for p in self.Product) + val = get(p.tags for p in Product) self.assertListEqual(val, ['Tablets', 'Apple', 'Retina']) self.assertNotIsInstance(val, TrackedList) @db_session def test_get_object(self): - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertDictEqual(p.info['display'], { 'size': 9.7, 'resolution': [2048, 1536], @@ -151,269 +152,269 @@ def test_get_object(self): def test_set_str(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) p.info['os']['version'] = '9' with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertEqual(p.info['os']['version'], '9') def test_set_int(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) p.info['display']['resolution'][0] += 1 with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertEqual(p.info['display']['resolution'][0], 2049) def test_set_true(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) p.info['discontinued'] = True with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertIs(p.info['discontinued'], True) def test_set_false(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) p.info['display']['multi-touch'] = False with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertIs(p.info['display']['multi-touch'], False) def test_set_null(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) p.info['display'] = None with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertIs(p.info['display'], None) def test_set_list(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) p.info['colors'] = ['Pink', 'Black'] with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertListEqual(p.info['colors'], ['Pink', 'Black']) def test_list_del(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) del p.info['colors'][1] with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertListEqual(p.info['colors'], ['Gold', 'Space Gray']) def test_list_append(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) p.info['colors'].append('White') with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertListEqual(p.info['colors'], ['Gold', 'Silver', 'Space Gray', 'White']) def test_list_set_slice(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) p.info['colors'][1:] = ['White'] with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertListEqual(p.info['colors'], ['Gold', 'White']) def test_list_set_item(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) p.info['colors'][1] = 'White' with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertListEqual(p.info['colors'], ['Gold', 'White', 'Space Gray']) def test_set_dict(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) p.info['display']['resolution'] = {'width': 2048, 'height': 1536} with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertDictEqual(p.info['display']['resolution'], {'width': 2048, 'height': 1536}) def test_dict_del(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) del p.info['os']['version'] with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertDictEqual(p.info['os'], {'type': 'iOS'}) def test_dict_pop(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) p.info['os'].pop('version') with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertDictEqual(p.info['os'], {'type': 'iOS'}) def test_dict_update(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) p.info['os'].update(version='9') with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertDictEqual(p.info['os'], {'type': 'iOS', 'version': '9'}) def test_dict_set_item(self): with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) p.info['os']['version'] = '9' with db_session: - p = get(p for p in self.Product) + p = get(p for p in Product) self.assertDictEqual(p.info['os'], {'type': 'iOS', 'version': '9'}) @db_session def test_set_same_value(self): - p = get(p for p in self.Product) + p = get(p for p in Product) p.info = p.info @db_session def test_len(self): - with raises_if(self, self.db.provider.dialect == 'Oracle', + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, 'Oracle does not provide `length` function for JSON arrays'): - val = select(len(p.tags) for p in self.Product).first() + val = select(len(p.tags) for p in Product).first() self.assertEqual(val, 3) - val = select(len(p.info['colors']) for p in self.Product).first() + val = select(len(p.info['colors']) for p in Product).first() self.assertEqual(val, 3) @db_session def test_equal_str(self): - p = get(p for p in self.Product if p.info['name'] == 'Apple iPad Air 2') + p = get(p for p in Product if p.info['name'] == 'Apple iPad Air 2') self.assertTrue(p) @db_session def test_unicode_key(self): - p = get(p for p in self.Product if p.info[u'name'] == 'Apple iPad Air 2') + p = get(p for p in Product if p.info[u'name'] == 'Apple iPad Air 2') self.assertTrue(p) @db_session def test_equal_string_attr(self): - p = get(p for p in self.Product if p.info['name'] == p.name) + p = get(p for p in Product if p.info['name'] == p.name) self.assertTrue(p) @db_session def test_equal_param(self): x = 'Apple iPad Air 2' - p = get(p for p in self.Product if p.name == x) + p = get(p for p in Product if p.name == x) self.assertTrue(p) @db_session def test_composite_param(self): - with raises_if(self, self.db.provider.dialect == 'Oracle', + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle doesn't allow parameters in JSON paths"): key = 'models' index = 0 - val = get(p.info[key][index]['name'] for p in self.Product) + val = get(p.info[key][index]['name'] for p in Product) self.assertEqual(val, 'Wi-Fi') @db_session def test_composite_param_in_condition(self): - with raises_if(self, self.db.provider.dialect == 'Oracle', + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle doesn't allow parameters in JSON paths"): key = 'models' index = 0 - p = get(p for p in self.Product if p.info[key][index]['name'] == 'Wi-Fi') + p = get(p for p in Product if p.info[key][index]['name'] == 'Wi-Fi') self.assertIsNotNone(p) @db_session def test_equal_json_1(self): - with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: " "p.info['os'] == {'type':'iOS', 'version':'8'}"): - p = get(p for p in self.Product if p.info['os'] == {'type': 'iOS', 'version': '8'}) + p = get(p for p in Product if p.info['os'] == {'type': 'iOS', 'version': '8'}) self.assertTrue(p) @db_session def test_equal_json_2(self): - with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: " "p.info['os'] == Json({'type':'iOS', 'version':'8'})"): - p = get(p for p in self.Product if p.info['os'] == Json({'type': 'iOS', 'version': '8'})) + p = get(p for p in Product if p.info['os'] == Json({'type': 'iOS', 'version': '8'})) self.assertTrue(p) @db_session def test_ne_json_1(self): - with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['os'] != {}"): - p = get(p for p in self.Product if p.info['os'] != {}) + p = get(p for p in Product if p.info['os'] != {}) self.assertTrue(p) - p = get(p for p in self.Product if p.info['os'] != {'type': 'iOS', 'version': '8'}) + p = get(p for p in Product if p.info['os'] != {'type': 'iOS', 'version': '8'}) self.assertFalse(p) @db_session def test_ne_json_2(self): - with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['os'] != Json({})"): - p = get(p for p in self.Product if p.info['os'] != Json({})) + p = get(p for p in Product if p.info['os'] != Json({})) self.assertTrue(p) - p = get(p for p in self.Product if p.info['os'] != {'type': 'iOS', 'version': '8'}) + p = get(p for p in Product if p.info['os'] != {'type': 'iOS', 'version': '8'}) self.assertFalse(p) @db_session def test_equal_list_1(self): colors = ['Gold', 'Silver', 'Space Gray'] - with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['colors'] == Json(colors)"): - p = get(p for p in self.Product if p.info['colors'] == Json(colors)) + p = get(p for p in Product if p.info['colors'] == Json(colors)) self.assertTrue(p) @db_session @raises_exception(TypeError, "Incomparable types 'Json' and 'list' in expression: p.info['colors'] == ['Gold']") def test_equal_list_2(self): - p = get(p for p in self.Product if p.info['colors'] == ['Gold']) + p = get(p for p in Product if p.info['colors'] == ['Gold']) @db_session def test_equal_list_3(self): - with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['colors'] != Json(['Gold'])"): - p = get(p for p in self.Product if p.info['colors'] != Json(['Gold'])) + p = get(p for p in Product if p.info['colors'] != Json(['Gold'])) self.assertIsNotNone(p) @db_session def test_equal_list_4(self): colors = ['Gold', 'Silver', 'Space Gray'] - with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['colors'] == Json(colors)"): - p = get(p for p in self.Product if p.info['colors'] == Json(colors)) + p = get(p for p in Product if p.info['colors'] == Json(colors)) self.assertTrue(p) @db_session @raises_exception(TypeError, "Incomparable types 'Json' and 'list' in expression: p.info['colors'] == []") def test_equal_empty_list_1(self): - p = get(p for p in self.Product if p.info['colors'] == []) + p = get(p for p in Product if p.info['colors'] == []) @db_session def test_equal_empty_list_2(self): - with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['colors'] == Json([])"): - p = get(p for p in self.Product if p.info['colors'] == Json([])) + p = get(p for p in Product if p.info['colors'] == Json([])) self.assertIsNone(p) @db_session def test_ne_list(self): - with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['colors'] != Json(['Gold'])"): - p = get(p for p in self.Product if p.info['colors'] != Json(['Gold'])) + p = get(p for p in Product if p.info['colors'] != Json(['Gold'])) self.assertTrue(p) @db_session def test_ne_empty_list(self): - with raises_if(self, self.db.provider.dialect == 'Oracle', TranslationError, + with raises_if(self, db.provider.dialect == 'Oracle', TranslationError, "Oracle does not support comparison of json structures: p.info['colors'] != Json([])"): - p = get(p for p in self.Product if p.info['colors'] != Json([])) + p = get(p for p in Product if p.info['colors'] != Json([])) self.assertTrue(p) @db_session def test_dbval2val(self): - p = select(p for p in self.Product)[:][0] - attr = self.Product.info + p = select(p for p in Product)[:][0] + attr = Product.info val = p._vals_[attr] dbval = p._dbvals_[attr] self.assertIsInstance(dbval, basestring) @@ -427,56 +428,56 @@ def test_dbval2val(self): @db_session def test_wildcard_path_1(self): - with raises_if(self, self.db.provider.dialect not in ('Oracle', 'MySQL'), + with raises_if(self, db.provider.dialect not in ('Oracle', 'MySQL'), TranslationError, '...does not support wildcards in JSON path...'): - names = get(p.info['models'][:]['name'] for p in self.Product) + names = get(p.info['models'][:]['name'] for p in Product) self.assertSetEqual(set(names), {'Wi-Fi', 'Wi-Fi + Cellular'}) @db_session def test_wildcard_path_2(self): - with raises_if(self, self.db.provider.dialect not in ('Oracle', 'MySQL'), + with raises_if(self, db.provider.dialect not in ('Oracle', 'MySQL'), TranslationError, '...does not support wildcards in JSON path...'): - values = get(p.info['os'][...] for p in self.Product) + values = get(p.info['os'][...] for p in Product) self.assertSetEqual(set(values), {'iOS', '8'}) @db_session def test_wildcard_path_3(self): - with raises_if(self, self.db.provider.dialect not in ('Oracle', 'MySQL'), + with raises_if(self, db.provider.dialect not in ('Oracle', 'MySQL'), TranslationError, '...does not support wildcards in JSON path...'): - names = get(p.info[...][0]['name'] for p in self.Product) + names = get(p.info[...][0]['name'] for p in Product) self.assertSetEqual(set(names), {'Wi-Fi'}) @db_session def test_wildcard_path_4(self): - if self.db.provider.dialect == 'Oracle': + if db.provider.dialect == 'Oracle': raise unittest.SkipTest - with raises_if(self, self.db.provider.dialect != 'MySQL', + with raises_if(self, db.provider.dialect != 'MySQL', TranslationError, '...does not support wildcards in JSON path...'): - values = get(p.info[...][:][...][:] for p in self.Product)[:] + values = get(p.info[...][:][...][:] for p in Product)[:] self.assertSetEqual(set(values), {'16GB', '64GB'}) @db_session def test_wildcard_path_with_params(self): - if self.db.provider.dialect != 'Oracle': + if db.provider.dialect != 'Oracle': exc_msg = '...does not support wildcards in JSON path...' else: exc_msg = "Oracle doesn't allow parameters in JSON paths" - with raises_if(self, self.db.provider.dialect != 'MySQL', TranslationError, exc_msg): + with raises_if(self, db.provider.dialect != 'MySQL', TranslationError, exc_msg): key = 'models' index = 0 - values = get(p.info[key][:]['capacity'][index] for p in self.Product) + values = get(p.info[key][:]['capacity'][index] for p in Product) self.assertListEqual(values, ['16GB', '16GB']) @db_session def test_wildcard_path_with_params_as_string(self): - if self.db.provider.dialect != 'Oracle': + if db.provider.dialect != 'Oracle': exc_msg = '...does not support wildcards in JSON path...' else: exc_msg = "Oracle doesn't allow parameters in JSON paths" - with raises_if(self, self.db.provider.dialect != 'MySQL', TranslationError, exc_msg): + with raises_if(self, db.provider.dialect != 'MySQL', TranslationError, exc_msg): key = 'models' index = 0 - values = get("p.info[key][:]['capacity'][index] for p in self.Product") + values = get("p.info[key][:]['capacity'][index] for p in Product") self.assertListEqual(values, ['16GB', '16GB']) @db_session @@ -486,142 +487,141 @@ def test_wildcard_path_in_condition(self): 'SQLite': '...does not support wildcards in JSON path...', 'PostgreSQL': '...does not support wildcards in JSON path...' } - dialect = self.db.provider.dialect + dialect = db.provider.dialect with raises_if(self, dialect in errors, TranslationError, errors.get(dialect)): - p = get(p for p in self.Product if '16GB' in p.info['models'][:]['capacity']) + p = get(p for p in Product if '16GB' in p.info['models'][:]['capacity']) self.assertTrue(p) ##### 'key' in json @db_session def test_in_dict(self): - obj = get(p for p in self.Product if 'resolution' in p.info['display']) + obj = get(p for p in Product if 'resolution' in p.info['display']) self.assertTrue(obj) @db_session def test_not_in_dict(self): - obj = get(p for p in self.Product if 'resolution' not in p.info['display']) + obj = get(p for p in Product if 'resolution' not in p.info['display']) self.assertIs(obj, None) - obj = get(p for p in self.Product if 'xyz' not in p.info['display']) + obj = get(p for p in Product if 'xyz' not in p.info['display']) self.assertTrue(obj) @db_session def test_in_list(self): - obj = get(p for p in self.Product if 'Gold' in p.info['colors']) + obj = get(p for p in Product if 'Gold' in p.info['colors']) self.assertTrue(obj) @db_session def test_not_in_list(self): - obj = get(p for p in self.Product if 'White' not in p.info['colors']) + obj = get(p for p in Product if 'White' not in p.info['colors']) self.assertTrue(obj) - obj = get(p for p in self.Product if 'Gold' not in p.info['colors']) + obj = get(p for p in Product if 'Gold' not in p.info['colors']) self.assertIs(obj, None) @db_session def test_var_in_json(self): - with raises_if(self, self.db.provider.dialect == 'Oracle', + with raises_if(self, db.provider.dialect == 'Oracle', TypeError, "For `key in JSON` operation Oracle supports literal key values only, " "parameters are not allowed: key in p.info['colors']"): key = 'Gold' - obj = get(p for p in self.Product if key in p.info['colors']) + obj = get(p for p in Product if key in p.info['colors']) self.assertTrue(obj) @db_session def test_select_first(self): # query should not contain ORDER BY - obj = select(p.info for p in self.Product).first() - self.assertNotIn('order by', self.db.last_sql.lower()) + obj = select(p.info for p in Product).first() + self.assertNotIn('order by', db.last_sql.lower()) def test_sql_inject(self): # test quote in json is not causing error with db_session: - p = select(p for p in self.Product).first() + p = select(p for p in Product).first() p.info['display']['size'] = "0' 9.7\"" with db_session: - p = select(p for p in self.Product).first() + p = select(p for p in Product).first() self.assertEqual(p.info['display']['size'], "0' 9.7\"") @db_session def test_int_compare(self): - p = get(p for p in self.Product if p.info['display']['resolution'][0] == 2048) + p = get(p for p in Product if p.info['display']['resolution'][0] == 2048) self.assertTrue(p) - p = get(p for p in self.Product if p.info['display']['resolution'][0] != 2048) + p = get(p for p in Product if p.info['display']['resolution'][0] != 2048) self.assertIsNone(p) - p = get(p for p in self.Product if p.info['display']['resolution'][0] < 2048) + p = get(p for p in Product if p.info['display']['resolution'][0] < 2048) self.assertIs(p, None) - p = get(p for p in self.Product if p.info['display']['resolution'][0] <= 2048) + p = get(p for p in Product if p.info['display']['resolution'][0] <= 2048) self.assertTrue(p) - p = get(p for p in self.Product if p.info['display']['resolution'][0] > 2048) + p = get(p for p in Product if p.info['display']['resolution'][0] > 2048) self.assertIs(p, None) - p = get(p for p in self.Product if p.info['display']['resolution'][0] >= 2048) + p = get(p for p in Product if p.info['display']['resolution'][0] >= 2048) self.assertTrue(p) @db_session def test_float_compare(self): - p = get(p for p in self.Product if p.info['display']['size'] > 9.5) + p = get(p for p in Product if p.info['display']['size'] > 9.5) self.assertTrue(p) - p = get(p for p in self.Product if p.info['display']['size'] < 9.8) + p = get(p for p in Product if p.info['display']['size'] < 9.8) self.assertTrue(p) - p = get(p for p in self.Product if p.info['display']['size'] < 9.5) + p = get(p for p in Product if p.info['display']['size'] < 9.5) self.assertIsNone(p) - p = get(p for p in self.Product if p.info['display']['size'] > 9.8) + p = get(p for p in Product if p.info['display']['size'] > 9.8) self.assertIsNone(p) @db_session def test_str_compare(self): - p = get(p for p in self.Product if p.info['ram'] == '8GB') + p = get(p for p in Product if p.info['ram'] == '8GB') self.assertTrue(p) - p = get(p for p in self.Product if p.info['ram'] != '8GB') + p = get(p for p in Product if p.info['ram'] != '8GB') self.assertIsNone(p) - p = get(p for p in self.Product if p.info['ram'] < '9GB') + p = get(p for p in Product if p.info['ram'] < '9GB') self.assertTrue(p) - p = get(p for p in self.Product if p.info['ram'] > '7GB') + p = get(p for p in Product if p.info['ram'] > '7GB') self.assertTrue(p) - p = get(p for p in self.Product if p.info['ram'] > '9GB') + p = get(p for p in Product if p.info['ram'] > '9GB') self.assertIsNone(p) - p = get(p for p in self.Product if p.info['ram'] < '7GB') + p = get(p for p in Product if p.info['ram'] < '7GB') self.assertIsNone(p) @db_session def test_bool_compare(self): - p = get(p for p in self.Product if p.info['display']['multi-touch'] == True) + p = get(p for p in Product if p.info['display']['multi-touch'] == True) self.assertTrue(p) - p = get(p for p in self.Product if p.info['display']['multi-touch'] is True) + p = get(p for p in Product if p.info['display']['multi-touch'] is True) self.assertTrue(p) - p = get(p for p in self.Product if p.info['display']['multi-touch'] == False) + p = get(p for p in Product if p.info['display']['multi-touch'] == False) self.assertIsNone(p) - p = get(p for p in self.Product if p.info['display']['multi-touch'] is False) + p = get(p for p in Product if p.info['display']['multi-touch'] is False) self.assertIsNone(p) - p = get(p for p in self.Product if p.info['discontinued'] == False) + p = get(p for p in Product if p.info['discontinued'] == False) self.assertTrue(p) - p = get(p for p in self.Product if p.info['discontinued'] == True) + p = get(p for p in Product if p.info['discontinued'] == True) self.assertIsNone(p) @db_session def test_none_compare(self): - p = get(p for p in self.Product if p.info['videoUrl'] is None) + p = get(p for p in Product if p.info['videoUrl'] is None) self.assertTrue(p) - p = get(p for p in self.Product if p.info['videoUrl'] is not None) + p = get(p for p in Product if p.info['videoUrl'] is not None) self.assertIsNone(p) @db_session def test_none_for_nonexistent_path(self): - p = get(p for p in self.Product if p.info['some_attr'] is None) + p = get(p for p in Product if p.info['some_attr'] is None) self.assertTrue(p) @db_session def test_str_cast(self): - p = get(coalesce(str(p.name), 'empty') for p in self.Product) - self.assertTrue('AS text' in self.db.last_sql) + p = get(coalesce(str(p.name), 'empty') for p in Product) + self.assertTrue('AS text' in db.last_sql) @db_session def test_int_cast(self): - p = get(coalesce(int(p.info['os']['version']), 0) for p in self.Product) - self.assertTrue('as integer' in self.db.last_sql) + p = get(coalesce(int(p.info['os']['version']), 0) for p in Product) + self.assertTrue('as integer' in db.last_sql) def test_nonzero(self): - Product = self.Product with db_session: delete(p for p in Product) Product(name='P1', info=dict(id=1, val=True)) @@ -649,7 +649,7 @@ def test_nonzero(self): @db_session def test_optimistic_check(self): - p1 = self.Product.select().first() + p1 = Product.select().first() p1.info['foo'] = 'bar' flush() p1.name = 'name2' @@ -659,5 +659,15 @@ def test_optimistic_check(self): @db_session def test_avg(self): - result = select(avg(p.info['display']['size']) for p in self.Product).first() + result = select(avg(p.info['display']['size']) for p in Product).first() self.assertAlmostEqual(result, 9.7) + + def test_pickle(self): + with db_session: + p1 = Product.select().first() + data = pickle.dumps(p1) + with db_session: + p1 = pickle.loads(data) + p1.name = 'name2' + flush() + rollback() From a0be95ac91acc0c288270c7a90485c67a9fbbba6 Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Sat, 12 Oct 2019 00:02:29 +0300 Subject: [PATCH 487/547] Update BACKERS.md --- BACKERS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/BACKERS.md b/BACKERS.md index aa7ee850b..033085f51 100644 --- a/BACKERS.md +++ b/BACKERS.md @@ -6,6 +6,7 @@ Pony ORM is Apache 2.0 licensed open source project. If you would like to suppor ## Backers +- [Vincere](https://vince.re) - Sergio Aguilar Guerrero - David ROUBLOT - Elijas Dapšauskas From d1c6bb21f6ec232049bc22653e5e63d3a526206d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 14 Oct 2019 16:25:58 +0300 Subject: [PATCH 488/547] Fix bulk delete queries --- pony/orm/sqltranslation.py | 24 ++++++++++++------------ pony/orm/tests/queries.txt | 14 ++++++++++++++ 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 78f2dc009..e2f9196c4 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -362,7 +362,7 @@ def check_name_is_single(): else: check_name_is_single() attr_names = [] - while isinstance(monad, AttrSetMonad) and monad.parent is not None: + while isinstance(monad, (AttrMonad, AttrSetMonad)) and monad.parent is not None: attr_names.append(monad.attr.name) monad = monad.parent attr_names.reverse() @@ -1193,20 +1193,20 @@ def __init__(sqlquery, translator, parent_sqlquery=None, left_join=False): sqlquery.alias_counters = parent_sqlquery.alias_counters.copy() sqlquery.expr_counter = parent_sqlquery.expr_counter sqlquery.used_from_subquery = False - def get_tableref(sqlquery, name_path, from_subquery=False): + def get_tableref(sqlquery, name_path): tableref = sqlquery.tablerefs.get(name_path) - if tableref is not None: - if from_subquery and sqlquery.parent_sqlquery is None: - sqlquery.used_from_subquery = True - return tableref - if sqlquery.parent_sqlquery: - return sqlquery.parent_sqlquery.get_tableref(name_path, from_subquery=True) - return None + parent_sqlquery = sqlquery.parent_sqlquery + if tableref is None and parent_sqlquery: + tableref = parent_sqlquery.get_tableref(name_path) + if tableref is not None: + parent_sqlquery.used_from_subquery = True + return tableref def add_tableref(sqlquery, name_path, parent_tableref, attr): - tablerefs = sqlquery.tablerefs - assert name_path not in tablerefs + assert name_path not in sqlquery.tablerefs + if parent_tableref.sqlquery is not sqlquery: + parent_tableref.sqlquery.used_from_subquery = True tableref = JoinedTableRef(sqlquery, name_path, parent_tableref, attr) - tablerefs[name_path] = tableref + sqlquery.tablerefs[name_path] = tableref return tableref def make_alias(sqlquery, name): name = name[:max_alias_length-3].lower() diff --git a/pony/orm/tests/queries.txt b/pony/orm/tests/queries.txt index ac371de88..ec1d5fd25 100644 --- a/pony/orm/tests/queries.txt +++ b/pony/orm/tests/queries.txt @@ -893,6 +893,20 @@ WHERE "ID" IN ( ) ) +>>> select(s for s in Student if count(g for g in s.group.dept.groups) > 2).delete(bulk=True) + +DELETE FROM "Student" +WHERE "id" IN ( + SELECT "s"."id" + FROM "Student" "s" + WHERE ( + SELECT COUNT(DISTINCT "g"."number") + FROM "Group" "group", "Group" "g" + WHERE "s"."group" = "group"."number" + AND "group"."dept" = "g"."dept" + ) > 2 + ) + # Test UPPER/LOWER functions: >>> select(s.name.upper() for s in Student) From 37dc81767d0a7097998b5a34bba565624c4d25e8 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 15 Oct 2019 18:55:28 +0300 Subject: [PATCH 489/547] #472: Fix warning in Python 3.8 --- pony/orm/core.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 70515dcbb..7de87dbce 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -421,7 +421,7 @@ class DBSessionContextManager(object): 'sql_debug', 'show_values' def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False, strict=False, optimistic=True, retry_exceptions=(TransactionError,), allowed_exceptions=(), sql_debug=None, show_values=None): - if retry is not 0: + if retry != 0: if type(retry) is not int: throw(TypeError, "'retry' parameter of db_session must be of integer type. Got: %s" % type(retry)) if retry < 0: throw(TypeError, @@ -454,7 +454,7 @@ def __call__(db_session, *args, **kwargs): return db_session._wrap_coroutine_or_generator_function(func) return db_session._wrap_function(func) def __enter__(db_session): - if db_session.retry is not 0: throw(TypeError, + if db_session.retry != 0: throw(TypeError, "@db_session can accept 'retry' parameter only when used as decorator and not as context manager") db_session._enter() def _enter(db_session): From b9527faa9422aa6ec178aafb111f6f58c9a45327 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 15 Oct 2019 18:55:55 +0300 Subject: [PATCH 490/547] #472: Fix test for Python 3.8 --- pony/orm/tests/test_declarative_exceptions.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pony/orm/tests/test_declarative_exceptions.py b/pony/orm/tests/test_declarative_exceptions.py index 718343b40..b4fa44631 100644 --- a/pony/orm/tests/test_declarative_exceptions.py +++ b/pony/orm/tests/test_declarative_exceptions.py @@ -153,8 +153,9 @@ def test29(self): @raises_exception(NotImplementedError, "date(s.id, 1, 1)") def test30(self): select(s for s in Student if s.dob < date(s.id, 1, 1)) - @raises_exception(ExprEvalError, "`max()` raises TypeError: max expected 1 arguments, got 0" if not PYPY else - "`max()` raises TypeError: max() expects at least one argument") + @raises_exception(ExprEvalError, "`max()` raises TypeError: max() expects at least one argument" if PYPY else + "`max()` raises TypeError: max expected 1 arguments, got 0" if sys.version_info[:2] < (3, 8) else + "`max()` raises TypeError: max expected 1 argument, got 0") def test31(self): select(s for s in Student if s.id < max()) @raises_exception(TypeError, "Incomparable types 'Student' and 'Course' in expression: s in s.courses") From aa5ed83bab0ad22d2125d7781fab14e24ae50724 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 15 Oct 2019 19:57:29 +0300 Subject: [PATCH 491/547] #472: Fix parser for Python 3.8: ignore `namedexpr_test` node --- pony/thirdparty/compiler/transformer.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/pony/thirdparty/compiler/transformer.py b/pony/thirdparty/compiler/transformer.py index d9cb9b5a9..4b13652b9 100644 --- a/pony/thirdparty/compiler/transformer.py +++ b/pony/thirdparty/compiler/transformer.py @@ -142,10 +142,18 @@ def __init__(self): }) self.encoding = None + def print_tree(self, tree, indent=''): + for item in tree: + if isinstance(item, tuple): + self.print_tree(item, indent+' ') + else: + print(indent, symbol.sym_name.get(item, item)) + def transform(self, tree): """Transform an AST into a modified parse tree.""" if not (isinstance(tree, tuple) or isinstance(tree, list)): tree = parser.st2tuple(tree, line_info=1) + # self.print_tree(tree) return self.compile_node(tree) def parsesuite(self, text): @@ -614,7 +622,11 @@ def star_expr(self, *args): def testlist_comp(self, nodelist): # test ( comp_for | (',' test)* [','] ) - assert nodelist[0][0] == symbol.test + PY38 = sys.version_info >= (3, 8) + if PY38 and nodelist[0][0] == symbol.namedexpr_test: + nodelist = (nodelist[0][1],) + nodelist[1:] + if nodelist[0][0] != symbol.test: + assert False, symbol.sym_name.get(nodelist[0][0], nodelist[0][0]) if len(nodelist) == 2 and nodelist[1][0] == symbol.comp_for: test = self.com_node(nodelist[0]) return self.com_generator_expression(test, nodelist[1]) From 1d60b3779d3077119a36e52303cc77cc95f7fd05 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 15 Oct 2019 20:00:05 +0300 Subject: [PATCH 492/547] Fixes #472: Specify Python 3.8 support in setup.py --- setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index e1689e30c..0c748118c 100644 --- a/setup.py +++ b/setup.py @@ -72,6 +72,7 @@ def test_suite(): 'Programming Language :: Python :: 3.5', 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', 'Programming Language :: Python :: Implementation :: PyPy', 'Topic :: Software Development :: Libraries', 'Topic :: Database' @@ -105,8 +106,8 @@ def test_suite(): if __name__ == "__main__": pv = sys.version_info[:2] - if pv not in ((2, 7), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7)): - s = "Sorry, but %s %s requires Python of one of the following versions: 2.7, 3.3-3.7." \ + if pv not in ((2, 7), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8)): + s = "Sorry, but %s %s requires Python of one of the following versions: 2.7, 3.3-3.8." \ " You have version %s" print(s % (name, version, sys.version.split(' ', 1)[0])) sys.exit(1) From ecc55b5aa940165dbee432e76d5ed76ad0c2db68 Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Thu, 19 Sep 2019 19:42:08 +0300 Subject: [PATCH 493/547] Update BACKERS.md --- BACKERS.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/BACKERS.md b/BACKERS.md index 6b4a6ba51..aa7ee850b 100644 --- a/BACKERS.md +++ b/BACKERS.md @@ -10,4 +10,6 @@ Pony ORM is Apache 2.0 licensed open source project. If you would like to suppor - David ROUBLOT - Elijas Dapšauskas - Dan Swain - +- Christian Macht +- Johnathan Nader +- Andrei Rachalouski From 62b03f7e577e8f551dfa29cb9422ffe3488fabf3 Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Sat, 12 Oct 2019 00:02:29 +0300 Subject: [PATCH 494/547] Update BACKERS.md --- BACKERS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/BACKERS.md b/BACKERS.md index aa7ee850b..033085f51 100644 --- a/BACKERS.md +++ b/BACKERS.md @@ -6,6 +6,7 @@ Pony ORM is Apache 2.0 licensed open source project. If you would like to suppor ## Backers +- [Vincere](https://vince.re) - Sergio Aguilar Guerrero - David ROUBLOT - Elijas Dapšauskas From 4efb56f00a8d80cb4169c279776442397b5c9aed Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 16 Oct 2019 17:59:57 +0300 Subject: [PATCH 495/547] Fix Python 3.8 support: add namedexpr_test method to transformer --- pony/thirdparty/compiler/transformer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/pony/thirdparty/compiler/transformer.py b/pony/thirdparty/compiler/transformer.py index 4b13652b9..18f5c0cfe 100644 --- a/pony/thirdparty/compiler/transformer.py +++ b/pony/thirdparty/compiler/transformer.py @@ -41,6 +41,7 @@ if not hasattr(symbol, 'comp_if'): symbol.comp_if = symbol.gen_if atom_expr = getattr(symbol, 'atom_expr', None) +namedexpr_test = getattr(symbol, 'namedexpr_test', None) class WalkerError(Exception): pass @@ -121,6 +122,9 @@ def atom_expr(self, nodelist): node = self.com_apply_trailer(node, elt) return node + def namedexpr_test(self, nodelist): + return self.test(nodelist[0][1:]) + def __init__(self): self._dispatch = {} for value, name in symbol.sym_name.items(): @@ -623,10 +627,9 @@ def star_expr(self, *args): def testlist_comp(self, nodelist): # test ( comp_for | (',' test)* [','] ) PY38 = sys.version_info >= (3, 8) - if PY38 and nodelist[0][0] == symbol.namedexpr_test: - nodelist = (nodelist[0][1],) + nodelist[1:] - if nodelist[0][0] != symbol.test: - assert False, symbol.sym_name.get(nodelist[0][0], nodelist[0][0]) + code = nodelist[0][0] + if code not in (symbol.test, namedexpr_test): + assert False, symbol.sym_name.get(code, code) if len(nodelist) == 2 and nodelist[1][0] == symbol.comp_for: test = self.com_node(nodelist[0]) return self.com_generator_expression(test, nodelist[1]) From bbf8de6816a8511069fdc31bc856ab611ae3f475 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 16 Oct 2019 18:06:22 +0300 Subject: [PATCH 496/547] Fix bulk delete queries --- pony/orm/sqltranslation.py | 16 ++++++++++------ pony/orm/tests/queries.txt | 10 ++++++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index e2f9196c4..8e9c45161 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -771,17 +771,21 @@ def construct_delete_sql_ast(translator): expr_monad = translator.tree.expr.monad if not isinstance(entity, EntityMeta): throw(TranslationError, 'Delete query should be applied to a single entity. Got: %s' % ast2src(translator.tree.expr)) - if translator.groupby_monads: throw(TranslationError, - 'Delete query cannot contains GROUP BY section or aggregate functions') - assert not translator.having_conditions + force_in = False + if translator.groupby_monads: + force_in = True + else: + assert not translator.having_conditions tableref = expr_monad.tableref from_ast = translator.sqlquery.from_ast - assert from_ast[0] == 'FROM' - if len(from_ast) == 2 and not translator.sqlquery.used_from_subquery: + if from_ast[0] != 'FROM': + force_in = True + + if not force_in and len(from_ast) == 2 and not translator.sqlquery.used_from_subquery: sql_ast = [ 'DELETE', None, from_ast ] if translator.conditions: sql_ast.append([ 'WHERE' ] + translator.conditions) - elif translator.dialect == 'MySQL': + elif not force_in and translator.dialect == 'MySQL': sql_ast = [ 'DELETE', tableref.alias, from_ast ] if translator.conditions: sql_ast.append([ 'WHERE' ] + translator.conditions) diff --git a/pony/orm/tests/queries.txt b/pony/orm/tests/queries.txt index ec1d5fd25..1992a7a4d 100644 --- a/pony/orm/tests/queries.txt +++ b/pony/orm/tests/queries.txt @@ -907,6 +907,16 @@ WHERE "id" IN ( ) > 2 ) +>>> Student.select(lambda s: count(s.group.students) == 2).delete(bulk=True) + +DELETE FROM "Student" +WHERE "id" IN ( + SELECT "s"."id" + FROM "Student" "s" + LEFT JOIN "Student" "student" + ON "s"."group" = "student"."group" + ) + # Test UPPER/LOWER functions: >>> select(s.name.upper() for s in Student) From 397e6ac7b31d08ffe697b4f51ea5303fa3e73f8d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 21 Oct 2019 16:28:41 +0300 Subject: [PATCH 497/547] Fixes #468: Tuple-value comparisons generate incorrect queries --- pony/orm/sqltranslation.py | 7 +++-- pony/orm/tests/queries.txt | 57 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+), 3 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 8e9c45161..faabda085 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2451,9 +2451,10 @@ def getsql(monad, sqlquery=None): if monad.translator.row_value_syntax: return [ [ cmp_ops[op], [ 'ROW' ] + left_sql, [ 'ROW' ] + right_sql ] ] clauses = [] - for i in xrange(1, size): - clauses.append(sqland([ [ monad.EQ, left_sql[j], right_sql[j] ] for j in xrange(1, i) ] - + [ [ cmp_ops[op[0] if i < size - 1 else op], left_sql[i], right_sql[i] ] ])) + for i in xrange(size): + clause = [ [ monad.EQ, left_sql[j], right_sql[j] ] for j in range(i) ] + clause.append([ cmp_ops[op], left_sql[i], right_sql[i] ]) + clauses.append(sqland(clause)) return [ sqlor(clauses) ] if op == '==': return [ sqland([ [ monad.EQ, a, b ] for a, b in izip(left_sql, right_sql) ]) ] diff --git a/pony/orm/tests/queries.txt b/pony/orm/tests/queries.txt index 1992a7a4d..877a05ca9 100644 --- a/pony/orm/tests/queries.txt +++ b/pony/orm/tests/queries.txt @@ -1014,3 +1014,60 @@ SELECT t.* FROM ( ) t ) t WHERE "row-num" > 3 +# Test row comparison: + +>>> select((s1.id, s2.id) for s1 in Student for s2 in Student if (s1.name, s1.gpa, s1.tel) < (s2.name, s2.gpa, s2.tel)) + +SELECT DISTINCT "s1"."id", "s2"."id" +FROM "Student" "s1", "Student" "s2" +WHERE ("s1"."name" < "s2"."name" OR "s1"."name" = "s2"."name" AND "s1"."gpa" < "s2"."gpa" OR "s1"."name" = "s2"."name" AND "s1"."gpa" = "s2"."gpa" AND "s1"."tel" < "s2"."tel") + +PostgreSQL: + +SELECT DISTINCT "s1"."id", "s2"."id" +FROM "student" "s1", "student" "s2" +WHERE ("s1"."name", "s1"."gpa", "s1"."tel") < ("s2"."name", "s2"."gpa", "s2"."tel") + +MySQL: + +SELECT DISTINCT `s1`.`id`, `s2`.`id` +FROM `student` `s1`, `student` `s2` +WHERE (`s1`.`name`, `s1`.`gpa`, `s1`.`tel`) < (`s2`.`name`, `s2`.`gpa`, `s2`.`tel`) + +Oracle: + +SELECT DISTINCT "s1"."ID", "s2"."ID" +FROM "STUDENT" "s1", "STUDENT" "s2" +WHERE ("s1"."NAME", "s1"."GPA", "s1"."TEL") < ("s2"."NAME", "s2"."GPA", "s2"."TEL") + +>>> select((s1.id, s2.id) for s1 in Student for s2 in Student if (s1.name, s1.gpa, s1.tel) == (s2.name, s2.gpa, s2.tel)) + +SELECT DISTINCT "s1"."id", "s2"."id" +FROM "Student" "s1", "Student" "s2" +WHERE "s1"."name" = "s2"."name" + AND "s1"."gpa" = "s2"."gpa" + AND "s1"."tel" = "s2"."tel" + +PostgreSQL: + +SELECT DISTINCT "s1"."id", "s2"."id" +FROM "student" "s1", "student" "s2" +WHERE "s1"."name" = "s2"."name" + AND "s1"."gpa" = "s2"."gpa" + AND "s1"."tel" = "s2"."tel" + +MySQL: + +SELECT DISTINCT `s1`.`id`, `s2`.`id` +FROM `student` `s1`, `student` `s2` +WHERE `s1`.`name` = `s2`.`name` + AND `s1`.`gpa` = `s2`.`gpa` + AND `s1`.`tel` = `s2`.`tel` + +Oracle: + +SELECT DISTINCT "s1"."ID", "s2"."ID" +FROM "STUDENT" "s1", "STUDENT" "s2" +WHERE "s1"."NAME" = "s2"."NAME" + AND "s1"."GPA" = "s2"."GPA" + AND "s1"."TEL" = "s2"."TEL" From 0fa39eb1412427151fc32539385ec6bf7b807472 Mon Sep 17 00:00:00 2001 From: Javier Caballero Date: Tue, 15 Oct 2019 09:38:59 +0200 Subject: [PATCH 498/547] Fix #470 'imp' module PendingDeprecationWarning The package 'imp' will be deprecated in favor of importlib. This change allows to use os 'importlib' when python version is greater than 2 or use 'imp' otherwise. --- pony/thirdparty/compiler/pycodegen.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pony/thirdparty/compiler/pycodegen.py b/pony/thirdparty/compiler/pycodegen.py index 0181b4fa8..fd24132c6 100644 --- a/pony/thirdparty/compiler/pycodegen.py +++ b/pony/thirdparty/compiler/pycodegen.py @@ -1,7 +1,6 @@ from __future__ import absolute_import, print_function from pony.py23compat import izip -import imp import os import marshal import struct @@ -123,7 +122,12 @@ def dump(self, f): f.write(self.getPycHeader()) marshal.dump(self.code, f) - MAGIC = imp.get_magic() + if VERSION < 3: + import imp + MAGIC = imp.get_magic() + else: + import importlib.util + MAGIC = importlib.util.MAGIC_NUMBER def getPycHeader(self): # compile.c uses marshal to write a long directly, with From 45cbff73bfeb9d8d84804134f7a33a58736297ba Mon Sep 17 00:00:00 2001 From: Jinxu <15059493+imfht@users.noreply.github.com> Date: Mon, 23 Sep 2019 16:45:59 +0800 Subject: [PATCH 499/547] Fixes #465, fixes #466, closes #467: Should reconnect to MySQL on OperationalError 2013 'Lost connection to MySQL server during query' --- pony/orm/dbproviders/mysql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 23ea895aa..3170daa60 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -230,7 +230,7 @@ def inspect_connection(provider, connection): cursor.execute('set session group_concat_max_len = 4294967295') def should_reconnect(provider, exc): - return isinstance(exc, mysql_module.OperationalError) and exc.args[0] == 2006 + return isinstance(exc, mysql_module.OperationalError) and exc.args[0] in (2006, 2013) def get_pool(provider, *args, **kwargs): if 'conv' not in kwargs: From 368adf8767b986a5baf4d2be26e320cea82441c6 Mon Sep 17 00:00:00 2001 From: Vladimir Date: Fri, 27 Sep 2019 13:40:58 -0300 Subject: [PATCH 500/547] Adding tests for Attribute Options and Entity Instances --- pony/orm/tests/test_attribute_options.py | 105 +++++++++++++++++++++++ pony/orm/tests/test_entity_instances.py | 103 ++++++++++++++++++++++ 2 files changed, 208 insertions(+) create mode 100644 pony/orm/tests/test_attribute_options.py create mode 100644 pony/orm/tests/test_entity_instances.py diff --git a/pony/orm/tests/test_attribute_options.py b/pony/orm/tests/test_attribute_options.py new file mode 100644 index 000000000..017637c84 --- /dev/null +++ b/pony/orm/tests/test_attribute_options.py @@ -0,0 +1,105 @@ +import unittest +from decimal import Decimal +from datetime import datetime, time +from random import randint + +from pony import orm +from pony.orm.core import * +from pony.orm.tests.testutils import raises_exception + +db = Database('sqlite', ':memory:') + +class Person(db.Entity): + name = orm.Required(str, 40) + lastName = orm.Required(str, max_len=40, unique=True) + age = orm.Optional(int, max=60, min=10) + nickName = orm.Optional(str, autostrip=False) + middleName = orm.Optional(str, nullable=True) + rate = orm.Optional(Decimal, precision=11) + salaryRate = orm.Optional(Decimal, precision=13, scale=8) + timeStmp = orm.Optional(datetime, precision=6) + gpa = orm.Optional(float, py_check=lambda val: val >= 0 and val <= 5) + vehicle = orm.Optional(str, column='car') + +db.generate_mapping(create_tables=True) + +with orm.db_session: + p1 = Person(name='Andrew', lastName='Bodroue', age=40, rate=0.980000000001, salaryRate=0.98000001) + p2 = Person(name='Vladimir', lastName='Andrew ', nickName='vlad ') + p3 = Person(name='Nick', lastName='Craig', middleName=None, timeStmp='2010-12-10 14:12:09.019473', vehicle='dodge') + +class TestAttributeOptions(unittest.TestCase): + + def setUp(self): + rollback() + db_session.__enter__() + + def tearDown(self): + rollback() + db_session.__exit__() + + def test_optionalStringEmpty(self): + queryResult = select(p.id for p in Person if p.nickName==None).first() + self.assertIsNone(queryResult) + + def test_optionalStringNone(self): + queryResult = select(p.id for p in Person if p.middleName==None).first() + self.assertIsNotNone(queryResult) + + def test_stringAutoStrip(self): + self.assertEqual(p2.lastName, 'Andrew') + + def test_stringAutoStripFalse(self): + self.assertEqual(p2.nickName, 'vlad ') + + def test_intNone(self): + queryResult = select(p.id for p in Person if p.age==None).first() + self.assertIsNotNone(queryResult) + + def test_columnName(self): + self.assertEqual(getattr(Person.vehicle, 'column'), 'car') + + def test_decimalPrecisionTwo(self): + queryResult = select(p.rate for p in Person if p.age==40).first() + self.assertAlmostEqual(float(queryResult), 0.98, 12) + + def test_decimalPrecisionEight(self): + queryResult = select(p.salaryRate for p in Person if p.age==40).first() + self.assertAlmostEqual(float(queryResult), 0.98000001, 8) + + def test_fractionalSeconds(self): + queryResult = select(p.timeStmp for p in Person if p.name=='Nick').first() + self.assertEqual(queryResult.microsecond, 19473) + + def test_intMax(self): + p4 = Person(name='Denis', lastName='Blanc', age=60) + + def test_intMin(self): + p4 = Person(name='Denis', lastName='Blanc', age=10) + + @raises_exception(ValueError, "Value 61 of attr Person.age is greater than the maximum allowed value 60") + def test_intMaxException(self): + p4 = Person(name='Denis', lastName='Blanc', age=61) + + @raises_exception(ValueError, "Value 9 of attr Person.age is less than the minimum allowed value 10") + def test_intMinException(self): + p4 = Person(name='Denis', lastName='Blanc', age=9) + + def test_py_check(self): + p4 = Person(name='Denis', lastName='Blanc', gpa=5) + p5 = Person(name='Mario', lastName='Gon', gpa=1) + flush() + + @raises_exception(ValueError, "Check for attribute Person.gpa failed. Value: 6.0") + def test_py_checkMoreException(self): + p6 = Person(name='Daniel', lastName='Craig', gpa=6) + + @raises_exception(ValueError, "Check for attribute Person.gpa failed. Value: -1.0") + def test_py_checkLessException(self): + p6 = Person(name='Daniel', lastName='Craig', gpa=-1) + + @raises_exception(TransactionIntegrityError, 'Object Person[new:8] cannot be stored in the database.' + ' IntegrityError: UNIQUE constraint failed: Person.lastName') + def test_unique(self): + p6 = Person(name='Boris', lastName='Bodroue') + flush() \ No newline at end of file diff --git a/pony/orm/tests/test_entity_instances.py b/pony/orm/tests/test_entity_instances.py new file mode 100644 index 000000000..bf1e302a0 --- /dev/null +++ b/pony/orm/tests/test_entity_instances.py @@ -0,0 +1,103 @@ +import unittest + +from pony import orm +from pony.orm.core import * +from pony.orm.tests.testutils import raises_exception + +db = Database('sqlite', ':memory:') + +class Person(db.Entity): + id = orm.PrimaryKey(int, auto=True) + name = orm.Required(str, 40) + lastName = orm.Required(str, max_len=40, unique=True) + age = orm.Optional(int) + groupName = orm.Optional('Group') + chiefOfGroup = orm.Optional('Group') + +class Group(db.Entity): + name = orm.Required(str) + persons = orm.Set(Person) + chief = orm.Optional(Person, reverse='chiefOfGroup') + +db.generate_mapping(create_tables=True) + +class TestEntityInstances(unittest.TestCase): + + def setUp(self): + rollback() + db_session.__enter__() + + def tearDown(self): + rollback() + db_session.__exit__() + + def test_create_instance(self): + with orm.db_session: + Person(id=1, name='Philip', lastName='Croissan') + Person(id=2, name='Philip', lastName='Parlee', age=40) + Person(id=3, name='Philip', lastName='Illinois', age=50) + commit() + + def test_getObjectByPK(self): + self.assertEqual(Person[1].lastName, "Croissan") + + @raises_exception(ObjectNotFound , "Person[666]") + def test_getObjectByPKexception(self): + p = Person[666] + + def test_getObjectByGet(self): + p = Person.get(age=40) + self.assertEqual(p.lastName, "Parlee") + + def test_getObjectByGetNone(self): + self.assertIsNone(Person.get(age=41)) + + @raises_exception(MultipleObjectsFoundError , 'Multiple objects were found.' + ' Use Person.select(...) to retrieve them') + def test_getObjectByGetException(self): + p = Person.get(name="Philip") + + def test_updateObject(self): + with db_session: + Person[2].age=42 + self.assertEqual(Person[2].age, 42) + commit() + + @raises_exception(ObjectNotFound, 'Person[2]') + def test_deleteObject(self): + with db_session: + Person[2].delete() + p = Person[2] + + def test_bulkDelete(self): + with orm.db_session: + Person(id=4, name='Klaus', lastName='Mem', age=12) + Person(id=5, name='Abraham', lastName='Wrangler', age=13) + Person(id=6, name='Kira', lastName='Phito', age=20) + delete(p for p in Person if p.age <= 20) + self.assertEqual(select(p for p in Person if p.age <= 20).count(), 0) + + def test_bulkDeleteV2(self): + with orm.db_session: + Person(id=4, name='Klaus', lastName='Mem', age=12) + Person(id=5, name='Abraham', lastName='Wrangler', age=13) + Person(id=6, name='Kira', lastName='Phito', age=20) + Person.select(lambda p: p.id >= 4).delete(bulk=True) + self.assertEqual(select(p for p in Person if p.id >= 4).count(), 0) + + @raises_exception(UnresolvableCyclicDependency, 'Cannot save cyclic chain: Person -> Group') + def test_saveChainsException(self): + with orm.db_session: + claire = Person(name='Claire', lastName='Forlani') + annabel = Person(name='Annabel', lastName='Fiji') + Group(name='Aspen', persons=[claire, annabel], chief=claire) + print('group1=', Group[1]) + + def test_saveChainsWithFlush(self): + with orm.db_session: + claire = Person(name='Claire', lastName='Forlani') + annabel = Person(name='Annabel', lastName='Fiji') + flush() + Group(name='Aspen', persons=[claire, annabel], chief=claire) + self.assertEqual(Group[1].name, 'Aspen') + self.assertEqual(Group[1].chief.lastName, 'Forlani') \ No newline at end of file From a592b96611953d012140c25d34d1715482519979 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 21 Oct 2019 23:51:04 +0300 Subject: [PATCH 501/547] raises_exception should correctly handle ... in the middle of the test message --- pony/orm/tests/testutils.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/pony/orm/tests/testutils.py b/pony/orm/tests/testutils.py index 13f507fff..5b28c644d 100644 --- a/pony/orm/tests/testutils.py +++ b/pony/orm/tests/testutils.py @@ -1,7 +1,7 @@ from __future__ import absolute_import, print_function, division from pony.py23compat import basestring -from functools import wraps +import re from contextlib import contextmanager from pony.orm.core import Database @@ -10,16 +10,18 @@ def test_exception_msg(test_case, exc_msg, test_msg=None): if test_msg is None: return error_template = "incorrect exception message. expected '%s', got '%s'" + error_msg = error_template % (test_msg, exc_msg) assert test_msg not in ('...', '....', '.....', '......') - if test_msg.startswith('...'): - if test_msg.endswith('...'): - test_case.assertIn(test_msg[3:-3], exc_msg, error_template % (test_msg, exc_msg)) - else: - test_case.assertTrue(exc_msg.endswith(test_msg[3:]), error_template % (test_msg, exc_msg)) - elif test_msg.endswith('...'): - test_case.assertTrue(exc_msg.startswith(test_msg[:-3]), error_template % (test_msg, exc_msg)) + if '...' not in test_msg: + test_case.assertEqual(test_msg, exc_msg, error_msg) else: - test_case.assertEqual(exc_msg, test_msg, error_template % (test_msg, exc_msg)) + pattern = ''.join( + '[%s]' % char for char in test_msg.replace('\\', '\\\\') + .replace('[', '\\[') + ).replace('[.][.][.]', '.*') + regex = re.compile(pattern) + if not regex.match(exc_msg): + test_case.fail(error_template % (test_msg, exc_msg)) def raises_exception(exc_class, test_msg=None): def decorator(func): From ca9a62c74d9d5707dc1a9deef0b1a29423262252 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 21 Oct 2019 23:51:22 +0300 Subject: [PATCH 502/547] Test fixed --- pony/orm/tests/test_attribute_options.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/tests/test_attribute_options.py b/pony/orm/tests/test_attribute_options.py index 017637c84..cf8d79591 100644 --- a/pony/orm/tests/test_attribute_options.py +++ b/pony/orm/tests/test_attribute_options.py @@ -98,7 +98,7 @@ def test_py_checkMoreException(self): def test_py_checkLessException(self): p6 = Person(name='Daniel', lastName='Craig', gpa=-1) - @raises_exception(TransactionIntegrityError, 'Object Person[new:8] cannot be stored in the database.' + @raises_exception(TransactionIntegrityError, 'Object Person[new:...] cannot be stored in the database.' ' IntegrityError: UNIQUE constraint failed: Person.lastName') def test_unique(self): p6 = Person(name='Boris', lastName='Bodroue') From 2345f313eb1b78f050f2c5c0add711006b166fbc Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 22 Oct 2019 14:42:58 +0300 Subject: [PATCH 503/547] Revert change from 52c88b6e "Forcing of n+1 query optimization during cascade delete", add tests on cascade delete --- pony/orm/core.py | 3 +- pony/orm/tests/test_cascade_delete.py | 66 +++++++++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 pony/orm/tests/test_cascade_delete.py diff --git a/pony/orm/core.py b/pony/orm/core.py index 7de87dbce..6b6628c93 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -2914,8 +2914,7 @@ def load(attr, obj, items=None): counter = cache.collection_statistics.setdefault(attr, 0) nplus1_threshold = attr.nplus1_threshold - prefetching = not attr.lazy and nplus1_threshold is not None \ - and (counter >= nplus1_threshold or cache.noflush_counter) + prefetching = not attr.lazy and nplus1_threshold is not None and counter >= nplus1_threshold objects = [ obj ] setdata_list = [ setdata ] diff --git a/pony/orm/tests/test_cascade_delete.py b/pony/orm/tests/test_cascade_delete.py new file mode 100644 index 000000000..e5974fc45 --- /dev/null +++ b/pony/orm/tests/test_cascade_delete.py @@ -0,0 +1,66 @@ +import unittest + +from pony.orm import * + +db = Database('sqlite', ':memory:') + +class X(db.Entity): + id = PrimaryKey(int) + parent = Optional('X', reverse='children') + children = Set('X', reverse='parent', cascade_delete=True) + +class Y(db.Entity): + parent = Optional('Y', reverse='children') + children = Set('Y', reverse='parent', cascade_delete=True, lazy=True) + + +db.generate_mapping(create_tables=True) + +with db_session: + x1 = X(id=1) + x2 = X(id=2, parent=x1) + x3 = X(id=3, parent=x1) + x4 = X(id=4, parent=x3) + x5 = X(id=5, parent=x3) + x6 = X(id=6, parent=x5) + x7 = X(id=7, parent=x3) + x8 = X(id=8, parent=x7) + x9 = X(id=9, parent=x7) + x10 = X(id=10) + x11 = X(id=11, parent=x10) + x12 = X(id=12, parent=x10) + + y1 = Y(id=1) + y2 = Y(id=2, parent=y1) + y3 = Y(id=3, parent=y1) + y4 = Y(id=4, parent=y3) + y5 = Y(id=5, parent=y3) + y6 = Y(id=6, parent=y5) + y7 = Y(id=7, parent=y3) + y8 = Y(id=8, parent=y7) + y9 = Y(id=9, parent=y7) + y10 = Y(id=10) + y11 = Y(id=11, parent=y10) + y12 = Y(id=12, parent=y10) + +class TestCascade(unittest.TestCase): + + def setUp(self): + rollback() + db_session.__enter__() + + def tearDown(self): + rollback() + db_session.__exit__() + + def test_1(self): + db.merge_local_stats() + X[1].delete() + stats = db.local_stats[None] + self.assertEqual(5, stats.db_count) + + def test_2(self): + db.merge_local_stats() + Y[1].delete() + stats = db.local_stats[None] + self.assertEqual(10, stats.db_count) From 0ffe60f2c8104e253687fcf3150a14f5ef8a543e Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 23 Oct 2019 15:58:21 +0300 Subject: [PATCH 504/547] Fix #438: support date-date, datetime-datetime, date+timedelta, date-timedelta, datetime+timedelta, datetime-timedelta in queries --- pony/orm/dbproviders/mysql.py | 36 ++++--- pony/orm/dbproviders/oracle.py | 15 +-- pony/orm/dbproviders/postgres.py | 18 ++-- pony/orm/dbproviders/sqlite.py | 36 +++++-- pony/orm/sqlbuilding.py | 37 ++++--- pony/orm/sqltranslation.py | 41 ++++++-- pony/orm/tests/queries.txt | 112 ++++++++++++++++++++ pony/orm/tests/sql_tests.py | 2 + pony/orm/tests/test_datetime.py | 134 ++++++++++++++++++++++++ pony/orm/tests/test_declarative_date.py | 61 ----------- 10 files changed, 372 insertions(+), 120 deletions(-) create mode 100644 pony/orm/tests/test_datetime.py delete mode 100644 pony/orm/tests/test_declarative_date.py diff --git a/pony/orm/dbproviders/mysql.py b/pony/orm/dbproviders/mysql.py index 3170daa60..aacf6d44c 100644 --- a/pony/orm/dbproviders/mysql.py +++ b/pony/orm/dbproviders/mysql.py @@ -47,8 +47,20 @@ class MySQLTranslator(SQLTranslator): dialect = 'MySQL' json_path_wildcard_syntax = True +class MySQLValue(Value): + __slots__ = [] + def __unicode__(self): + value = self.value + if isinstance(value, timedelta): + if value.microseconds: + return "INTERVAL '%s' HOUR_MICROSECOND" % timedelta2str(value) + return "INTERVAL '%s' HOUR_SECOND" % timedelta2str(value) + return Value.__unicode__(self) + if not PY2: __str__ = __unicode__ + class MySQLBuilder(SQLBuilder): dialect = 'MySQL' + value_class = MySQLValue def CONCAT(builder, *args): return 'concat(', join(', ', imap(builder, args)), ')' def TRIM(builder, expr, chars=None): @@ -79,21 +91,21 @@ def MINUTE(builder, expr): def SECOND(builder, expr): return 'second(', builder(expr), ')' def DATE_ADD(builder, expr, delta): - if isinstance(delta, timedelta): - return 'DATE_ADD(', builder(expr), ", INTERVAL '", timedelta2str(delta), "' HOUR_SECOND)" - return 'ADDTIME(', builder(expr), ', ', builder(delta), ')' + if delta[0] == 'VALUE' and isinstance(delta[1], time): + return 'ADDTIME(', builder(expr), ', ', builder(delta), ')' + return 'ADDDATE(', builder(expr), ', ', builder(delta), ')' def DATE_SUB(builder, expr, delta): - if isinstance(delta, timedelta): - return 'DATE_SUB(', builder(expr), ", INTERVAL '", timedelta2str(delta), "' HOUR_SECOND)" - return 'SUBTIME(', builder(expr), ', ', builder(delta), ')' + if delta[0] == 'VALUE' and isinstance(delta[1], time): + return 'SUBTIME(', builder(expr), ', ', builder(delta), ')' + return 'SUBDATE(', builder(expr), ', ', builder(delta), ')' + def DATE_DIFF(builder, expr1, expr2): + return 'TIMEDIFF(', builder(expr1), ', ', builder(expr2), ')' def DATETIME_ADD(builder, expr, delta): - if isinstance(delta, timedelta): - return 'DATE_ADD(', builder(expr), ", INTERVAL '", timedelta2str(delta), "' HOUR_SECOND)" - return 'ADDTIME(', builder(expr), ', ', builder(delta), ')' + return builder.DATE_ADD(expr, delta) def DATETIME_SUB(builder, expr, delta): - if isinstance(delta, timedelta): - return 'DATE_SUB(', builder(expr), ", INTERVAL '", timedelta2str(delta), "' HOUR_SECOND)" - return 'SUBTIME(', builder(expr), ', ', builder(delta), ')' + return builder.DATE_SUB(expr, delta) + def DATETIME_DIFF(builder, expr1, expr2): + return 'TIMEDIFF(', builder(expr1), ', ', builder(expr2), ')' def JSON_QUERY(builder, expr, path): path_sql, has_params, has_wildcards = builder.build_json_path(path) return 'json_extract(', builder(expr), ', ', path_sql, ')' diff --git a/pony/orm/dbproviders/oracle.py b/pony/orm/dbproviders/oracle.py index d4f338be0..db38e01fb 100644 --- a/pony/orm/dbproviders/oracle.py +++ b/pony/orm/dbproviders/oracle.py @@ -15,10 +15,9 @@ from pony.orm.core import log_orm, log_sql, DatabaseError, TranslationError from pony.orm.dbschema import DBSchema, DBObject, Table, Column from pony.orm.ormtypes import Json -from pony.orm.sqlbuilding import SQLBuilder, Value +from pony.orm.sqlbuilding import SQLBuilder from pony.orm.dbapiprovider import DBAPIProvider, wrap_dbapi_exceptions, get_version_tuple from pony.utils import throw, is_ident -from pony.converting import timedelta2str NoneType = type(None) @@ -218,21 +217,17 @@ def RANDOM(builder): def MOD(builder, a, b): return 'MOD(', builder(a), ', ', builder(b), ')' def DATE_ADD(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " + INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' + ', builder(delta), ')' def DATE_SUB(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' + def DATE_DIFF(builder, expr1, expr2): + return builder(expr1), ' - ', builder(expr2) def DATETIME_ADD(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " + INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' + ', builder(delta), ')' def DATETIME_SUB(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' HOUR TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' + def DATETIME_DIFF(builder, expr1, expr2): + return builder(expr1), ' - ', builder(expr2) def build_json_path(builder, path): path_sql, has_params, has_wildcards = SQLBuilder.build_json_path(builder, path) if has_params: throw(TranslationError, "Oracle doesn't allow parameters in JSON paths") diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 96dbfece1..9477d8fe1 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -47,9 +47,11 @@ class PGValue(Value): __slots__ = [] def __unicode__(self): value = self.value - if isinstance(value, bool): return value and 'true' or 'false' + if isinstance(value, bool): + return value and 'true' or 'false' return Value.__unicode__(self) - if not PY2: __str__ = __unicode__ + if not PY2: + __str__ = __unicode__ class PGSQLBuilder(SQLBuilder): dialect = 'PostgreSQL' @@ -70,21 +72,17 @@ def DATE(builder, expr): def RANDOM(builder): return 'random()' def DATE_ADD(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " + INTERVAL '", timedelta2str(delta), "' DAY TO SECOND)" return '(', builder(expr), ' + ', builder(delta), ')' def DATE_SUB(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' DAY TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' + def DATE_DIFF(builder, expr1, expr2): + return builder(expr1), ' - ', builder(expr2) def DATETIME_ADD(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " + INTERVAL '", timedelta2str(delta), "' DAY TO SECOND)" return '(', builder(expr), ' + ', builder(delta), ')' def DATETIME_SUB(builder, expr, delta): - if isinstance(delta, timedelta): - return '(', builder(expr), " - INTERVAL '", timedelta2str(delta), "' DAY TO SECOND)" return '(', builder(expr), ' - ', builder(delta), ')' + def DATETIME_DIFF(builder, expr1, expr2): + return builder(expr1), ' - ', builder(expr2) def eval_json_path(builder, values): result = [] for value in values: diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index d9009e976..9f23c60f1 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -16,7 +16,7 @@ from pony.orm.core import log_orm from pony.orm.ormtypes import Json, TrackedArray from pony.orm.sqltranslation import SQLTranslator, StringExprMonad -from pony.orm.sqlbuilding import SQLBuilder, join, make_unary_func +from pony.orm.sqlbuilding import SQLBuilder, Value, join, make_unary_func from pony.orm.dbapiprovider import DBAPIProvider, Pool, wrap_dbapi_exceptions from pony.utils import datetime2timestamp, timestamp2datetime, absolutize_path, localbase, throw, reraise, \ cut_traceback_depth @@ -54,10 +54,24 @@ class SQLiteTranslator(SQLTranslator): StringMixin_UPPER = make_overriden_string_func('PY_UPPER') StringMixin_LOWER = make_overriden_string_func('PY_LOWER') +class SQLiteValue(Value): + __slots__ = [] + def __unicode__(self): + value = self.value + if isinstance(value, datetime): + return self.quote_str(datetime2timestamp(value)) + if isinstance(value, date): + return self.quote_str(str(value)) + if isinstance(value, timedelta): + return repr(value.total_seconds() / (24 * 60 * 60)) + return Value.__unicode__(self) + if not PY2: __str__ = __unicode__ + class SQLiteBuilder(SQLBuilder): dialect = 'SQLite' least_func_name = 'min' greatest_func_name = 'max' + value_class = SQLiteValue def __init__(builder, provider, ast): builder.json1_available = provider.json1_available SQLBuilder.__init__(builder, provider, ast) @@ -106,21 +120,25 @@ def datetime_add(builder, funcname, expr, td): if not modifiers: return builder(expr) return funcname, '(', builder(expr), modifiers, ')' def DATE_ADD(builder, expr, delta): - if isinstance(delta, timedelta): - return builder.datetime_add('date', expr, delta) + if delta[0] == 'VALUE' and isinstance(delta[1], timedelta): + return builder.datetime_add('date', expr, delta[1]) return 'datetime(julianday(', builder(expr), ') + ', builder(delta), ')' def DATE_SUB(builder, expr, delta): - if isinstance(delta, timedelta): - return builder.datetime_add('date', expr, -delta) + if delta[0] == 'VALUE' and isinstance(delta[1], timedelta): + return builder.datetime_add('date', expr, -delta[1]) return 'datetime(julianday(', builder(expr), ') - ', builder(delta), ')' + def DATE_DIFF(builder, expr1, expr2): + return 'julianday(', builder(expr1), ') - julianday(', builder(expr2), ')' def DATETIME_ADD(builder, expr, delta): - if isinstance(delta, timedelta): - return builder.datetime_add('datetime', expr, delta) + if delta[0] == 'VALUE' and isinstance(delta[1], timedelta): + return builder.datetime_add('datetime', expr, delta[1]) return 'datetime(julianday(', builder(expr), ') + ', builder(delta), ')' def DATETIME_SUB(builder, expr, delta): - if isinstance(delta, timedelta): - return builder.datetime_add('datetime', expr, -delta) + if delta[0] == 'VALUE' and isinstance(delta[1], timedelta): + return builder.datetime_add('datetime', expr, -delta[1]) return 'datetime(julianday(', builder(expr), ') - ', builder(delta), ')' + def DATETIME_DIFF(builder, expr1, expr2): + return 'julianday(', builder(expr1), ') - julianday(', builder(expr2), ')' def RANDOM(builder): return 'rand()' # return '(random() / 9223372036854775807.0 + 1.0) / 2.0' PY_UPPER = make_unary_func('py_upper') diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index ce6611147..f747b94d4 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -3,11 +3,12 @@ from operator import attrgetter from decimal import Decimal -from datetime import date, datetime +from datetime import date, datetime, timedelta from binascii import hexlify from pony import options from pony.utils import datetime2timestamp, throw, is_ident +from pony.converting import timedelta2str from pony.orm.ormtypes import RawSQL, Json class AstError(Exception): pass @@ -68,19 +69,31 @@ def __init__(self, paramstyle, value): self.value = value def __unicode__(self): value = self.value - if value is None: return 'null' - if isinstance(value, bool): return value and '1' or '0' - if isinstance(value, basestring): return self.quote_str(value) - if isinstance(value, datetime): return self.quote_str(datetime2timestamp(value)) - if isinstance(value, date): return self.quote_str(str(value)) + if value is None: + return 'null' + if isinstance(value, bool): + return value and '1' or '0' + if isinstance(value, basestring): + return self.quote_str(value) + if isinstance(value, datetime): + return 'TIMESTAMP ' + self.quote_str(datetime2timestamp(value)) + if isinstance(value, date): + return 'DATE ' + self.quote_str(str(value)) + if isinstance(value, timedelta): + return "INTERVAL '%s' HOUR TO SECOND" % timedelta2str(value) if PY2: - if isinstance(value, (int, long, float, Decimal)): return str(value) - if isinstance(value, buffer): return "X'%s'" % hexlify(value) + if isinstance(value, (int, long, float, Decimal)): + return str(value) + if isinstance(value, buffer): + return "X'%s'" % hexlify(value) else: - if isinstance(value, (int, float, Decimal)): return str(value) - if isinstance(value, bytes): return "X'%s'" % hexlify(value).decode('ascii') - assert False, value # pragma: no cover - if not PY2: __str__ = __unicode__ + if isinstance(value, (int, float, Decimal)): + return str(value) + if isinstance(value, bytes): + return "X'%s'" % hexlify(value).decode('ascii') + assert False, repr(value) # pragma: no cover + if not PY2: + __str__ = __unicode__ def __repr__(self): return '%s(%r)' % (self.__class__.__name__, self.value) def quote_str(self, s): diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index faabda085..d24039cf0 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -1746,8 +1746,7 @@ def datetime_binop(monad, monad2): if monad2.type != timedelta: throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(monad2.type), op)) expr_monad_cls = DateExprMonad if monad.type is date else DatetimeExprMonad - delta = monad2.value if isinstance(monad2, TimedeltaConstMonad) else monad2.getsql()[0] - return expr_monad_cls(monad.type, [ sqlop, monad.getsql()[0], delta ], + return expr_monad_cls(monad.type, [ sqlop, monad.getsql()[0], monad2.getsql()[0] ], nullable=monad.nullable or monad2.nullable) datetime_binop.__name__ = sqlop return datetime_binop @@ -1755,11 +1754,26 @@ def datetime_binop(monad, monad2): class DateMixin(MonadMixin): def mixin_init(monad): assert monad.type is date + attr_year = numeric_attr_factory('YEAR') attr_month = numeric_attr_factory('MONTH') attr_day = numeric_attr_factory('DAY') - __add__ = make_datetime_binop('+', 'DATE_ADD') - __sub__ = make_datetime_binop('-', 'DATE_SUB') + + def __add__(monad, other): + if other.type != timedelta: + throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(other.type), '+')) + return DateExprMonad(monad.type, [ 'DATE_ADD', monad.getsql()[0], other.getsql()[0] ], + nullable=monad.nullable or other.nullable) + + def __sub__(monad, other): + if other.type == timedelta: + return DateExprMonad(monad.type, [ 'DATE_SUB', monad.getsql()[0], other.getsql()[0] ], + nullable=monad.nullable or other.nullable) + elif other.type == date: + return TimedeltaExprMonad(timedelta, [ 'DATE_DIFF', monad.getsql()[0], other.getsql()[0] ], + nullable=monad.nullable or other.nullable) + throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(other.type), '-')) + class TimeMixin(MonadMixin): def mixin_init(monad): @@ -1775,14 +1789,29 @@ def mixin_init(monad): class DatetimeMixin(DateMixin): def mixin_init(monad): assert monad.type is datetime + def call_date(monad): sql = [ 'DATE', monad.getsql()[0] ] return ExprMonad.new(date, sql, nullable=monad.nullable) + attr_hour = numeric_attr_factory('HOUR') attr_minute = numeric_attr_factory('MINUTE') attr_second = numeric_attr_factory('SECOND') - __add__ = make_datetime_binop('+', 'DATETIME_ADD') - __sub__ = make_datetime_binop('-', 'DATETIME_SUB') + + def __add__(monad, other): + if other.type != timedelta: + throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(other.type), '+')) + return DatetimeExprMonad(monad.type, [ 'DATETIME_ADD', monad.getsql()[0], other.getsql()[0] ], + nullable=monad.nullable or other.nullable) + + def __sub__(monad, other): + if other.type == timedelta: + return DatetimeExprMonad(monad.type, [ 'DATETIME_SUB', monad.getsql()[0], other.getsql()[0] ], + nullable=monad.nullable or other.nullable) + elif other.type == datetime: + return TimedeltaExprMonad(timedelta, [ 'DATETIME_DIFF', monad.getsql()[0], other.getsql()[0] ], + nullable=monad.nullable or other.nullable) + throw(TypeError, _binop_errmsg % (type2str(monad.type), type2str(other.type), '-')) def make_string_binop(op, sqlop): def string_binop(monad, monad2): diff --git a/pony/orm/tests/queries.txt b/pony/orm/tests/queries.txt index 877a05ca9..1b754d782 100644 --- a/pony/orm/tests/queries.txt +++ b/pony/orm/tests/queries.txt @@ -1071,3 +1071,115 @@ FROM "STUDENT" "s1", "STUDENT" "s2" WHERE "s1"."NAME" = "s2"."NAME" AND "s1"."GPA" = "s2"."GPA" AND "s1"."TEL" = "s2"."TEL" + + +# Test date operations: + + +>>> select(s for s in Student if s.dob + timedelta(days=100) < date(2010, 1, 1)) + +SQLite: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "Student" "s" +WHERE date("s"."dob", '+100 days') < '2010-01-01' + +PostgreSQL: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "student" "s" +WHERE ("s"."dob" + INTERVAL '2400:0:0' HOUR TO SECOND) < DATE '2010-01-01' + +Oracle: + +SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" +FROM "STUDENT" "s" +WHERE ("s"."DOB" + INTERVAL '2400:0:0' HOUR TO SECOND) < DATE '2010-01-01' + +MySQL: + +SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` +FROM `student` `s` +WHERE ADDDATE(`s`.`dob`, INTERVAL '2400:0:0' HOUR_SECOND) < DATE '2010-01-01' + + +>>> td = timedelta(days=100) +>>> select(s for s in Student if s.dob + td < date(2010, 1, 1)) + +SQLite: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "Student" "s" +WHERE datetime(julianday("s"."dob") + ?) < '2010-01-01' + +PostgreSQL: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "student" "s" +WHERE ("s"."dob" + %(p1)s) < DATE '2010-01-01' + +Oracle: + +SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" +FROM "STUDENT" "s" +WHERE ("s"."DOB" + :p1) < DATE '2010-01-01' + +MySQL: + +SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` +FROM `student` `s` +WHERE ADDDATE(`s`.`dob`, %s) < DATE '2010-01-01' + + +>>> select(s for s in Student if s.dob - timedelta(days=100) < date(2010, 1, 1)) + +SQLite: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "Student" "s" +WHERE date("s"."dob", '-100 days') < '2010-01-01' + +PostgreSQL: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "student" "s" +WHERE ("s"."dob" - INTERVAL '2400:0:0' HOUR TO SECOND) < DATE '2010-01-01' + +Oracle: + +SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" +FROM "STUDENT" "s" +WHERE ("s"."DOB" - INTERVAL '2400:0:0' HOUR TO SECOND) < DATE '2010-01-01' + +MySQL: + +SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` +FROM `student` `s` +WHERE SUBDATE(`s`.`dob`, INTERVAL '2400:0:0' HOUR_SECOND) < DATE '2010-01-01' + +>>> td = timedelta(days=100) +>>> select(s for s in Student if s.dob - td < date(2010, 1, 1)) + +SQLite: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "Student" "s" +WHERE datetime(julianday("s"."dob") - ?) < '2010-01-01' + +PostgreSQL: + +SELECT "s"."id", "s"."name", "s"."dob", "s"."tel", "s"."gpa", "s"."group" +FROM "student" "s" +WHERE ("s"."dob" - %(p1)s) < DATE '2010-01-01' + +Oracle: + +SELECT "s"."ID", "s"."NAME", "s"."DOB", "s"."TEL", "s"."GPA", "s"."GROUP" +FROM "STUDENT" "s" +WHERE ("s"."DOB" - :p1) < DATE '2010-01-01' + +MySQL: + +SELECT `s`.`id`, `s`.`name`, `s`.`dob`, `s`.`tel`, `s`.`gpa`, `s`.`group` +FROM `student` `s` +WHERE SUBDATE(`s`.`dob`, %s) < DATE '2010-01-01' diff --git a/pony/orm/tests/sql_tests.py b/pony/orm/tests/sql_tests.py index 668293ece..ec68a82d5 100644 --- a/pony/orm/tests/sql_tests.py +++ b/pony/orm/tests/sql_tests.py @@ -2,6 +2,7 @@ from pony.py23compat import PY2 import re, os, os.path, sys +from datetime import datetime, timedelta from pony import orm from pony.orm import core @@ -55,6 +56,7 @@ def do_test(provider_name, raw_server_version): return module = sys.modules[module_name] globals = vars(module).copy() + globals.update(datetime=datetime, timedelta=timedelta) with orm.db_session: for statement in statements[:-1]: code = compile(statement, '', 'exec') diff --git a/pony/orm/tests/test_datetime.py b/pony/orm/tests/test_datetime.py new file mode 100644 index 000000000..42448efb5 --- /dev/null +++ b/pony/orm/tests/test_datetime.py @@ -0,0 +1,134 @@ +from __future__ import absolute_import, print_function, division +from pony.py23compat import PY2 + +import unittest +from datetime import date, datetime, timedelta + +from pony.orm.core import * +from pony.orm.tests.testutils import * + +db = Database('sqlite', ':memory:') + +class Entity1(db.Entity): + id = PrimaryKey(int) + d = Required(date) + dt = Required(datetime) + +db.generate_mapping(create_tables=True) + +with db_session: + Entity1(id=1, d=date(2009, 10, 20), dt=datetime(2009, 10, 20, 10, 20, 30)) + Entity1(id=2, d=date(2010, 10, 21), dt=datetime(2010, 10, 21, 10, 21, 31)) + Entity1(id=3, d=date(2011, 11, 22), dt=datetime(2011, 11, 22, 10, 20, 32)) + +class TestDate(unittest.TestCase): + def setUp(self): + rollback() + db_session.__enter__() + + def tearDown(self): + rollback() + db_session.__exit__() + + def test_create(self): + e1 = Entity1(id=4, d=date(2011, 10, 20), dt=datetime(2009, 10, 20, 10, 20, 30)) + + def test_date_year(self): + result = select(e for e in Entity1 if e.d.year > 2009) + self.assertEqual(len(result), 2) + + def test_date_month(self): + result = select(e for e in Entity1 if e.d.month == 10) + self.assertEqual(len(result), 2) + + def test_date_day(self): + result = select(e for e in Entity1 if e.d.day == 22) + self.assertEqual(len(result), 1) + + def test_datetime_year(self): + result = select(e for e in Entity1 if e.dt.year > 2009) + self.assertEqual(len(result), 2) + + def test_datetime_month(self): + result = select(e for e in Entity1 if e.dt.month == 10) + self.assertEqual(len(result), 2) + + def test_datetime_day(self): + result = select(e for e in Entity1 if e.dt.day == 22) + self.assertEqual(len(result), 1) + + def test_datetime_hour(self): + result = select(e for e in Entity1 if e.dt.hour == 10) + self.assertEqual(len(result), 3) + + def test_datetime_minute(self): + result = select(e for e in Entity1 if e.dt.minute == 20) + self.assertEqual(len(result), 2) + + def test_datetime_second(self): + result = select(e for e in Entity1 if e.dt.second == 30) + self.assertEqual(len(result), 1) + + def test_date_sub_date(self): + dt = date(2012, 1, 1) + result = select(e.id for e in Entity1 if dt - e.d > timedelta(days=500)) + self.assertEqual(set(result), {1}) + + def test_datetime_sub_datetime(self): + dt = datetime(2012, 1, 1, 10, 20, 30) + result = select(e.id for e in Entity1 if dt - e.dt > timedelta(days=500)) + self.assertEqual(set(result), {1}) + + def test_date_sub_timedelta_param(self): + td = timedelta(days=500) + result = select(e.id for e in Entity1 if e.d - td < date(2009, 1, 1)) + self.assertEqual(set(result), {1}) + + def test_date_sub_const_timedelta(self): + result = select(e.id for e in Entity1 if e.d - timedelta(days=500) < date(2009, 1, 1)) + self.assertEqual(set(result), {1}) + + def test_datetime_sub_timedelta_param(self): + td = timedelta(days=500) + result = select(e.id for e in Entity1 if e.dt - td < datetime(2009, 1, 1, 10, 20, 30)) + self.assertEqual(set(result), {1}) + + def test_datetime_sub_const_timedelta(self): + result = select(e.id for e in Entity1 if e.dt - timedelta(days=500) < datetime(2009, 1, 1, 10, 20, 30)) + self.assertEqual(set(result), {1}) + + def test_date_add_timedelta_param(self): + td = timedelta(days=500) + result = select(e.id for e in Entity1 if e.d + td > date(2013, 1, 1)) + self.assertEqual(set(result), {3}) + + def test_date_add_const_timedelta(self): + result = select(e.id for e in Entity1 if e.d + timedelta(days=500) > date(2013, 1, 1)) + self.assertEqual(set(result), {3}) + + def test_datetime_add_timedelta_param(self): + td = timedelta(days=500) + result = select(e.id for e in Entity1 if e.dt + td > date(2013, 1, 1)) + self.assertEqual(set(result), {3}) + + def test_datetime_add_const_timedelta(self): + result = select(e.id for e in Entity1 if e.dt + timedelta(days=500) > date(2013, 1, 1)) + self.assertEqual(set(result), {3}) + + @raises_exception(TypeError, "Unsupported operand types 'date' and '%s' " + "for operation '-' in expression: e.d - s" % ('unicode' if PY2 else 'str')) + def test_date_sub_error(self): + s = 'hello' + result = select(e.id for e in Entity1 if e.d - s > timedelta(days=500)) + self.assertEqual(set(result), {1}) + + @raises_exception(TypeError, "Unsupported operand types 'datetime' and '%s' " + "for operation '-' in expression: e.dt - s" % ('unicode' if PY2 else 'str')) + def test_datetime_sub_error(self): + s = 'hello' + result = select(e.id for e in Entity1 if e.dt - s > timedelta(days=500)) + self.assertEqual(set(result), {1}) + + +if __name__ == '__main__': + unittest.main() diff --git a/pony/orm/tests/test_declarative_date.py b/pony/orm/tests/test_declarative_date.py deleted file mode 100644 index b7a89541f..000000000 --- a/pony/orm/tests/test_declarative_date.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import absolute_import, print_function, division - -import unittest -from datetime import date, datetime - -from pony.orm.core import * -from pony.orm.tests.testutils import * - -db = Database('sqlite', ':memory:') - -class Entity1(db.Entity): - a = PrimaryKey(int) - b = Required(date) - c = Required(datetime) - -db.generate_mapping(create_tables=True) - -with db_session: - Entity1(a=1, b=date(2009, 10, 20), c=datetime(2009, 10, 20, 10, 20, 30)) - Entity1(a=2, b=date(2010, 10, 21), c=datetime(2010, 10, 21, 10, 21, 31)) - Entity1(a=3, b=date(2011, 11, 22), c=datetime(2011, 11, 22, 10, 20, 32)) - -class TestDate(unittest.TestCase): - def setUp(self): - rollback() - db_session.__enter__() - def tearDown(self): - rollback() - db_session.__exit__() - def test_create(self): - e1 = Entity1(a=4, b=date(2011, 10, 20), c=datetime(2009, 10, 20, 10, 20, 30)) - def test_date_year(self): - result = select(e for e in Entity1 if e.b.year > 2009) - self.assertEqual(len(result), 2) - def test_date_month(self): - result = select(e for e in Entity1 if e.b.month == 10) - self.assertEqual(len(result), 2) - def test_date_day(self): - result = select(e for e in Entity1 if e.b.day == 22) - self.assertEqual(len(result), 1) - def test_datetime_year(self): - result = select(e for e in Entity1 if e.c.year > 2009) - self.assertEqual(len(result), 2) - def test_datetime_month(self): - result = select(e for e in Entity1 if e.c.month == 10) - self.assertEqual(len(result), 2) - def test_datetime_day(self): - result = select(e for e in Entity1 if e.c.day == 22) - self.assertEqual(len(result), 1) - def test_datetime_hour(self): - result = select(e for e in Entity1 if e.c.hour == 10) - self.assertEqual(len(result), 3) - def test_datetime_minute(self): - result = select(e for e in Entity1 if e.c.minute == 20) - self.assertEqual(len(result), 2) - def test_datetime_second(self): - result = select(e for e in Entity1 if e.c.second == 30) - self.assertEqual(len(result), 1) - -if __name__ == '__main__': - unittest.main() From 9d008b9238f0af495c0d7a7cb785db20723ccf89 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 23 Oct 2019 21:50:12 +0300 Subject: [PATCH 505/547] SQLite does not allow to specify distinct and separator in group_concat at the same time --- pony/orm/sqltranslation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index d24039cf0..e3be1f621 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2818,6 +2818,8 @@ class FuncGroupConcatMonad(FuncMonad): func = utils.group_concat, core.group_concat def call(monad, x, sep=None, distinct=None): if sep is not None: + if distinct and monad.translator.database.provider.dialect == 'SQLite': + throw(TypeError, 'SQLite does not allow to specify distinct and separator in group_concat at the same time: {EXPR}') if not(isinstance(sep, StringConstMonad) and isinstance(sep.value, basestring)): throw(TypeError, '`sep` option of `group_concat` should be type of str. Got: %s' % ast2src(sep.node)) sep = sep.value From 66c29f56e89d5abad15c24d9b344fdb54d9a445b Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 23 Oct 2019 23:55:44 +0300 Subject: [PATCH 506/547] Update changelog and Pony version: 0.7.11-dev -> 0.7.11 --- CHANGELOG.md | 23 +++++++++++++++++++++++ pony/__init__.py | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 65d16d9c8..74912fc20 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,26 @@ +# PonyORM release 0.7.11 (2019-10-23) + +## Features + +* #472: Python 3.8 support +* Support of hybrid functions (inlining simple Python functions into query) +* #438: support datetime-datetime, datetime-timedelta, datetime+timedelta in queries + +## Bugfixes + +* #430: add ON DELETE CASCADE for many-to-many relationships +* #465: Should reconnect to MySQL on OperationalError 2013 'Lost connection to MySQL server during query' +* #468: Tuple-value comparisons generate incorrect queries +* #470 fix PendingDeprecationWarning of imp module +* Fix incorrect unpickling of objects with Json attributes +* Check value of discriminator column on object creation if set explicitly +* Correctly handle Flask current_user proxy when adding new items to collections +* Some bugs in syntax of aggregated queries were fixed +* Fix syntax of bulk delete queries +* Bulk delete queries should clear query results cache so next select will get correct result from the database +* Fix error message when hybrid method is too complex to decompile + + # PonyORM release 0.7.10 (2019-04-20) ## Bugfixes diff --git a/pony/__init__.py b/pony/__init__.py index f0f04fe91..48c218bff 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.11-dev' +__version__ = '0.7.11' uid = str(random.randint(1, 1000000)) From 21459cfdf1ca6e5167f0b9b0e3acb40aebaaf180 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 28 Oct 2019 18:18:13 +0300 Subject: [PATCH 507/547] Update Pony version: 0.7.11 -> 0.7.12-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index 48c218bff..59e66688e 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.11' +__version__ = '0.7.12-dev' uid = str(random.randint(1, 1000000)) From db29fb5308343fdd3840a639740f95af9812a65e Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Mon, 28 Oct 2019 18:19:02 +0300 Subject: [PATCH 508/547] Fix column definition when sql_default is specified --- pony/orm/dbschema.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index a2bc4e372..9b32c081b 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -219,19 +219,24 @@ def get_sql(column): result = [] append = result.append append(quote_name(column.name)) + + def add_default(): + if column.sql_default not in (None, True, False): + append(case('DEFAULT')) + append(column.sql_default) + if column.is_pk == 'auto' and column.auto_template and column.converter.py_type in int_types: append(case(column.auto_template % dict(type=column.sql_type))) + add_default() else: append(case(column.sql_type)) + add_default() if column.is_pk: if schema.dialect == 'SQLite': append(case('NOT NULL')) append(case('PRIMARY KEY')) else: if column.is_unique: append(case('UNIQUE')) if column.is_not_null: append(case('NOT NULL')) - if column.sql_default not in (None, True, False): - append(case('DEFAULT')) - append(column.sql_default) if schema.inline_fk_syntax and not schema.named_foreign_keys: foreign_key = table.foreign_keys.get((column,)) if foreign_key is not None: From 29285347c20b4b9d647e7c2989a8492bf7f6d42b Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 30 Oct 2019 18:56:38 +0300 Subject: [PATCH 509/547] Relax checks in cache.update_simple_index, cache.db_update_simple_index, cache.update_composite_index, cache.db_update_composite_index --- pony/orm/core.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 6b6628c93..b619fc3b1 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1936,7 +1936,7 @@ def _calc_modified_m2m(cache): cache.modified_collections.clear() return modified_m2m def update_simple_index(cache, obj, attr, old_val, new_val, undo): - assert old_val != new_val + if old_val == new_val: return cache_index = cache.indexes[attr] if new_val is not None: obj2 = cache_index.setdefault(new_val, obj) @@ -1945,7 +1945,7 @@ def update_simple_index(cache, obj, attr, old_val, new_val, undo): if old_val is not None: del cache_index[old_val] undo.append((cache_index, old_val, new_val)) def db_update_simple_index(cache, obj, attr, old_dbval, new_dbval): - assert old_dbval != new_dbval + if old_dbval == new_dbval: return cache_index = cache.indexes[attr] if new_dbval is not None: obj2 = cache_index.setdefault(new_dbval, obj) @@ -1955,10 +1955,10 @@ def db_update_simple_index(cache, obj, attr, old_dbval, new_dbval): # attribute which was created or updated lately clashes with one stored in database cache_index.pop(old_dbval, None) def update_composite_index(cache, obj, attrs, prev_vals, new_vals, undo): - assert prev_vals != new_vals if None in prev_vals: prev_vals = None if None in new_vals: new_vals = None if prev_vals is None and new_vals is None: return + if prev_vals == new_vals: return cache_index = cache.indexes[attrs] if new_vals is not None: obj2 = cache_index.setdefault(new_vals, obj) @@ -1969,7 +1969,7 @@ def update_composite_index(cache, obj, attrs, prev_vals, new_vals, undo): if prev_vals is not None: del cache_index[prev_vals] undo.append((cache_index, prev_vals, new_vals)) def db_update_composite_index(cache, obj, attrs, prev_vals, new_vals): - assert prev_vals != new_vals + if prev_vals == new_vals: return cache_index = cache.indexes[attrs] if None not in new_vals: obj2 = cache_index.setdefault(new_vals, obj) From e90d200ac5dca6f34f652055450a41fb40f15849 Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Wed, 29 Jan 2020 19:39:58 +0300 Subject: [PATCH 510/547] Update BACKERS.md --- BACKERS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/BACKERS.md b/BACKERS.md index 033085f51..cb5e8cb4f 100644 --- a/BACKERS.md +++ b/BACKERS.md @@ -14,3 +14,4 @@ Pony ORM is Apache 2.0 licensed open source project. If you would like to suppor - Christian Macht - Johnathan Nader - Andrei Rachalouski +- Juan Pablo Scaletti From a446c7f0e3229bd70c48e0d801b958037afc780b Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Wed, 29 Jan 2020 19:39:58 +0300 Subject: [PATCH 511/547] Update BACKERS.md --- BACKERS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/BACKERS.md b/BACKERS.md index 033085f51..cb5e8cb4f 100644 --- a/BACKERS.md +++ b/BACKERS.md @@ -14,3 +14,4 @@ Pony ORM is Apache 2.0 licensed open source project. If you would like to suppor - Christian Macht - Johnathan Nader - Andrei Rachalouski +- Juan Pablo Scaletti From 493a1197f2a4e6da37d97b3020859c005cda558a Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 3 Jan 2020 10:22:09 +0300 Subject: [PATCH 512/547] Fix deduplication --- pony/orm/core.py | 2 +- pony/utils/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index b619fc3b1..162388780 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1745,7 +1745,7 @@ def __init__(cache, database): cache.objects_to_save = [] cache.saved_objects = [] cache.query_results = {} - cache.dbvals_deduplication_cache = {} + cache.dbvals_deduplication_cache = defaultdict(dict) cache.modified = False cache.db_session = db_session = local.db_session cache.immediate = db_session is not None and db_session.immediate diff --git a/pony/utils/utils.py b/pony/utils/utils.py index 9aecef744..9c322e2a4 100644 --- a/pony/utils/utils.py +++ b/pony/utils/utils.py @@ -605,6 +605,6 @@ def deref_proxy(value): def deduplicate(value, deduplication_cache): t = type(value) try: - return deduplication_cache.setdefault(t, t).setdefault(value, value) + return deduplication_cache[t].setdefault(value, value) except: return value From 51876a1172754867a5e4b6682da1b4b12ebb273f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 3 Jan 2020 11:13:29 +0300 Subject: [PATCH 513/547] Fix determination of interactive mode in PyCharm --- pony/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index 59e66688e..3cd2176ef 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -12,7 +12,7 @@ def detect_mode(): try: import google.appengine except ImportError: pass else: - if os.environ.get('SERVER_SOFTWARE', '').startswith('Development'): + if os.getenv('SERVER_SOFTWARE', '').startswith('Development'): return 'GAE-LOCAL' return 'GAE-SERVER' @@ -25,6 +25,9 @@ def detect_mode(): if not hasattr(main, '__file__'): # console return 'INTERACTIVE' + if os.getenv('IPYTHONENABLE', '') == 'True': + return 'INTERACTIVE' + if getattr(main, 'INTERACTIVE_MODE_AVAILABLE', False): # pycharm console return 'INTERACTIVE' From ceecad7570ca1ea810c90e58e53dd88de1c6c311 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 5 Nov 2019 10:15:55 +0300 Subject: [PATCH 514/547] CockroachDB support added --- pony/orm/dbproviders/cockroach.py | 112 ++++++++++++++++++++++++++++++ pony/orm/examples/university1.py | 1 + pony/orm/sqltranslation.py | 8 ++- 3 files changed, 120 insertions(+), 1 deletion(-) create mode 100644 pony/orm/dbproviders/cockroach.py diff --git a/pony/orm/dbproviders/cockroach.py b/pony/orm/dbproviders/cockroach.py new file mode 100644 index 000000000..90dc7a8e2 --- /dev/null +++ b/pony/orm/dbproviders/cockroach.py @@ -0,0 +1,112 @@ +from __future__ import absolute_import +from pony.py23compat import PY2, basestring, unicode, buffer, int_types + +from decimal import Decimal +from datetime import datetime, date, time, timedelta +from uuid import UUID + +try: + import psycopg2 +except ImportError: + try: + from psycopg2cffi import compat + except ImportError: + raise ImportError('In order to use PonyORM with CockroachDB please install psycopg2 or psycopg2cffi') + else: + compat.register() + +from pony.orm.dbproviders.postgres import ( + PGSQLBuilder, PGColumn, PGSchema, PGTranslator, PGProvider, + PGStrConverter, PGIntConverter, PGRealConverter, + PGDatetimeConverter, PGTimedeltaConverter, + PGBlobConverter, PGJsonConverter, PGArrayConverter, +) + +from pony.orm import core, dbapiprovider, ormtypes +from pony.orm.core import log_orm +from pony.orm.dbapiprovider import wrap_dbapi_exceptions + +NoneType = type(None) + +class CRColumn(PGColumn): + auto_template = 'SERIAL PRIMARY KEY' + +class CRSchema(PGSchema): + column_class = CRColumn + +class CRTranslator(PGTranslator): + pass + +class CRSQLBuilder(PGSQLBuilder): + pass + +class CRIntConverter(PGIntConverter): + signed_types = {None: 'INT', 8: 'INT2', 16: 'INT2', 24: 'INT8', 32: 'INT8', 64: 'INT8'} + unsigned_types = {None: 'INT', 8: 'INT2', 16: 'INT4', 24: 'INT8', 32: 'INT8'} + # signed_types = {None: 'INT', 8: 'INT2', 16: 'INT2', 24: 'INT4', 32: 'INT4', 64: 'INT8'} + # unsigned_types = {None: 'INT', 8: 'INT2', 16: 'INT4', 24: 'INT4', 32: 'INT8'} + +class CRBlobConverter(PGBlobConverter): + def sql_type(converter): + return 'BYTES' + +class CRTimedeltaConverter(PGTimedeltaConverter): + sql_type_name = 'INTERVAL' + +class PGUuidConverter(dbapiprovider.UuidConverter): + def py2sql(converter, val): + return val + +class CRArrayConverter(PGArrayConverter): + array_types = { + int: ('INT', PGIntConverter), + unicode: ('STRING', PGStrConverter), + float: ('DOUBLE PRECISION', PGRealConverter) + } + +class CRProvider(PGProvider): + dbapi_module = psycopg2 + dbschema_cls = CRSchema + translator_cls = CRTranslator + sqlbuilder_cls = CRSQLBuilder + array_converter_cls = CRArrayConverter + + default_schema_name = 'public' + + fk_types = { 'SERIAL' : 'INT8' } + + def normalize_name(provider, name): + return name[:provider.max_name_len].lower() + + @wrap_dbapi_exceptions + def set_transaction_mode(provider, connection, cache): + assert not cache.in_transaction + if cache.immediate and connection.autocommit: + connection.autocommit = False + if core.local.debug: log_orm('SWITCH FROM AUTOCOMMIT TO TRANSACTION MODE') + db_session = cache.db_session + if db_session is not None and db_session.serializable: + pass + elif not cache.immediate and not connection.autocommit: + connection.autocommit = True + if core.local.debug: log_orm('SWITCH TO AUTOCOMMIT MODE') + if db_session is not None and (db_session.serializable or db_session.ddl): + cache.in_transaction = True + + converter_classes = [ + (NoneType, dbapiprovider.NoneConverter), + (bool, dbapiprovider.BoolConverter), + (basestring, PGStrConverter), + (int_types, CRIntConverter), + (float, PGRealConverter), + (Decimal, dbapiprovider.DecimalConverter), + (datetime, PGDatetimeConverter), + (date, dbapiprovider.DateConverter), + (time, dbapiprovider.TimeConverter), + (timedelta, CRTimedeltaConverter), + (UUID, PGUuidConverter), + (buffer, CRBlobConverter), + (ormtypes.Json, PGJsonConverter), + ] + +provider_cls = CRProvider diff --git a/pony/orm/examples/university1.py b/pony/orm/examples/university1.py index 81a1b91f9..b575743d2 100644 --- a/pony/orm/examples/university1.py +++ b/pony/orm/examples/university1.py @@ -46,6 +46,7 @@ class Student(db.Entity): sqlite=dict(provider='sqlite', filename='university1.sqlite', create_db=True), mysql=dict(provider='mysql', host="localhost", user="pony", passwd="pony", db="pony"), postgres=dict(provider='postgres', user='pony', password='pony', host='localhost', database='pony'), + cockroach=dict(provider='cockroach', user='root', host='localhost', port=26257, database='pony', sslmode='disable'), oracle=dict(provider='oracle', user='c##pony', password='pony', dsn='localhost/orcl') ) db.bind(**params['sqlite']) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index e3be1f621..b2fdf6308 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2753,10 +2753,13 @@ class FuncConcatMonad(FuncMonad): def call(monad, *args): if len(args) < 2: throw(TranslationError, 'concat() function requires at least two arguments') result_ast = [ 'CONCAT' ] + translator = monad.translator for arg in args: t = arg.type if isinstance(t, EntityMeta) or type(t) in (tuple, SetType): throw(TranslationError, 'Invalid argument of concat() function: %s' % ast2src(arg.node)) + if translator.database.provider_name == 'cockroach' and not isinstance(arg, StringMixin): + arg = arg.to_str() result_ast.extend(arg.getsql()) return ExprMonad.new(unicode, result_ast, nullable=any(arg.nullable for arg in args)) @@ -3048,7 +3051,10 @@ def count(monad, distinct=None): else: make_aggr = lambda expr_list: [ 'COUNT', None, [ 'COUNT', None ] ] elif translator.dialect == 'PostgreSQL': row = [ 'ROW' ] + expr_list - expr = [ 'CASE', None, [ [ [ 'IS_NULL', row ], [ 'VALUE', None ] ] ], row ] + cond = [ 'IS_NULL', row ] + if translator.database.provider_name == 'cockroach': + cond = [ 'OR' ] + [ [ 'IS_NULL', expr ] for expr in expr_list ] + expr = [ 'CASE', None, [ [ cond, [ 'VALUE', None ] ] ], row ] make_aggr = lambda expr_list: [ 'COUNT', True, expr ] elif translator.row_value_syntax: make_aggr = lambda expr_list: [ 'COUNT', True ] + expr_list From 994654e3e58dbc34c08492f9bb3d50144e25481c Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 31 Jan 2020 00:13:55 +0100 Subject: [PATCH 515/547] Fix cockroach ddl issue --- pony/orm/dbproviders/cockroach.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pony/orm/dbproviders/cockroach.py b/pony/orm/dbproviders/cockroach.py index 90dc7a8e2..95c598ce4 100644 --- a/pony/orm/dbproviders/cockroach.py +++ b/pony/orm/dbproviders/cockroach.py @@ -81,12 +81,12 @@ def normalize_name(provider, name): @wrap_dbapi_exceptions def set_transaction_mode(provider, connection, cache): assert not cache.in_transaction + db_session = cache.db_session + if db_session is not None and db_session.ddl: + cache.immediate = False if cache.immediate and connection.autocommit: connection.autocommit = False if core.local.debug: log_orm('SWITCH FROM AUTOCOMMIT TO TRANSACTION MODE') - db_session = cache.db_session - if db_session is not None and db_session.serializable: - pass elif not cache.immediate and not connection.autocommit: connection.autocommit = True if core.local.debug: log_orm('SWITCH TO AUTOCOMMIT MODE') From 6822b1ce0aeaf8cbd178d9554f1571fa252ab23d Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 31 Jan 2020 00:19:48 +0300 Subject: [PATCH 516/547] cockroach retry support --- pony/orm/core.py | 14 +++++---- pony/orm/dbapiprovider.py | 63 ++++++++++++++++++++++----------------- 2 files changed, 45 insertions(+), 32 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index 162388780..7acc176a2 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -530,14 +530,18 @@ def new_func(func, *args, **kwargs): return result except: exc_type, exc, tb = sys.exc_info() - retry_exceptions = db_session.retry_exceptions - if not callable(retry_exceptions): - do_retry = issubclass(exc_type, tuple(retry_exceptions)) + if getattr(exc, 'should_retry', False): + do_retry = True else: - assert exc is not None # exc can be None in Python 2.6 - do_retry = retry_exceptions(exc) + retry_exceptions = db_session.retry_exceptions + if not callable(retry_exceptions): + do_retry = issubclass(exc_type, tuple(retry_exceptions)) + else: + assert exc is not None # exc can be None in Python 2.6 + do_retry = retry_exceptions(exc) if not do_retry: raise + rollback() finally: db_session.__exit__(exc_type, exc, tb) reraise(exc_type, exc, tb) diff --git a/pony/orm/dbapiprovider.py b/pony/orm/dbapiprovider.py index 9d31dd8ce..a6612d24c 100644 --- a/pony/orm/dbapiprovider.py +++ b/pony/orm/dbapiprovider.py @@ -45,34 +45,43 @@ class NotSupportedError(DatabaseError): pass @decorator def wrap_dbapi_exceptions(func, provider, *args, **kwargs): dbapi_module = provider.dbapi_module + should_retry = False try: - if provider.dialect != 'SQLite': - return func(provider, *args, **kwargs) - else: - provider.local_exceptions.keep_traceback = True - try: return func(provider, *args, **kwargs) - finally: provider.local_exceptions.keep_traceback = False - except dbapi_module.NotSupportedError as e: raise NotSupportedError(e) - except dbapi_module.ProgrammingError as e: - if provider.dialect == 'PostgreSQL': - msg = str(e) - if msg.startswith('operator does not exist:') and ' json ' in msg: - msg += ' (Note: use column type `jsonb` instead of `json`)' - raise ProgrammingError(e, msg, *e.args[1:]) - raise ProgrammingError(e) - except dbapi_module.InternalError as e: raise InternalError(e) - except dbapi_module.IntegrityError as e: raise IntegrityError(e) - except dbapi_module.OperationalError as e: - if provider.dialect == 'SQLite': provider.restore_exception() - raise OperationalError(e) - except dbapi_module.DataError as e: raise DataError(e) - except dbapi_module.DatabaseError as e: raise DatabaseError(e) - except dbapi_module.InterfaceError as e: - if e.args == (0, '') and getattr(dbapi_module, '__name__', None) == 'MySQLdb': - throw(InterfaceError, e, 'MySQL server misconfiguration') - raise InterfaceError(e) - except dbapi_module.Error as e: raise Error(e) - except dbapi_module.Warning as e: raise Warning(e) + try: + if provider.dialect != 'SQLite': + return func(provider, *args, **kwargs) + else: + provider.local_exceptions.keep_traceback = True + try: return func(provider, *args, **kwargs) + finally: provider.local_exceptions.keep_traceback = False + except dbapi_module.NotSupportedError as e: raise NotSupportedError(e) + except dbapi_module.ProgrammingError as e: + if provider.dialect == 'PostgreSQL': + msg = str(e) + if msg.startswith('operator does not exist:') and ' json ' in msg: + msg += ' (Note: use column type `jsonb` instead of `json`)' + raise ProgrammingError(e, msg, *e.args[1:]) + raise ProgrammingError(e) + except dbapi_module.InternalError as e: raise InternalError(e) + except dbapi_module.IntegrityError as e: raise IntegrityError(e) + except dbapi_module.OperationalError as e: + if provider.dialect == 'PostgreSQL' and e.pgcode == '40001': + should_retry = True + if provider.dialect == 'SQLite': + provider.restore_exception() + raise OperationalError(e) + except dbapi_module.DataError as e: raise DataError(e) + except dbapi_module.DatabaseError as e: raise DatabaseError(e) + except dbapi_module.InterfaceError as e: + if e.args == (0, '') and getattr(dbapi_module, '__name__', None) == 'MySQLdb': + throw(InterfaceError, e, 'MySQL server misconfiguration') + raise InterfaceError(e) + except dbapi_module.Error as e: raise Error(e) + except dbapi_module.Warning as e: raise Warning(e) + except Exception as e: + if should_retry: + e.should_retry = True + raise def unexpected_args(attr, args): throw(TypeError, 'Unexpected positional argument{} for attribute {}: {}'.format( From bd6dc31467a4431bd4902645655e92b6b24a96a0 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 4 Dec 2019 17:19:03 +0300 Subject: [PATCH 517/547] New tests mechanics for testing SQLite, PostgreSQL & CockroachDB --- pony/orm/tests/__init__.py | 70 ++++- pony/orm/tests/model1.py | 7 +- pony/orm/tests/py36_test_f_strings.py | 27 +- pony/orm/tests/test_array.py | 21 +- pony/orm/tests/test_attribute_options.py | 51 ++-- pony/orm/tests/test_autostrip.py | 11 +- pony/orm/tests/test_buffer.py | 24 +- pony/orm/tests/test_bug_170.py | 25 +- pony/orm/tests/test_bug_182.py | 40 +-- pony/orm/tests/test_bug_331.py | 40 ++- pony/orm/tests/test_bug_386.py | 21 +- pony/orm/tests/test_cascade.py | 58 ++-- pony/orm/tests/test_cascade_delete.py | 64 +++-- pony/orm/tests/test_collections.py | 10 + pony/orm/tests/test_core_find_in_cache.py | 50 +++- pony/orm/tests/test_core_multiset.py | 51 ++-- pony/orm/tests/test_crud.py | 34 ++- pony/orm/tests/test_crud_raw_sql.py | 13 +- pony/orm/tests/test_datetime.py | 22 +- pony/orm/tests/test_db_session.py | 35 ++- .../tests/test_declarative_attr_set_monad.py | 71 ++--- pony/orm/tests/test_declarative_exceptions.py | 26 +- pony/orm/tests/test_declarative_func_monad.py | 54 ++-- .../test_declarative_join_optimization.py | 24 +- .../test_declarative_object_flat_monad.py | 51 ++-- .../tests/test_declarative_orderby_limit.py | 31 +- .../tests/test_declarative_query_set_monad.py | 32 ++- .../tests/test_declarative_sqltranslator.py | 66 +++-- .../tests/test_declarative_sqltranslator2.py | 96 ++++--- pony/orm/tests/test_declarative_strings.py | 171 ++++++----- pony/orm/tests/test_deduplication.py | 38 +-- pony/orm/tests/test_diagram.py | 77 +++-- pony/orm/tests/test_diagram_attribute.py | 265 +++++++++--------- pony/orm/tests/test_diagram_keys.py | 49 ++-- pony/orm/tests/test_distinct.py | 42 +-- pony/orm/tests/test_entity_init.py | 33 ++- pony/orm/tests/test_entity_instances.py | 103 ------- pony/orm/tests/test_entity_proxy.py | 63 ++--- pony/orm/tests/test_exists.py | 31 +- pony/orm/tests/test_filter.py | 9 + pony/orm/tests/test_flush.py | 32 ++- pony/orm/tests/test_frames.py | 20 +- pony/orm/tests/test_generator_db_session.py | 7 +- pony/orm/tests/test_get_pk.py | 84 +++--- pony/orm/tests/test_getattr.py | 10 +- pony/orm/tests/test_hooks.py | 32 ++- .../test_hybrid_methods_and_properties.py | 39 ++- pony/orm/tests/test_indexes.py | 42 ++- pony/orm/tests/test_inheritance.py | 68 +++-- pony/orm/tests/test_inner_join_syntax.py | 64 +++-- pony/orm/tests/test_isinstance.py | 36 ++- pony/orm/tests/test_json.py | 23 +- pony/orm/tests/test_lazy.py | 15 +- pony/orm/tests/test_mapping.py | 48 ++-- .../orm/tests/test_objects_to_save_cleanup.py | 33 +-- pony/orm/tests/test_prefetching.py | 54 ++-- pony/orm/tests/test_query.py | 25 +- pony/orm/tests/test_random.py | 26 +- pony/orm/tests/test_raw_sql.py | 23 +- pony/orm/tests/test_relations_m2m.py | 98 +++---- pony/orm/tests/test_relations_one2many.py | 8 +- pony/orm/tests/test_relations_one2one1.py | 22 +- pony/orm/tests/test_relations_one2one2.py | 14 +- pony/orm/tests/test_relations_one2one3.py | 9 +- pony/orm/tests/test_relations_one2one4.py | 34 +-- .../orm/tests/test_relations_symmetric_m2m.py | 14 +- .../tests/test_relations_symmetric_one2one.py | 13 +- .../tests/test_select_from_select_queries.py | 43 +-- pony/orm/tests/test_show.py | 35 ++- .../tests/test_sqlbuilding_formatstyles.py | 1 + pony/orm/tests/test_sqlbuilding_sqlast.py | 14 +- pony/orm/tests/test_sqlite_str_functions.py | 5 + pony/orm/tests/test_time_parsing.py | 1 + pony/orm/tests/test_to_dict.py | 167 +++++++---- pony/orm/tests/test_transaction_lock.py | 26 +- pony/orm/tests/test_validate.py | 40 ++- pony/orm/tests/test_virtuals.py | 0 pony/orm/tests/test_volatile.py | 17 +- 78 files changed, 1933 insertions(+), 1315 deletions(-) delete mode 100644 pony/orm/tests/test_entity_instances.py create mode 100644 pony/orm/tests/test_virtuals.py diff --git a/pony/orm/tests/__init__.py b/pony/orm/tests/__init__.py index 150399271..2f1caa71d 100644 --- a/pony/orm/tests/__init__.py +++ b/pony/orm/tests/__init__.py @@ -1,4 +1,72 @@ +import unittest +import os +import types import pony.orm.core, pony.options pony.options.CUT_TRACEBACK = False -pony.orm.core.sql_debug(False) \ No newline at end of file +pony.orm.core.sql_debug(False) + + +def _load_env(): + settings_filename = os.environ.get('pony_test_db') + if settings_filename is None: + print('use default sqlite provider') + return dict(provider='sqlite', filename=':memory:') + with open(settings_filename, 'r') as f: + content = f.read() + + config = {} + exec(content, config) + settings = config.get('settings') + if settings is None or not isinstance(settings, dict): + raise ValueError('Incorrect settings pony test db file contents') + provider = settings.get('provider') + if provider is None: + raise ValueError('Incorrect settings pony test db file contents: provider was not specified') + print('use provider %s' % provider) + return settings + + +db_params = _load_env() + + +def setup_database(db): + if db.provider is None: + db.bind(**db_params) + if db.schema is None: + db.generate_mapping(check_tables=False) + db.drop_all_tables(with_all_data=True) + db.create_tables() + + +def teardown_database(db): + if db.schema: + db.drop_all_tables(with_all_data=True) + db.disconnect() + + +def only_for(providers): + if not isinstance(providers, (list, tuple)): + providers = [providers] + def decorator(x): + if isinstance(x, type) and issubclass(x, unittest.TestCase): + @classmethod + def setUpClass(cls): + raise unittest.SkipTest('%s tests implemented only for %s provider%s' % ( + cls.__name__, ', '.join(providers), '' if len(providers) < 2 else 's' + )) + if db_params['provider'] not in providers: + x.setUpClass = setUpClass + result = x + elif isinstance(x, types.FunctionType): + def new_test_func(self): + if db_params['provider'] not in providers: + raise unittest.SkipTest('%s test implemented only for %s provider%s' % ( + x.__name__, ', '.join(providers), '' if len(providers) < 2 else 's' + )) + return x(self) + result = new_test_func + else: + raise TypeError + return result + return decorator diff --git a/pony/orm/tests/model1.py b/pony/orm/tests/model1.py index 8858b3d4c..96a01c208 100644 --- a/pony/orm/tests/model1.py +++ b/pony/orm/tests/model1.py @@ -1,8 +1,9 @@ from __future__ import absolute_import, print_function, division from pony.orm.core import * +from pony.orm.tests import db_params -db = Database('sqlite', ':memory:') +db = Database(**db_params) class Student(db.Entity): _table_ = "Students" @@ -33,7 +34,8 @@ class Mark(db.Entity): PrimaryKey(student, subject) -db.generate_mapping(create_tables=True) +db.generate_mapping(check_tables=False) + @db_session def populate_db(): @@ -56,4 +58,3 @@ def populate_db(): Mark(student=s102, subject=Chemistry, value=5) Mark(student=s103, subject=Physics, value=2) Mark(student=s103, subject=Chemistry, value=4) -populate_db() diff --git a/pony/orm/tests/py36_test_f_strings.py b/pony/orm/tests/py36_test_f_strings.py index 6c291ba57..7da921c7f 100644 --- a/pony/orm/tests/py36_test_f_strings.py +++ b/pony/orm/tests/py36_test_f_strings.py @@ -1,8 +1,10 @@ import unittest from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') + +db = Database() class Person(db.Entity): first_name = Required(str) @@ -11,17 +13,22 @@ class Person(db.Entity): value = Required(float) -db.generate_mapping(create_tables=True) +class TestFString(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + Person(id=1, first_name='Alexander', last_name='Tischenko', age=23, value=1.4) + Person(id=2, first_name='Alexander', last_name='Kozlovskiy', age=42, value=1.2) + Person(id=3, first_name='Arthur', last_name='Pendragon', age=54, value=1.33) + Person(id=4, first_name='Okita', last_name='Souji', age=15, value=2.1) + Person(id=5, first_name='Musashi', last_name='Miyamoto', age=None, value=0.9) + Person(id=6, first_name='Jeanne', last_name="d'Arc", age=30, value=43.212) -with db_session: - Person(id=1, first_name='Alexander', last_name='Tischenko', age=23, value=1.4) - Person(id=2, first_name='Alexander', last_name='Kozlovskiy', age=42, value=1.2) - Person(id=3, first_name='Arthur', last_name='Pendragon', age=54, value=1.33) - Person(id=4, first_name='Okita', last_name='Souji', age=15, value=2.1) - Person(id=5, first_name='Musashi', last_name='Miyamoto', age=None, value=0.9) - Person(id=6, first_name='Jeanne', last_name="d'Arc", age=30, value=43.212) + @classmethod + def tearDownClass(cls): + teardown_database(db) -class TestFString(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() diff --git a/pony/orm/tests/test_array.py b/pony/orm/tests/test_array.py index e446b2aac..ef647e055 100644 --- a/pony/orm/tests/test_array.py +++ b/pony/orm/tests/test_array.py @@ -2,10 +2,12 @@ import unittest from pony.orm.tests.testutils import * +from pony.orm.tests import db_params, setup_database, teardown_database from pony.orm import * -db = Database('sqlite', ':memory:') +db = Database() + class Foo(db.Entity): id = PrimaryKey(int) @@ -18,13 +20,22 @@ class Foo(db.Entity): array4 = Optional(IntArray) array5 = Optional(IntArray, nullable=True) -db.generate_mapping(create_tables=True) +class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + if db_params['provider'] not in ('sqlite', 'postgres'): + raise unittest.SkipTest('Arrays are only available for SQLite and PostgreSQL') + + setup_database(db) + with db_session: + Foo(id=1, a=1, b=3, c=-2, array1=[10, 20, 30, 40, 50], array2=[1.1, 2.2, 3.3, 4.4, 5.5], + array3=['foo', 'bar']) -with db_session: - Foo(id=1, a=1, b=3, c=-2, array1=[10, 20, 30, 40, 50], array2=[1.1, 2.2, 3.3, 4.4, 5.5], array3=['foo', 'bar']) + @classmethod + def tearDownClass(cls): + teardown_database(db) -class Test(unittest.TestCase): @db_session def test_1(self): foo = select(f for f in Foo if 10 in f.array1)[:] diff --git a/pony/orm/tests/test_attribute_options.py b/pony/orm/tests/test_attribute_options.py index cf8d79591..6ba55ca46 100644 --- a/pony/orm/tests/test_attribute_options.py +++ b/pony/orm/tests/test_attribute_options.py @@ -5,11 +5,13 @@ from pony import orm from pony.orm.core import * +from pony.orm.tests import setup_database, teardown_database from pony.orm.tests.testutils import raises_exception -db = Database('sqlite', ':memory:') +db = Database() class Person(db.Entity): + id = PrimaryKey(int) name = orm.Required(str, 40) lastName = orm.Required(str, max_len=40, unique=True) age = orm.Optional(int, max=60, min=10) @@ -21,15 +23,20 @@ class Person(db.Entity): gpa = orm.Optional(float, py_check=lambda val: val >= 0 and val <= 5) vehicle = orm.Optional(str, column='car') -db.generate_mapping(create_tables=True) +class TestAttributeOptions(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with orm.db_session: + p1 = Person(id=1, name='Andrew', lastName='Bodroue', age=40, rate=0.980000000001, salaryRate=0.98000001) + p2 = Person(id=2, name='Vladimir', lastName='Andrew ', nickName='vlad ') + p3 = Person(id=3, name='Nick', lastName='Craig', middleName=None, timeStmp='2010-12-10 14:12:09.019473', + vehicle='dodge') -with orm.db_session: - p1 = Person(name='Andrew', lastName='Bodroue', age=40, rate=0.980000000001, salaryRate=0.98000001) - p2 = Person(name='Vladimir', lastName='Andrew ', nickName='vlad ') - p3 = Person(name='Nick', lastName='Craig', middleName=None, timeStmp='2010-12-10 14:12:09.019473', vehicle='dodge') + @classmethod + def tearDownClass(cls): + teardown_database(db) -class TestAttributeOptions(unittest.TestCase): - def setUp(self): rollback() db_session.__enter__() @@ -47,10 +54,10 @@ def test_optionalStringNone(self): self.assertIsNotNone(queryResult) def test_stringAutoStrip(self): - self.assertEqual(p2.lastName, 'Andrew') + self.assertEqual(Person[2].lastName, 'Andrew') def test_stringAutoStripFalse(self): - self.assertEqual(p2.nickName, 'vlad ') + self.assertEqual(Person[2].nickName, 'vlad ') def test_intNone(self): queryResult = select(p.id for p in Person if p.age==None).first() @@ -72,34 +79,34 @@ def test_fractionalSeconds(self): self.assertEqual(queryResult.microsecond, 19473) def test_intMax(self): - p4 = Person(name='Denis', lastName='Blanc', age=60) + p4 = Person(id=4, name='Denis', lastName='Blanc', age=60) def test_intMin(self): - p4 = Person(name='Denis', lastName='Blanc', age=10) + p4 = Person(id=4, name='Denis', lastName='Blanc', age=10) @raises_exception(ValueError, "Value 61 of attr Person.age is greater than the maximum allowed value 60") def test_intMaxException(self): - p4 = Person(name='Denis', lastName='Blanc', age=61) + p4 = Person(id=4, name='Denis', lastName='Blanc', age=61) @raises_exception(ValueError, "Value 9 of attr Person.age is less than the minimum allowed value 10") def test_intMinException(self): - p4 = Person(name='Denis', lastName='Blanc', age=9) + p4 = Person(id=4, name='Denis', lastName='Blanc', age=9) def test_py_check(self): - p4 = Person(name='Denis', lastName='Blanc', gpa=5) - p5 = Person(name='Mario', lastName='Gon', gpa=1) + p4 = Person(id=4, name='Denis', lastName='Blanc', gpa=5) + p5 = Person(id=5, name='Mario', lastName='Gon', gpa=1) flush() @raises_exception(ValueError, "Check for attribute Person.gpa failed. Value: 6.0") def test_py_checkMoreException(self): - p6 = Person(name='Daniel', lastName='Craig', gpa=6) + p6 = Person(id=6, name='Daniel', lastName='Craig', gpa=6) @raises_exception(ValueError, "Check for attribute Person.gpa failed. Value: -1.0") def test_py_checkLessException(self): - p6 = Person(name='Daniel', lastName='Craig', gpa=-1) + p6 = Person(id=6, name='Daniel', lastName='Craig', gpa=-1) - @raises_exception(TransactionIntegrityError, 'Object Person[new:...] cannot be stored in the database.' - ' IntegrityError: UNIQUE constraint failed: Person.lastName') + @raises_exception(TransactionIntegrityError, + 'Object Person[...] cannot be stored in the database. IntegrityError: ...') def test_unique(self): - p6 = Person(name='Boris', lastName='Bodroue') - flush() \ No newline at end of file + p6 = Person(id=6, name='Boris', lastName='Bodroue') + flush() diff --git a/pony/orm/tests/test_autostrip.py b/pony/orm/tests/test_autostrip.py index 18865536b..b4656256d 100644 --- a/pony/orm/tests/test_autostrip.py +++ b/pony/orm/tests/test_autostrip.py @@ -2,16 +2,23 @@ from pony.orm import * from pony.orm.tests.testutils import raises_exception +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Person(db.Entity): name = Required(str) tel = Optional(str) -db.generate_mapping(create_tables=True) class TestAutostrip(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) @db_session def test_1(self): diff --git a/pony/orm/tests/test_buffer.py b/pony/orm/tests/test_buffer.py index 7dc9f13de..002317f9f 100644 --- a/pony/orm/tests/test_buffer.py +++ b/pony/orm/tests/test_buffer.py @@ -1,32 +1,40 @@ import unittest from pony import orm -from pony.py23compat import buffer +from pony.orm.tests import setup_database, teardown_database + +db = orm.Database() -db = orm.Database('sqlite', ':memory:') class Foo(db.Entity): id = orm.PrimaryKey(int) b = orm.Optional(orm.buffer) + class Bar(db.Entity): b = orm.PrimaryKey(orm.buffer) + class Baz(db.Entity): id = orm.PrimaryKey(int) b = orm.Optional(orm.buffer, unique=True) -db.generate_mapping(create_tables=True) buf = orm.buffer(b'123') -with orm.db_session: - Foo(id=1, b=buf) - Bar(b=buf) - Baz(id=1, b=buf) +class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with orm.db_session: + Foo(id=1, b=buf) + Bar(b=buf) + Baz(id=1, b=buf) + @classmethod + def tearDownClass(cls): + teardown_database(db) -class Test(unittest.TestCase): def test_1(self): # Bug #355 with orm.db_session: Bar[buf] diff --git a/pony/orm/tests/test_bug_170.py b/pony/orm/tests/test_bug_170.py index e60127d78..8b38deab1 100644 --- a/pony/orm/tests/test_bug_170.py +++ b/pony/orm/tests/test_bug_170.py @@ -1,18 +1,27 @@ import unittest from pony import orm +from pony.orm.tests import setup_database, teardown_database -class Test(unittest.TestCase): - def test_1(self): - db = orm.Database('sqlite', ':memory:') +db = orm.Database() - class Person(db.Entity): - id = orm.PrimaryKey(int, auto=True) - name = orm.Required(str) - orm.composite_key(id, name) - db.generate_mapping(create_tables=True) +class Person(db.Entity): + id = orm.PrimaryKey(int, auto=True) + name = orm.Required(str) + orm.composite_key(id, name) + +class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def test_1(self): table = db.schema.tables[Person._table_] pk_column = table.column_dict[Person.id.column] self.assertTrue(pk_column.is_pk) diff --git a/pony/orm/tests/test_bug_182.py b/pony/orm/tests/test_bug_182.py index 5d59b2fe7..7581f9c67 100644 --- a/pony/orm/tests/test_bug_182.py +++ b/pony/orm/tests/test_bug_182.py @@ -1,40 +1,46 @@ - import unittest from pony.orm import * from pony import orm +from pony.orm.tests import setup_database, teardown_database +db = Database() -class Test(unittest.TestCase): - def setUp(self): - db = self.db = Database('sqlite', ':memory:') +class User(db.Entity): + name = Required(str) + servers = Set("Server") + - class User(db.Entity): - name = Required(str) - servers = Set("Server") +class Worker(User): + pass - class Worker(db.User): - pass - class Admin(db.Worker): - pass +class Admin(Worker): + pass - # And M:1 relationship with another entity - class Server(db.Entity): - name = Required(str) - user = Optional(User) +# And M:1 relationship with another entity +class Server(db.Entity): + name = Required(str) + user = Optional(User) - db.generate_mapping(check_tables=True, create_tables=True) +class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) with orm.db_session: Server(name='s1.example.com', user=User(name="Alex")) Server(name='s2.example.com', user=Worker(name="John")) Server(name='free.example.com', user=None) + @classmethod + def tearDownClass(cls): + teardown_database(db) + @db_session def test(self): - qu = left_join((s.name, s.user.name) for s in self.db.Server)[:] + qu = left_join((s.name, s.user.name) for s in db.Server)[:] for server, user in qu: if user is None: break diff --git a/pony/orm/tests/test_bug_331.py b/pony/orm/tests/test_bug_331.py index 00832e484..3a6af0469 100644 --- a/pony/orm/tests/test_bug_331.py +++ b/pony/orm/tests/test_bug_331.py @@ -1,25 +1,35 @@ import unittest -from pony import orm +from pony.orm.tests import setup_database, teardown_database +from pony.orm import * -class Test(unittest.TestCase): - def test_1(self): - db = orm.Database('sqlite', ':memory:') +db = Database() + + +class Person(db.Entity): + name = Required(str) + group = Optional(lambda: Group) - class Person(db.Entity): - name = orm.Required(str) - group = orm.Optional(lambda: Group) - class Group(db.Entity): - title = orm.PrimaryKey(str) - persons = orm.Set(Person) +class Group(db.Entity): + title = PrimaryKey(str) + persons = Set(Person) - def __len__(self): - return len(self.persons) + def __len__(self): + return len(self.persons) - db.generate_mapping(create_tables=True) - with orm.db_session: +class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def test_1(self): + with db_session: p1 = Person(name="Alex") p2 = Person(name="Brad") p3 = Person(name="Chad") @@ -34,7 +44,7 @@ def __len__(self): g1.persons.add(p3) g2.persons.add(p4) g2.persons.add(p5) - orm.commit() + commit() foxes = Group['Foxes'] gorillas = Group['Gorillas'] diff --git a/pony/orm/tests/test_bug_386.py b/pony/orm/tests/test_bug_386.py index fba12711d..31f62574b 100644 --- a/pony/orm/tests/test_bug_386.py +++ b/pony/orm/tests/test_bug_386.py @@ -1,16 +1,25 @@ import unittest from pony import orm +from pony.orm.tests import setup_database, teardown_database -class Test(unittest.TestCase): - def test_1(self): - db = orm.Database('sqlite', ':memory:') +db = orm.Database() - class Person(db.Entity): - name = orm.Required(str) - db.generate_mapping(create_tables=True) +class Person(db.Entity): + name = orm.Required(str) + +class Test(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + def test_1(self): with orm.db_session: a = Person(name='John') a.delete() diff --git a/pony/orm/tests/test_cascade.py b/pony/orm/tests/test_cascade.py index a05cc07b3..d2f52c21f 100644 --- a/pony/orm/tests/test_cascade.py +++ b/pony/orm/tests/test_cascade.py @@ -2,11 +2,33 @@ from pony.orm import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + class TestCascade(unittest.TestCase): + providers = ['sqlite'] # Implement for other providers + + def tearDown(self): + if self.db.schema is not None: + teardown_database(self.db) + + def assert_on_delete(self, table_name, value): + db = self.db + if not (db.provider.dialect == 'SQLite' and pony.__version__ < '0.9'): + table_name = table_name.lower() + table = db.schema.tables[table_name] + fkeys = table.foreign_keys + self.assertEqual(1, len(fkeys)) + if pony.__version__ >= '0.9': + self.assertEqual(fkeys[0].on_delete, value) + elif db.provider.dialect == 'SQLite': + self.assertIn('ON DELETE %s' % value, table.get_create_command()) + else: + self.assertIn('ON DELETE %s' % value, list(fkeys.values())[0].get_create_command()) + def test_1(self): - db = self.db = Database('sqlite', ':memory:') + db = self.db = Database() class Person(self.db.Entity): name = Required(str) @@ -15,12 +37,11 @@ class Person(self.db.Entity): class Group(self.db.Entity): persons = Set(Person) - db.generate_mapping(create_tables=True) - - self.assertTrue('ON DELETE CASCADE' in self.db.schema.tables['Person'].get_create_command()) + setup_database(db) + self.assert_on_delete('Person', 'CASCADE') def test_2(self): - db = self.db = Database('sqlite', ':memory:') + db = self.db = Database() class Person(self.db.Entity): name = Required(str) @@ -29,13 +50,11 @@ class Person(self.db.Entity): class Group(self.db.Entity): persons = Set(Person, cascade_delete=True) - db.generate_mapping(create_tables=True) - - self.assertTrue('ON DELETE CASCADE' in self.db.schema.tables['Person'].get_create_command()) - + setup_database(db) + self.assert_on_delete('Person', 'CASCADE') def test_3(self): - db = self.db = Database('sqlite', ':memory:') + db = self.db = Database() class Person(self.db.Entity): name = Required(str) @@ -44,13 +63,12 @@ class Person(self.db.Entity): class Group(self.db.Entity): persons = Set(Person, cascade_delete=True) - db.generate_mapping(create_tables=True) - - self.assertTrue('ON DELETE CASCADE' in self.db.schema.tables['Person'].get_create_command()) + setup_database(db) + self.assert_on_delete('Person', 'CASCADE') @raises_exception(TypeError, "'cascade_delete' option cannot be set for attribute Group.persons, because reverse attribute Person.group is collection") def test_4(self): - db = self.db = Database('sqlite', ':memory:') + db = self.db = Database() class Person(self.db.Entity): name = Required(str) @@ -59,11 +77,11 @@ class Person(self.db.Entity): class Group(self.db.Entity): persons = Set(Person, cascade_delete=True) - db.generate_mapping(create_tables=True) + setup_database(db) @raises_exception(TypeError, "'cascade_delete' option cannot be set for both sides of relationship (Person.group and Group.persons) simultaneously") def test_5(self): - db = self.db = Database('sqlite', ':memory:') + db = self.db = Database() class Person(self.db.Entity): name = Required(str) @@ -72,10 +90,10 @@ class Person(self.db.Entity): class Group(self.db.Entity): persons = Required(Person, cascade_delete=True) - db.generate_mapping(create_tables=True) + setup_database(db) def test_6(self): - db = self.db = Database('sqlite', ':memory:') + db = self.db = Database() class Person(self.db.Entity): name = Required(str) @@ -84,9 +102,9 @@ class Person(self.db.Entity): class Group(self.db.Entity): persons = Optional(Person) - db.generate_mapping(create_tables=True) + setup_database(db) + self.assert_on_delete('Group', 'SET NULL') - self.assertTrue('ON DELETE SET NULL' in self.db.schema.tables['Group'].get_create_command()) if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_cascade_delete.py b/pony/orm/tests/test_cascade_delete.py index e5974fc45..b2d38f9dc 100644 --- a/pony/orm/tests/test_cascade_delete.py +++ b/pony/orm/tests/test_cascade_delete.py @@ -1,49 +1,55 @@ import unittest from pony.orm import * +from pony.orm.tests import setup_database, teardown_database, only_for -db = Database('sqlite', ':memory:') +db = Database() class X(db.Entity): id = PrimaryKey(int) parent = Optional('X', reverse='children') children = Set('X', reverse='parent', cascade_delete=True) + class Y(db.Entity): parent = Optional('Y', reverse='children') children = Set('Y', reverse='parent', cascade_delete=True, lazy=True) -db.generate_mapping(create_tables=True) - -with db_session: - x1 = X(id=1) - x2 = X(id=2, parent=x1) - x3 = X(id=3, parent=x1) - x4 = X(id=4, parent=x3) - x5 = X(id=5, parent=x3) - x6 = X(id=6, parent=x5) - x7 = X(id=7, parent=x3) - x8 = X(id=8, parent=x7) - x9 = X(id=9, parent=x7) - x10 = X(id=10) - x11 = X(id=11, parent=x10) - x12 = X(id=12, parent=x10) +class TestCascade(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + x1 = X(id=1) + x2 = X(id=2, parent=x1) + x3 = X(id=3, parent=x1) + x4 = X(id=4, parent=x3) + x5 = X(id=5, parent=x3) + x6 = X(id=6, parent=x5) + x7 = X(id=7, parent=x3) + x8 = X(id=8, parent=x7) + x9 = X(id=9, parent=x7) + x10 = X(id=10) + x11 = X(id=11, parent=x10) + x12 = X(id=12, parent=x10) - y1 = Y(id=1) - y2 = Y(id=2, parent=y1) - y3 = Y(id=3, parent=y1) - y4 = Y(id=4, parent=y3) - y5 = Y(id=5, parent=y3) - y6 = Y(id=6, parent=y5) - y7 = Y(id=7, parent=y3) - y8 = Y(id=8, parent=y7) - y9 = Y(id=9, parent=y7) - y10 = Y(id=10) - y11 = Y(id=11, parent=y10) - y12 = Y(id=12, parent=y10) + y1 = Y(id=1) + y2 = Y(id=2, parent=y1) + y3 = Y(id=3, parent=y1) + y4 = Y(id=4, parent=y3) + y5 = Y(id=5, parent=y3) + y6 = Y(id=6, parent=y5) + y7 = Y(id=7, parent=y3) + y8 = Y(id=8, parent=y7) + y9 = Y(id=9, parent=y7) + y10 = Y(id=10) + y11 = Y(id=11, parent=y10) + y12 = Y(id=12, parent=y10) -class TestCascade(unittest.TestCase): + @classmethod + def tearDownClass(cls): + teardown_database(db) def setUp(self): rollback() diff --git a/pony/orm/tests/test_collections.py b/pony/orm/tests/test_collections.py index 15242f148..a6a54e1bb 100644 --- a/pony/orm/tests/test_collections.py +++ b/pony/orm/tests/test_collections.py @@ -5,8 +5,18 @@ from pony.orm.tests.testutils import raises_exception from pony.orm.tests.model1 import * +from pony.orm.tests import setup_database, teardown_database + class TestCollections(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + populate_db() + + @classmethod + def tearDownClass(cls): + db.drop_all_tables(with_all_data=True) @db_session def test_setwrapper_len(self): diff --git a/pony/orm/tests/test_core_find_in_cache.py b/pony/orm/tests/test_core_find_in_cache.py index d0ec1b29d..d803809b8 100644 --- a/pony/orm/tests/test_core_find_in_cache.py +++ b/pony/orm/tests/test_core_find_in_cache.py @@ -3,8 +3,9 @@ import unittest from pony.orm.tests.testutils import raises_exception from pony.orm import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class AbstractUser(db.Entity): username = PrimaryKey(unicode) @@ -32,31 +33,40 @@ class Diagram(db.Entity): name = Required(unicode) owner = Required(User) -db.generate_mapping(create_tables=True) - -with db_session: - u1 = User(username='user1') - u2 = SubUser1(username='subuser1', attr1='some attr') - u3 = SubUser2(username='subuser2', attr2='some attr') - o1 = Organization(username='org1') - o2 = SubOrg1(username='suborg1', attr3='some attr') - o3 = SubOrg2(username='suborg2', attr4='some attr') - au = AbstractUser(username='abstractUser') - Diagram(name='diagram1', owner=u1) - Diagram(name='diagram2', owner=u2) - Diagram(name='diagram3', owner=u3) def is_seed(entity, pk): cache = entity._database_._get_cache() return pk in [ obj._pk_ for obj in cache.seeds[entity._pk_attrs_] ] + class TestFindInCache(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + u1 = User(username='user1') + u2 = SubUser1(username='subuser1', attr1='some attr') + u3 = SubUser2(username='subuser2', attr2='some attr') + o1 = Organization(username='org1') + o2 = SubOrg1(username='suborg1', attr3='some attr') + o3 = SubOrg2(username='suborg2', attr4='some attr') + au = AbstractUser(username='abstractUser') + Diagram(name='diagram1', owner=u1) + Diagram(name='diagram2', owner=u2) + Diagram(name='diagram3', owner=u3) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() + def tearDown(self): rollback() db_session.__exit__() + def test1(self): u = User.get(username='org1') org = Organization.get(username='org1') @@ -72,6 +82,7 @@ def test_user_1(self): u = AbstractUser['user1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, User) + def test_user_2(self): Diagram.get(lambda d: d.name == 'diagram1') last_sql = db.last_sql @@ -79,6 +90,7 @@ def test_user_2(self): u = User['user1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, User) + @raises_exception(ObjectNotFound) def test_user_3(self): Diagram.get(lambda d: d.name == 'diagram1') @@ -88,6 +100,7 @@ def test_user_3(self): SubUser1['user1'] finally: self.assertNotEqual(last_sql, db.last_sql) + @raises_exception(ObjectNotFound) def test_user_4(self): Diagram.get(lambda d: d.name == 'diagram1') @@ -97,6 +110,7 @@ def test_user_4(self): Organization['user1'] finally: self.assertEqual(last_sql, db.last_sql) + @raises_exception(ObjectNotFound) def test_user_5(self): Diagram.get(lambda d: d.name == 'diagram1') @@ -107,7 +121,6 @@ def test_user_5(self): finally: self.assertEqual(last_sql, db.last_sql) - def test_subuser_1(self): Diagram.get(lambda d: d.name == 'diagram2') last_sql = db.last_sql @@ -115,6 +128,7 @@ def test_subuser_1(self): u = AbstractUser['subuser1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, SubUser1) + def test_subuser_2(self): Diagram.get(lambda d: d.name == 'diagram2') last_sql = db.last_sql @@ -122,6 +136,7 @@ def test_subuser_2(self): u = User['subuser1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, SubUser1) + def test_subuser_3(self): Diagram.get(lambda d: d.name == 'diagram2') last_sql = db.last_sql @@ -129,6 +144,7 @@ def test_subuser_3(self): u = SubUser1['subuser1'] self.assertNotEqual(last_sql, db.last_sql) self.assertEqual(u.__class__, SubUser1) + @raises_exception(ObjectNotFound) def test_subuser_4(self): Diagram.get(lambda d: d.name == 'diagram2') @@ -138,6 +154,7 @@ def test_subuser_4(self): Organization['subuser1'] finally: self.assertEqual(last_sql, db.last_sql) + @raises_exception(ObjectNotFound) def test_subuser_5(self): Diagram.get(lambda d: d.name == 'diagram2') @@ -147,6 +164,7 @@ def test_subuser_5(self): SubUser2['subuser1'] finally: self.assertNotEqual(last_sql, db.last_sql) + @raises_exception(ObjectNotFound) def test_subuser_6(self): Diagram.get(lambda d: d.name == 'diagram2') @@ -163,6 +181,7 @@ def test_user_6(self): u2 = SubUser1['subuser1'] self.assertEqual(last_sql, db.last_sql) self.assertEqual(u1, u2) + def test_user_7(self): u1 = SubUser1['subuser1'] u1.delete() @@ -170,6 +189,7 @@ def test_user_7(self): u2 = SubUser1.get(username='subuser1') self.assertEqual(last_sql, db.last_sql) self.assertEqual(u2, None) + def test_user_8(self): u1 = SubUser1['subuser1'] last_sql = db.last_sql diff --git a/pony/orm/tests/test_core_multiset.py b/pony/orm/tests/test_core_multiset.py index 278e33df9..65d5153dd 100644 --- a/pony/orm/tests/test_core_multiset.py +++ b/pony/orm/tests/test_core_multiset.py @@ -4,8 +4,9 @@ from pony.orm.core import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Department(db.Entity): number = PrimaryKey(int) @@ -27,33 +28,40 @@ class Course(db.Entity): department = Required(Department) students = Set('Student') -db.generate_mapping(create_tables=True) -with db_session: - d1 = Department(number=1) - d2 = Department(number=2) - d3 = Department(number=3) +class TestMultiset(unittest.TestCase): - g1 = Group(number=101, department=d1) - g2 = Group(number=102, department=d1) - g3 = Group(number=201, department=d2) + @classmethod + def setUpClass(cls): + setup_database(db) - c1 = Course(name='C1', department=d1) - c2 = Course(name='C2', department=d1) - c3 = Course(name='C3', department=d2) - c4 = Course(name='C4', department=d2) - c5 = Course(name='C5', department=d3) + with db_session: + d1 = Department(number=1) + d2 = Department(number=2) + d3 = Department(number=3) - s1 = Student(name='S1', group=g1, courses=[c1, c2]) - s2 = Student(name='S2', group=g1, courses=[c1, c3]) - s3 = Student(name='S3', group=g1, courses=[c2, c3]) + g1 = Group(number=101, department=d1) + g2 = Group(number=102, department=d1) + g3 = Group(number=201, department=d2) - s4 = Student(name='S4', group=g2, courses=[c1, c2]) - s5 = Student(name='S5', group=g2, courses=[c1, c2]) + c1 = Course(name='C1', department=d1) + c2 = Course(name='C2', department=d1) + c3 = Course(name='C3', department=d2) + c4 = Course(name='C4', department=d2) + c5 = Course(name='C5', department=d3) - s6 = Student(name='A', group=g3, courses=[c5]) + s1 = Student(name='S1', group=g1, courses=[c1, c2]) + s2 = Student(name='S2', group=g1, courses=[c1, c3]) + s3 = Student(name='S3', group=g1, courses=[c2, c3]) -class TestMultiset(unittest.TestCase): + s4 = Student(name='S4', group=g2, courses=[c1, c2]) + s5 = Student(name='S5', group=g2, courses=[c1, c2]) + + s6 = Student(name='A', group=g3, courses=[c5]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) @db_session def test_multiset_repr_1(self): @@ -138,5 +146,6 @@ def test_multiset_reduce(self): multiset_1 = pickle.loads(s) self.assertEqual(multiset_1, multiset_2) + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_crud.py b/pony/orm/tests/test_crud.py index ea4fcc0f6..540a92bcf 100644 --- a/pony/orm/tests/test_crud.py +++ b/pony/orm/tests/test_crud.py @@ -6,8 +6,9 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Group(db.Entity): id = PrimaryKey(int) @@ -25,25 +26,31 @@ class Student(db.Entity): group = Optional('Group') class Course(db.Entity): + id = PrimaryKey(int) name = Required(unicode) semester = Required(int) students = Set(Student) composite_key(name, semester) -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(id=1, major='Math') - g2 = Group(id=2, major='Physics') - s1 = Student(id=1, name='S1', age=19, email='s1@example.com', group=g1) - s2 = Student(id=2, name='S2', age=21, email='s2@example.com', group=g1) - s3 = Student(id=3, name='S3', email='s3@example.com', group=g2) - c1 = Course(name='Math', semester=1) - c2 = Course(name='Math', semester=2) - c3 = Course(name='Physics', semester=1) - class TestCRUD(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(id=1, major='Math') + g2 = Group(id=2, major='Physics') + s1 = Student(id=1, name='S1', age=19, email='s1@example.com', group=g1) + s2 = Student(id=2, name='S2', age=21, email='s2@example.com', group=g1) + s3 = Student(id=3, name='S3', email='s3@example.com', group=g2) + c1 = Course(id=1, name='Math', semester=1) + c2 = Course(id=2, name='Math', semester=2) + c3 = Course(id=3, name='Physics', semester=1) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() @@ -121,6 +128,7 @@ def test_validate_3(self): @raises_exception(ValueError, "Value type for attribute Group.id must be int. Got string 'not a number'") def test_validate_5(self): s4 = Student(id=3, name='S4', email='s4@example.com', group='not a number') + @raises_exception(TypeError, "Attribute Student.group must be of Group type. Got: datetime.date(2011, 1, 1)") def test_validate_6(self): s4 = Student(id=3, name='S4', email='s4@example.com', group=date(2011, 1, 1)) diff --git a/pony/orm/tests/test_crud_raw_sql.py b/pony/orm/tests/test_crud_raw_sql.py index 1ebab6651..03dd23348 100644 --- a/pony/orm/tests/test_crud_raw_sql.py +++ b/pony/orm/tests/test_crud_raw_sql.py @@ -4,8 +4,9 @@ from pony.orm.core import * from pony.orm.tests.testutils import raises_exception +from pony.orm.tests import setup_database, teardown_database, only_for -db = Database('sqlite', ':memory:') +db = Database() class Student(db.Entity): name = Required(unicode) @@ -25,9 +26,17 @@ class Bio(db.Entity): desc = Required(unicode) Student = Required(Student) -db.generate_mapping(create_tables=True) +@only_for('sqlite') class TestCrudRawSQL(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): with db_session: db.execute('delete from Student') diff --git a/pony/orm/tests/test_datetime.py b/pony/orm/tests/test_datetime.py index 42448efb5..7bb60accb 100644 --- a/pony/orm/tests/test_datetime.py +++ b/pony/orm/tests/test_datetime.py @@ -6,22 +6,30 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() -db = Database('sqlite', ':memory:') class Entity1(db.Entity): id = PrimaryKey(int) d = Required(date) dt = Required(datetime) -db.generate_mapping(create_tables=True) - -with db_session: - Entity1(id=1, d=date(2009, 10, 20), dt=datetime(2009, 10, 20, 10, 20, 30)) - Entity1(id=2, d=date(2010, 10, 21), dt=datetime(2010, 10, 21, 10, 21, 31)) - Entity1(id=3, d=date(2011, 11, 22), dt=datetime(2011, 11, 22, 10, 20, 32)) class TestDate(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + Entity1(id=1, d=date(2009, 10, 20), dt=datetime(2009, 10, 20, 10, 20, 30)) + Entity1(id=2, d=date(2010, 10, 21), dt=datetime(2010, 10, 21, 10, 21, 31)) + Entity1(id=3, d=date(2011, 11, 22), dt=datetime(2011, 11, 22, 10, 20, 32)) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() diff --git a/pony/orm/tests/test_db_session.py b/pony/orm/tests/test_db_session.py index 8f0d8795d..09cd08ed0 100644 --- a/pony/orm/tests/test_db_session.py +++ b/pony/orm/tests/test_db_session.py @@ -7,19 +7,25 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + class TestDBSession(unittest.TestCase): def setUp(self): - self.db = Database('sqlite', ':memory:') + self.db = Database() class X(self.db.Entity): a = PrimaryKey(int) b = Optional(int) self.X = X - self.db.generate_mapping(create_tables=True) + setup_database(self.db) with db_session: x1 = X(a=1, b=1) x2 = X(a=2, b=2) + def tearDown(self): + if self.db.provider.dialect != 'SQLite': + teardown_database(self.db) + @raises_exception(TypeError, "Pass only keyword arguments to db_session or use db_session as decorator") def test_db_session_1(self): db_session(1, 2, 3) @@ -375,7 +381,8 @@ def test_db_session_exceptions_4(self): connection.close() 1/0 -db = Database('sqlite', ':memory:') + +db = Database() class Group(db.Entity): id = PrimaryKey(int) @@ -387,17 +394,21 @@ class Student(db.Entity): picture = Optional(buffer, lazy=True) group = Required('Group') -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(id=1, major='Math') - g2 = Group(id=2, major='Physics') - s1 = Student(id=1, name='S1', group=g1) - s2 = Student(id=2, name='S2', group=g1) - s3 = Student(id=3, name='S3', group=g2) +class TestDBSessionScope(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(id=1, major='Math') + g2 = Group(id=2, major='Physics') + s1 = Student(id=1, name='S1', group=g1) + s2 = Student(id=2, name='S2', group=g1) + s3 = Student(id=3, name='S3', group=g2) + @classmethod + def tearDownClass(cls): + teardown_database(db) -class TestDBSessionScope(unittest.TestCase): def setUp(self): rollback() diff --git a/pony/orm/tests/test_declarative_attr_set_monad.py b/pony/orm/tests/test_declarative_attr_set_monad.py index 6217277f1..d14557342 100644 --- a/pony/orm/tests/test_declarative_attr_set_monad.py +++ b/pony/orm/tests/test_declarative_attr_set_monad.py @@ -4,8 +4,9 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Student(db.Entity): name = Required(unicode) @@ -30,38 +31,44 @@ class Mark(db.Entity): subject = Required(Subject) PrimaryKey(student, subject) -db.generate_mapping(create_tables=True) - -with db_session: - g41 = Group(number=41, department=101) - g42 = Group(number=42, department=102) - g43 = Group(number=43, department=102) - g44 = Group(number=44, department=102) - - s1 = Student(id=1, name="Joe", scholarship=None, group=g41) - s2 = Student(id=2, name="Bob", scholarship=100, group=g41) - s3 = Student(id=3, name="Beth", scholarship=500, group=g41) - s4 = Student(id=4, name="Jon", scholarship=500, group=g42) - s5 = Student(id=5, name="Pete", scholarship=700, group=g42) - s6 = Student(id=6, name="Mary", scholarship=300, group=g44) - - Math = Subject(name="Math") - Physics = Subject(name="Physics") - History = Subject(name="History") - - g41.subjects = [ Math, Physics, History ] - g42.subjects = [ Math, Physics ] - g43.subjects = [ Physics ] - - Mark(value=5, student=s1, subject=Math) - Mark(value=4, student=s2, subject=Physics) - Mark(value=3, student=s2, subject=Math) - Mark(value=2, student=s2, subject=History) - Mark(value=1, student=s3, subject=History) - Mark(value=2, student=s3, subject=Math) - Mark(value=2, student=s4, subject=Math) class TestAttrSetMonad(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g41 = Group(number=41, department=101) + g42 = Group(number=42, department=102) + g43 = Group(number=43, department=102) + g44 = Group(number=44, department=102) + + s1 = Student(id=1, name="Joe", scholarship=None, group=g41) + s2 = Student(id=2, name="Bob", scholarship=100, group=g41) + s3 = Student(id=3, name="Beth", scholarship=500, group=g41) + s4 = Student(id=4, name="Jon", scholarship=500, group=g42) + s5 = Student(id=5, name="Pete", scholarship=700, group=g42) + s6 = Student(id=6, name="Mary", scholarship=300, group=g44) + + Math = Subject(name="Math") + Physics = Subject(name="Physics") + History = Subject(name="History") + + g41.subjects = [Math, Physics, History] + g42.subjects = [Math, Physics] + g43.subjects = [Physics] + + Mark(value=5, student=s1, subject=Math) + Mark(value=4, student=s2, subject=Physics) + Mark(value=3, student=s2, subject=Math) + Mark(value=2, student=s2, subject=History) + Mark(value=1, student=s3, subject=History) + Mark(value=2, student=s3, subject=Math) + Mark(value=2, student=s4, subject=Math) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() @@ -81,7 +88,7 @@ def test3(self): self.assertEqual(groups, [Group[41]]) def test3a(self): groups = select(g for g in Group if len(g.students.marks) < 2)[:] - self.assertEqual(groups, [Group[42], Group[43], Group[44]]) + self.assertEqual(set(groups), {Group[42], Group[43], Group[44]}) def test4(self): groups = select(g for g in Group if max(g.students.marks.value) <= 2)[:] self.assertEqual(groups, [Group[42]]) diff --git a/pony/orm/tests/test_declarative_exceptions.py b/pony/orm/tests/test_declarative_exceptions.py index b4fa44631..6154086ff 100644 --- a/pony/orm/tests/test_declarative_exceptions.py +++ b/pony/orm/tests/test_declarative_exceptions.py @@ -8,8 +8,9 @@ from pony.orm.core import * from pony.orm.sqltranslation import IncomparableTypesError from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Student(db.Entity): name = Required(unicode) @@ -34,16 +35,20 @@ class Course(db.Entity): PrimaryKey(name, semester) students = Set(Student) -db.generate_mapping(create_tables=True) - -with db_session: - d1 = Department(number=44) - g1 = Group(number=101, dept=d1) - Student(name='S1', group=g1) - Student(name='S2', group=g1) - Student(name='S3', group=g1) class TestSQLTranslatorExceptions(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + d1 = Department(number=44) + g1 = Group(number=101, dept=d1) + Student(name='S1', group=g1) + Student(name='S2', group=g1) + Student(name='S3', group=g1) + @classmethod + def tearDownClass(cls): + teardown_database(db) def setUp(self): rollback() db_session.__enter__() @@ -118,9 +123,6 @@ def test18(self): % unicode.__name__) def test19(self): select(s for s in Student if s.name[1:'a'] == 'A') - @raises_exception(NotImplementedError, "Negative indices are not supported in string slice s.name[-1:1]") - def test20(self): - select(s for s in Student if s.name[-1:1] == 'A') @raises_exception(TypeError, "String indices must be integers. Got '%s' in expression s.name['a']" % unicode.__name__) def test21(self): select(s.name for s in Student if s.name['a'] == 'h') diff --git a/pony/orm/tests/test_declarative_func_monad.py b/pony/orm/tests/test_declarative_func_monad.py index 2a7b3f0a3..707ed5e11 100644 --- a/pony/orm/tests/test_declarative_func_monad.py +++ b/pony/orm/tests/test_declarative_func_monad.py @@ -8,8 +8,9 @@ from pony.orm.core import * from pony.orm.sqltranslation import IncomparableTypesError from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Student(db.Entity): id = PrimaryKey(int) @@ -25,29 +26,31 @@ class Group(db.Entity): students = Set(Student) -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(number=1) - g2 = Group(number=2) - - Student(id=1, name="AA", dob=date(1981, 1, 1), last_visit=datetime(2011, 1, 1, 11, 11, 11), - scholarship=Decimal("0"), phd=True, group=g1) - - Student(id=2, name="BB", dob=date(1982, 2, 2), last_visit=datetime(2011, 2, 2, 12, 12, 12), - scholarship=Decimal("202.2"), phd=True, group=g1) - - Student(id=3, name="CC", dob=date(1983, 3, 3), last_visit=datetime(2011, 3, 3, 13, 13, 13), - scholarship=Decimal("303.3"), phd=False, group=g1) - - Student(id=4, name="DD", dob=date(1984, 4, 4), last_visit=datetime(2011, 4, 4, 14, 14, 14), - scholarship=Decimal("404.4"), phd=False, group=g2) - - Student(id=5, name="EE", dob=date(1985, 5, 5), last_visit=datetime(2011, 5, 5, 15, 15, 15), - scholarship=Decimal("505.5"), phd=False, group=g2) - - class TestFuncMonad(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(number=1) + g2 = Group(number=2) + + Student(id=1, name="AA", dob=date(1981, 1, 1), last_visit=datetime(2011, 1, 1, 11, 11, 11), + scholarship=Decimal("0"), phd=True, group=g1) + + Student(id=2, name="BB", dob=date(1982, 2, 2), last_visit=datetime(2011, 2, 2, 12, 12, 12), + scholarship=Decimal("202.2"), phd=True, group=g1) + + Student(id=3, name="CC", dob=date(1983, 3, 3), last_visit=datetime(2011, 3, 3, 13, 13, 13), + scholarship=Decimal("303.3"), phd=False, group=g1) + + Student(id=4, name="DD", dob=date(1984, 4, 4), last_visit=datetime(2011, 4, 4, 14, 14, 14), + scholarship=Decimal("404.4"), phd=False, group=g2) + + Student(id=5, name="EE", dob=date(1985, 5, 5), last_visit=datetime(2011, 5, 5, 15, 15, 15), + scholarship=Decimal("505.5"), phd=False, group=g2) + @classmethod + def tearDownClass(cls): + teardown_database(db) def setUp(self): rollback() db_session.__enter__() @@ -130,7 +133,10 @@ def test_decimal_func(self): self.assertEqual(result, {Student[3], Student[4], Student[5]}) def test_concat_1(self): result = set(select(concat(s.name, ':', s.dob.year, ':', s.scholarship) for s in Student)) - self.assertEqual(result, {'AA:1981:0', 'BB:1982:202.2', 'CC:1983:303.3', 'DD:1984:404.4', 'EE:1985:505.5'}) + if db.provider.dialect == 'PostgreSQL': + self.assertEqual(result, {'AA:1981:0.00', 'BB:1982:202.20', 'CC:1983:303.30', 'DD:1984:404.40', 'EE:1985:505.50'}) + else: + self.assertEqual(result, {'AA:1981:0', 'BB:1982:202.2', 'CC:1983:303.3', 'DD:1984:404.4', 'EE:1985:505.5'}) @raises_exception(TranslationError, 'Invalid argument of concat() function: g.students') def test_concat_2(self): result = set(select(concat(g.number, g.students) for g in Group)) diff --git a/pony/orm/tests/test_declarative_join_optimization.py b/pony/orm/tests/test_declarative_join_optimization.py index 8bc380a77..98eb8b99a 100644 --- a/pony/orm/tests/test_declarative_join_optimization.py +++ b/pony/orm/tests/test_declarative_join_optimization.py @@ -5,8 +5,9 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Department(db.Entity): name = Required(str) @@ -37,9 +38,13 @@ class Student(db.Entity): courses = Set(Course) -db.generate_mapping(create_tables=True) - class TestM2MOptimization(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + @classmethod + def tearDownClass(cls): + teardown_database(db) def setUp(self): rollback() db_session.__enter__() @@ -68,8 +73,17 @@ def test5(self): def test7(self): q = select(s for s in Student if sum(c.credits for c in Course if s.group.dept == c.dept) > 10) objects = q[:] - self.assertEqual(str(q._translator.sqlquery.from_ast), - "['FROM', ['s', 'TABLE', 'Student'], ['group', 'TABLE', 'Group', ['EQ', ['COLUMN', 's', 'group'], ['COLUMN', 'group', 'number']]]]") + student_table_name = 'Student' + group_table_name = 'Group' + if not (db.provider.dialect == 'SQLite' and pony.__version__ < '0.9'): + student_table_name = student_table_name.lower() + group_table_name = group_table_name.lower() + self.assertEqual(q._translator.sqlquery.from_ast, [ + 'FROM', ['s', 'TABLE', student_table_name], + ['group', 'TABLE', group_table_name, + ['EQ', ['COLUMN', 's', 'group'], ['COLUMN', 'group', 'number']] + ] + ]) if __name__ == '__main__': diff --git a/pony/orm/tests/test_declarative_object_flat_monad.py b/pony/orm/tests/test_declarative_object_flat_monad.py index c3c33c0c5..4b695b76b 100644 --- a/pony/orm/tests/test_declarative_object_flat_monad.py +++ b/pony/orm/tests/test_declarative_object_flat_monad.py @@ -2,8 +2,9 @@ import unittest from pony.orm.core import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Student(db.Entity): name = Required(unicode) @@ -28,32 +29,38 @@ class Mark(db.Entity): subject = Required(Subject) PrimaryKey(student, subject) -db.generate_mapping(create_tables=True) -with db_session: - Math = Subject(name="Math") - Physics = Subject(name="Physics") - History = Subject(name="History") +class TestObjectFlatMonad(unittest.TestCase): + @classmethod + def setUpClass(self): + setup_database(db) + with db_session: + Math = Subject(name="Math") + Physics = Subject(name="Physics") + History = Subject(name="History") - g41 = Group(number=41, department=101, subjects=[ Math, Physics, History ]) - g42 = Group(number=42, department=102, subjects=[ Math, Physics ]) - g43 = Group(number=43, department=102, subjects=[ Physics ]) + g41 = Group(number=41, department=101, subjects=[Math, Physics, History]) + g42 = Group(number=42, department=102, subjects=[Math, Physics]) + g43 = Group(number=43, department=102, subjects=[Physics]) - s1 = Student(id=1, name="Joe", scholarship=None, group=g41) - s2 = Student(id=2, name="Bob", scholarship=100, group=g41) - s3 = Student(id=3, name="Beth", scholarship=500, group=g41) - s4 = Student(id=4, name="Jon", scholarship=500, group=g42) - s5 = Student(id=5, name="Pete", scholarship=700, group=g42) + s1 = Student(id=1, name="Joe", scholarship=None, group=g41) + s2 = Student(id=2, name="Bob", scholarship=100, group=g41) + s3 = Student(id=3, name="Beth", scholarship=500, group=g41) + s4 = Student(id=4, name="Jon", scholarship=500, group=g42) + s5 = Student(id=5, name="Pete", scholarship=700, group=g42) - Mark(value=5, student=s1, subject=Math) - Mark(value=4, student=s2, subject=Physics) - Mark(value=3, student=s2, subject=Math) - Mark(value=2, student=s2, subject=History) - Mark(value=1, student=s3, subject=History) - Mark(value=2, student=s3, subject=Math) - Mark(value=2, student=s4, subject=Math) + Mark(value=5, student=s1, subject=Math) + Mark(value=4, student=s2, subject=Physics) + Mark(value=3, student=s2, subject=Math) + Mark(value=2, student=s2, subject=History) + Mark(value=1, student=s3, subject=History) + Mark(value=2, student=s3, subject=Math) + Mark(value=2, student=s4, subject=Math) + + @classmethod + def tearDownClass(cls): + teardown_database(db) -class TestObjectFlatMonad(unittest.TestCase): @db_session def test1(self): result = set(select(s.groups for s in Subject if len(s.name) == 4)) diff --git a/pony/orm/tests/test_declarative_orderby_limit.py b/pony/orm/tests/test_declarative_orderby_limit.py index b7467dfba..bb906f392 100644 --- a/pony/orm/tests/test_declarative_orderby_limit.py +++ b/pony/orm/tests/test_declarative_orderby_limit.py @@ -4,24 +4,31 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Student(db.Entity): name = Required(unicode) scholarship = Optional(int) group = Required(int) -db.generate_mapping(create_tables=True) - -with db_session: - Student(id=1, name="B", scholarship=None, group=41) - Student(id=2, name="C", scholarship=700, group=41) - Student(id=3, name="A", scholarship=500, group=42) - Student(id=4, name="D", scholarship=500, group=43) - Student(id=5, name="E", scholarship=700, group=42) class TestOrderbyLimit(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + Student(id=1, name="B", scholarship=None, group=41) + Student(id=2, name="C", scholarship=700, group=41) + Student(id=3, name="A", scholarship=500, group=42) + Student(id=4, name="D", scholarship=500, group=43) + Student(id=5, name="E", scholarship=700, group=42) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() @@ -119,7 +126,11 @@ def test19(self): def test20(self): q = select(s for s in Student).limit(offset=2) self.assertEqual(set(q), {Student[3], Student[4], Student[5]}) - self.assertTrue('LIMIT -1 OFFSET 2' in db.last_sql) + last_sql = db.last_sql + if db.provider.dialect == 'PostgreSQL': + self.assertTrue('LIMIT null OFFSET 2' in last_sql) + else: + self.assertTrue('LIMIT -1 OFFSET 2' in last_sql) def test21(self): q = select(s for s in Student).limit(0, offset=2) diff --git a/pony/orm/tests/test_declarative_query_set_monad.py b/pony/orm/tests/test_declarative_query_set_monad.py index 7c545a2d4..3edc8f52e 100644 --- a/pony/orm/tests/test_declarative_query_set_monad.py +++ b/pony/orm/tests/test_declarative_query_set_monad.py @@ -4,8 +4,9 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Group(db.Entity): id = PrimaryKey(int) @@ -24,20 +25,25 @@ class Course(db.Entity): PrimaryKey(name, semester) students = Set('Student') -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(id=1) - g2 = Group(id=2) - s1 = Student(id=1, name='S1', age=20, group=g1, scholarship=0) - s2 = Student(id=2, name='S2', age=23, group=g1, scholarship=100) - s3 = Student(id=3, name='S3', age=23, group=g2, scholarship=500) - c1 = Course(name='C1', semester=1, students=[s1, s2]) - c2 = Course(name='C2', semester=1, students=[s2, s3]) - c3 = Course(name='C3', semester=2, students=[s3]) - class TestQuerySetMonad(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(id=1) + g2 = Group(id=2) + s1 = Student(id=1, name='S1', age=20, group=g1, scholarship=0) + s2 = Student(id=2, name='S2', age=23, group=g1, scholarship=100) + s3 = Student(id=3, name='S3', age=23, group=g2, scholarship=500) + c1 = Course(name='C1', semester=1, students=[s1, s2]) + c2 = Course(name='C2', semester=1, students=[s2, s3]) + c3 = Course(name='C3', semester=2, students=[s3]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() diff --git a/pony/orm/tests/test_declarative_sqltranslator.py b/pony/orm/tests/test_declarative_sqltranslator.py index ccddd5779..477d5c1ec 100644 --- a/pony/orm/tests/test_declarative_sqltranslator.py +++ b/pony/orm/tests/test_declarative_sqltranslator.py @@ -6,8 +6,9 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Department(db.Entity): number = PrimaryKey(int) @@ -53,43 +54,48 @@ class Room(db.Entity): name = PrimaryKey(unicode) groups = Set(Group) -db.generate_mapping(create_tables=True) -with db_session: - d1 = Department(number=44) - d2 = Department(number=43) - g1 = Group(id=1, dept=d1) - g2 = Group(id=2, dept=d2) - s1 = Student(id=1, name='S1', group=g1, scholarship=0) - s2 = Student(id=2, name='S2', group=g1, scholarship=100) - s3 = Student(id=3, name='S3', group=g2, scholarship=500) - c1 = Course(name='Math', semester=1, dept=d1) - c2 = Course(name='Economics', semester=1, dept=d1, credits=3) - c3 = Course(name='Physics', semester=2, dept=d2) - t1 = Teacher(id=101, name="T1") - t2 = Teacher(id=102, name="T2") - Grade(student=s1, course=c1, value='C', teacher=t2, date=date(2011, 1, 1)) - Grade(student=s1, course=c3, value='A', teacher=t1, date=date(2011, 2, 1)) - Grade(student=s2, course=c2, value='B', teacher=t1) - r1 = Room(name='Room1') - r2 = Room(name='Room2') - r3 = Room(name='Room3') - g1.rooms = [ r1, r2 ] - g2.rooms = [ r2, r3 ] - c1.students.add(s1) - c1.students.add(s2) - c2.students.add(s2) - -db2 = Database('sqlite', ':memory:') +db2 = Database() class Room2(db2.Entity): name = PrimaryKey(unicode) -db2.generate_mapping(create_tables=True) - name1 = 'S1' + class TestSQLTranslator(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + d1 = Department(number=44) + d2 = Department(number=43) + g1 = Group(id=1, dept=d1) + g2 = Group(id=2, dept=d2) + s1 = Student(id=1, name='S1', group=g1, scholarship=0) + s2 = Student(id=2, name='S2', group=g1, scholarship=100) + s3 = Student(id=3, name='S3', group=g2, scholarship=500) + c1 = Course(name='Math', semester=1, dept=d1) + c2 = Course(name='Economics', semester=1, dept=d1, credits=3) + c3 = Course(name='Physics', semester=2, dept=d2) + t1 = Teacher(id=101, name="T1") + t2 = Teacher(id=102, name="T2") + Grade(student=s1, course=c1, value='C', teacher=t2, date=date(2011, 1, 1)) + Grade(student=s1, course=c3, value='A', teacher=t1, date=date(2011, 2, 1)) + Grade(student=s2, course=c2, value='B', teacher=t1) + r1 = Room(name='Room1') + r2 = Room(name='Room2') + r3 = Room(name='Room3') + g1.rooms = [r1, r2] + g2.rooms = [r2, r3] + c1.students.add(s1) + c1.students.add(s2) + c2.students.add(s2) + setup_database(db2) + @classmethod + def tearDownClass(cls): + teardown_database(db) + teardown_database(db2) def setUp(self): rollback() db_session.__enter__() diff --git a/pony/orm/tests/test_declarative_sqltranslator2.py b/pony/orm/tests/test_declarative_sqltranslator2.py index a45889e60..ccddad377 100644 --- a/pony/orm/tests/test_declarative_sqltranslator2.py +++ b/pony/orm/tests/test_declarative_sqltranslator2.py @@ -7,8 +7,9 @@ from pony.orm.core import * from pony.orm.sqltranslation import IncomparableTypesError from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Department(db.Entity): number = PrimaryKey(int, auto=True) @@ -43,51 +44,56 @@ class Student(db.Entity): group = Required(Group) courses = Set(Course) -db.generate_mapping(create_tables=True) - -with db_session: - d1 = Department(name="Department of Computer Science") - d2 = Department(name="Department of Mathematical Sciences") - d3 = Department(name="Department of Applied Physics") - - c1 = Course(name="Web Design", semester=1, dept=d1, - lect_hours=30, lab_hours=30, credits=3) - c2 = Course(name="Data Structures and Algorithms", semester=3, dept=d1, - lect_hours=40, lab_hours=20, credits=4) - - c3 = Course(name="Linear Algebra", semester=1, dept=d2, - lect_hours=30, lab_hours=30, credits=4) - c4 = Course(name="Statistical Methods", semester=2, dept=d2, - lect_hours=50, lab_hours=25, credits=5) - - c5 = Course(name="Thermodynamics", semester=2, dept=d3, - lect_hours=25, lab_hours=40, credits=4) - c6 = Course(name="Quantum Mechanics", semester=3, dept=d3, - lect_hours=40, lab_hours=30, credits=5) - - g101 = Group(number=101, major='B.E. in Computer Engineering', dept=d1) - g102 = Group(number=102, major='B.S./M.S. in Computer Science', dept=d2) - g103 = Group(number=103, major='B.S. in Applied Mathematics and Statistics', dept=d2) - g104 = Group(number=104, major='B.S./M.S. in Pure Mathematics', dept=d2) - g105 = Group(number=105, major='B.E in Electronics', dept=d3) - g106 = Group(number=106, major='B.S./M.S. in Nuclear Engineering', dept=d3) - - Student(name='John Smith', dob=date(1991, 3, 20), tel='123-456', gpa=3, group=g101, phd=True, - courses=[c1, c2, c4, c6]) - Student(name='Matthew Reed', dob=date(1990, 11, 26), gpa=3.5, group=g101, phd=True, - courses=[c1, c3, c4, c5]) - Student(name='Chuan Qin', dob=date(1989, 2, 5), gpa=4, group=g101, - courses=[c3, c5, c6]) - Student(name='Rebecca Lawson', dob=date(1990, 4, 18), tel='234-567', gpa=3.3, group=g102, - courses=[c1, c4, c5, c6]) - Student(name='Maria Ionescu', dob=date(1991, 4, 23), gpa=3.9, group=g102, - courses=[c1, c2, c4, c6]) - Student(name='Oliver Blakey', dob=date(1990, 9, 8), gpa=3.1, group=g102, - courses=[c1, c2, c5]) - Student(name='Jing Xia', dob=date(1988, 12, 30), gpa=3.2, group=g102, - courses=[c1, c3, c5, c6]) - class TestSQLTranslator2(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + d1 = Department(number=1, name="Department of Computer Science") + d2 = Department(number=2, name="Department of Mathematical Sciences") + d3 = Department(number=3, name="Department of Applied Physics") + + c1 = Course(name="Web Design", semester=1, dept=d1, + lect_hours=30, lab_hours=30, credits=3) + c2 = Course(name="Data Structures and Algorithms", semester=3, dept=d1, + lect_hours=40, lab_hours=20, credits=4) + + c3 = Course(name="Linear Algebra", semester=1, dept=d2, + lect_hours=30, lab_hours=30, credits=4) + c4 = Course(name="Statistical Methods", semester=2, dept=d2, + lect_hours=50, lab_hours=25, credits=5) + + c5 = Course(name="Thermodynamics", semester=2, dept=d3, + lect_hours=25, lab_hours=40, credits=4) + c6 = Course(name="Quantum Mechanics", semester=3, dept=d3, + lect_hours=40, lab_hours=30, credits=5) + + g101 = Group(number=101, major='B.E. in Computer Engineering', dept=d1) + g102 = Group(number=102, major='B.S./M.S. in Computer Science', dept=d2) + g103 = Group(number=103, major='B.S. in Applied Mathematics and Statistics', dept=d2) + g104 = Group(number=104, major='B.S./M.S. in Pure Mathematics', dept=d2) + g105 = Group(number=105, major='B.E in Electronics', dept=d3) + g106 = Group(number=106, major='B.S./M.S. in Nuclear Engineering', dept=d3) + + Student(id=1, name='John Smith', dob=date(1991, 3, 20), tel='123-456', gpa=3, group=g101, phd=True, + courses=[c1, c2, c4, c6]) + Student(id=2, name='Matthew Reed', dob=date(1990, 11, 26), gpa=3.5, group=g101, phd=True, + courses=[c1, c3, c4, c5]) + Student(id=3, name='Chuan Qin', dob=date(1989, 2, 5), gpa=4, group=g101, + courses=[c3, c5, c6]) + Student(id=4, name='Rebecca Lawson', dob=date(1990, 4, 18), tel='234-567', gpa=3.3, group=g102, + courses=[c1, c4, c5, c6]) + Student(id=5, name='Maria Ionescu', dob=date(1991, 4, 23), gpa=3.9, group=g102, + courses=[c1, c2, c4, c6]) + Student(id=6, name='Oliver Blakey', dob=date(1990, 9, 8), gpa=3.1, group=g102, + courses=[c1, c2, c5]) + Student(id=7, name='Jing Xia', dob=date(1988, 12, 30), gpa=3.2, group=g102, + courses=[c1, c3, c5, c6]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() diff --git a/pony/orm/tests/test_declarative_strings.py b/pony/orm/tests/test_declarative_strings.py index 2706e82c9..f82b6ff18 100644 --- a/pony/orm/tests/test_declarative_strings.py +++ b/pony/orm/tests/test_declarative_strings.py @@ -4,24 +4,31 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Student(db.Entity): - name = Required(unicode, autostrip=False) - foo = Optional(unicode) - bar = Optional(unicode) - -db.generate_mapping(create_tables=True) - -with db_session: - Student(id=1, name="Jon", foo='Abcdef', bar='b%d') - Student(id=2, name=" Bob ", foo='Ab%def', bar='b%d') - Student(id=3, name=" Beth ", foo='Ab_def', bar='b%d') - Student(id=4, name="Jonathan") - Student(id=5, name="Pete") + name = Required(str) + unstripped = Required(str, autostrip=False) + foo = Optional(str) + bar = Optional(str) class TestStringMethods(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + Student(id=1, name="Ann", unstripped="Ann", foo='Abcdef', bar='b%d') + Student(id=2, name="Bob", unstripped=" Bob ", foo='Ab%def', bar='b%d') + Student(id=3, name="Beth", unstripped=" Beth ", foo='Ab_def', bar='b%d') + Student(id=4, name="Jonathan", unstripped="\nJonathan\n") + Student(id=5, name="Pete", unstripped="\n Pete\n ") + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() @@ -30,106 +37,129 @@ def tearDown(self): rollback() db_session.__exit__() - def test_nonzero(self): - result = set(select(s for s in Student if s.foo)) - self.assertEqual(result, {Student[1], Student[2], Student[3]}) - - def test_add(self): - name = 'Jonny' - result = set(select(s for s in Student if s.name + "ny" == name)) + def test_getitem_01(self): + result = set(select(s for s in Student if s.name[:] == 'Ann')) self.assertEqual(result, {Student[1]}) + def test_getitem_1(self): + result = set(select(s for s in Student if s.name[1] == 'o')) + self.assertEqual(result, {Student[2], Student[4]}) + + def test_getitem_2(self): + x = 1 + result = set(select(s for s in Student if s.name[x] == 'o')) + self.assertEqual(result, {Student[2], Student[4]}) + + def test_getitem_3(self): + result = set(select(s for s in Student if s.name[-1] == 'n')) + self.assertEqual(result, {Student[1], Student[4]}) + + def test_getitem_4(self): + x = -1 + result = set(select(s for s in Student if s.name[x] == 'n')) + self.assertEqual(result, {Student[1], Student[4]}) + + def test_getitem_5(self): + result = set(select(s for s in Student if s.name[-2] == 't')) + self.assertEqual(result, {Student[3], Student[5]}) + + @sql_debugging + def test_getitem_6(self): + x = -2 + select((s.name, s.name[x]) for s in Student).show() + result = set(select(s for s in Student if s.name[x] == 't')) + self.assertEqual(result, {Student[3], Student[5]}) + def test_slice_1(self): result = set(select(s for s in Student if s.name[0:3] == "Jon")) - self.assertEqual(result, {Student[1], Student[4]}) + self.assertEqual(result, {Student[4]}) def test_slice_2(self): result = set(select(s for s in Student if s.name[:3] == "Jon")) - self.assertEqual(result, {Student[1], Student[4]}) + self.assertEqual(result, {Student[4]}) def test_slice_3(self): x = 3 result = set(select(s for s in Student if s.name[:x] == "Jon")) - self.assertEqual(result, {Student[1], Student[4]}) + self.assertEqual(result, {Student[4]}) def test_slice_4(self): x = 3 result = set(select(s for s in Student if s.name[0:x] == "Jon")) - self.assertEqual(result, {Student[1], Student[4]}) + self.assertEqual(result, {Student[4]}) def test_slice_5(self): - result = set(select(s for s in Student if s.name[0:10] == "Jon")) + result = set(select(s for s in Student if s.name[0:10] == "Ann")) self.assertEqual(result, {Student[1]}) def test_slice_6(self): - result = set(select(s for s in Student if s.name[0:] == "Jon")) + result = set(select(s for s in Student if s.name[0:] == "Ann")) self.assertEqual(result, {Student[1]}) def test_slice_7(self): - result = set(select(s for s in Student if s.name[:] == "Jon")) + result = set(select(s for s in Student if s.name[:] == "Ann")) self.assertEqual(result, {Student[1]}) def test_slice_8(self): - result = set(select(s for s in Student if s.name[1:] == "on")) + result = set(select(s for s in Student if s.name[1:] == "nn")) self.assertEqual(result, {Student[1]}) def test_slice_9(self): x = 1 - result = set(select(s for s in Student if s.name[x:] == "on")) + result = set(select(s for s in Student if s.name[x:] == "nn")) self.assertEqual(result, {Student[1]}) def test_slice_10(self): x = 0 - result = set(select(s for s in Student if s.name[x:3] == "Jon")) - self.assertEqual(result, {Student[1], Student[4]}) + result = set(select(s for s in Student if s.name[x:3] == "Ann")) + self.assertEqual(result, {Student[1]}) def test_slice_11(self): + result = set(select(s for s in Student if s.name[1:3] == "et")) + self.assertEqual(result, {Student[3], Student[5]}) + + def test_slice_12(self): x = 1 y = 3 - result = set(select(s for s in Student if s.name[x:y] == "on")) - self.assertEqual(result, {Student[1], Student[4]}) + result = set(select(s for s in Student if s.name[x:y] == "et")) + self.assertEqual(result, {Student[3], Student[5]}) - def test_slice_12(self): + def test_slice_13(self): x = 10 y = 20 result = set(select(s for s in Student if s.name[x:y] == '')) self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) - def test_getitem_1(self): - result = set(select(s for s in Student if s.name[1] == 'o')) - self.assertEqual(result, {Student[1], Student[4]}) - - def test_getitem_2(self): - x = 1 - result = set(select(s for s in Student if s.name[x] == 'o')) - self.assertEqual(result, {Student[1], Student[4]}) + def test_slice_14(self): + result = set(select(s for s in Student if s.name[-2:] == "nn")) + self.assertEqual(result, {Student[1]}) - def test_getitem_3(self): - result = set(select(s for s in Student if s.name[-1] == 'n')) - self.assertEqual(result, {Student[1], Student[4]}) + def test_nonzero(self): + result = set(select(s for s in Student if s.foo)) + self.assertEqual(result, {Student[1], Student[2], Student[3]}) - def test_getitem_4(self): - x = -1 - result = set(select(s for s in Student if s.name[x] == 'n')) - self.assertEqual(result, {Student[1], Student[4]}) + def test_add(self): + name = 'Bethy' + result = set(select(s for s in Student if s.name + "y" == name)) + self.assertEqual(result, {Student[3]}) def test_contains_1(self): result = set(select(s for s in Student if 'o' in s.name)) - self.assertEqual(result, {Student[1], Student[2], Student[4]}) + self.assertEqual(result, {Student[2], Student[4]}) def test_contains_2(self): - result = set(select(s for s in Student if 'on' in s.name)) - self.assertEqual(result, {Student[1], Student[4]}) + result = set(select(s for s in Student if 'an' in s.name)) + self.assertEqual(result, {Student[4]}) def test_contains_3(self): - x = 'on' + x = 'an' result = set(select(s for s in Student if x in s.name)) - self.assertEqual(result, {Student[1], Student[4]}) + self.assertEqual(result, {Student[4]}) def test_contains_4(self): - x = 'on' + x = 'an' result = set(select(s for s in Student if x not in s.name)) - self.assertEqual(result, {Student[2], Student[3], Student[5]}) + self.assertEqual(result, {Student[1], Student[2], Student[3], Student[5]}) def test_contains_5(self): result = set(select(s for s in Student if '%' in s.foo)) @@ -158,20 +188,20 @@ def test_contains_10(self): self.assertEqual(result, {Student[2], Student[4], Student[5]}) def test_startswith_1(self): - students = set(select(s for s in Student if s.name.startswith('J'))) - self.assertEqual(students, {Student[1], Student[4]}) + students = set(select(s for s in Student if s.name.startswith('B'))) + self.assertEqual(students, {Student[2], Student[3]}) def test_startswith_2(self): - students = set(select(s for s in Student if not s.name.startswith('J'))) - self.assertEqual(students, {Student[2], Student[3], Student[5]}) + students = set(select(s for s in Student if not s.name.startswith('B'))) + self.assertEqual(students, {Student[1], Student[4], Student[5]}) def test_startswith_3(self): - students = set(select(s for s in Student if not not s.name.startswith('J'))) - self.assertEqual(students, {Student[1], Student[4]}) + students = set(select(s for s in Student if not not s.name.startswith('B'))) + self.assertEqual(students, {Student[2], Student[3]}) def test_startswith_4(self): - students = set(select(s for s in Student if not not not s.name.startswith('J'))) - self.assertEqual(students, {Student[2], Student[3], Student[5]}) + students = set(select(s for s in Student if not not not s.name.startswith('B'))) + self.assertEqual(students, {Student[1], Student[4], Student[5]}) def test_startswith_5(self): x = "Pe" @@ -191,8 +221,13 @@ def test_strip_1(self): students = select(s for s in Student if s.name.strip() == 'Beth')[:] self.assertEqual(students, [Student[3]]) - def test_rstrip(self): - students = select(s for s in Student if s.name.rstrip('n') == 'Jo')[:] + def test_rstrip_1(self): + students = select(s for s in Student if s.name.rstrip('n') == 'A')[:] + self.assertEqual(students, [Student[1]]) + + def test_rstrip_2(self): + x = 'n' + students = select(s for s in Student if s.name.rstrip(x) == 'A')[:] self.assertEqual(students, [Student[1]]) def test_lstrip(self): @@ -200,11 +235,11 @@ def test_lstrip(self): self.assertEqual(students, [Student[5]]) def test_upper(self): - result = select(s for s in Student if s.name.upper() == "JON")[:] + result = select(s for s in Student if s.name.upper() == "ANN")[:] self.assertEqual(result, [Student[1]]) def test_lower(self): - result = select(s for s in Student if s.name.lower() == "jon")[:] + result = select(s for s in Student if s.name.lower() == "ann")[:] self.assertEqual(result, [Student[1]]) if __name__ == "__main__": diff --git a/pony/orm/tests/test_deduplication.py b/pony/orm/tests/test_deduplication.py index 6842deb26..146a33181 100644 --- a/pony/orm/tests/test_deduplication.py +++ b/pony/orm/tests/test_deduplication.py @@ -1,39 +1,43 @@ from pony.py23compat import StringIO +from pony.orm import * +from pony.orm.tests import setup_database, teardown_database import unittest -from pony import orm - - -db = orm.Database('sqlite', ':memory:') +db = Database() class A(db.Entity): - id = orm.PrimaryKey(int) - x = orm.Required(bool) - y = orm.Required(float) - -db.generate_mapping(create_tables=True) - -with orm.db_session: - a1 = A(id=1, x=False, y=3.0) - a2 = A(id=2, x=True, y=4.0) - a3 = A(id=3, x=False, y=1.0) + id = PrimaryKey(int) + x = Required(bool) + y = Required(float) class TestDeduplication(unittest.TestCase): - @orm.db_session + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + a1 = A(id=1, x=False, y=3.0) + a2 = A(id=2, x=True, y=4.0) + a3 = A(id=3, x=False, y=1.0) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + + @db_session def test_1(self): a2 = A.get(id=2) a1 = A.get(id=1) self.assertIs(a1.id, 1) - @orm.db_session + @db_session def test_2(self): a3 = A.get(id=3) a1 = A.get(id=1) self.assertIs(a1.id, 1) - @orm.db_session + @db_session def test_3(self): q = A.select().order_by(-1) stream = StringIO() diff --git a/pony/orm/tests/test_diagram.py b/pony/orm/tests/test_diagram.py index 3b525aed5..a812b440f 100644 --- a/pony/orm/tests/test_diagram.py +++ b/pony/orm/tests/test_diagram.py @@ -5,12 +5,13 @@ from pony.orm.core import * from pony.orm.core import Entity from pony.orm.tests.testutils import * +from pony.orm.tests import db_params -class TestDiag(unittest.TestCase): +class TestDiag(unittest.TestCase): @raises_exception(ERDiagramError, 'Entity Entity1 already exists') def test_entity_duplicate(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): id = PrimaryKey(int) class Entity1(db.Entity): @@ -19,66 +20,75 @@ class Entity1(db.Entity): @raises_exception(ERDiagramError, 'Interrelated entities must belong to same database.' ' Entities Entity2 and Entity1 belongs to different databases') def test_diagram1(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2') - db = Database('sqlite', ':memory:') + db.bind(**db_params) + db = Database() class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) + db.bind(**db_params) db.generate_mapping() @raises_exception(ERDiagramError, 'Entity definition Entity2 was not found') def test_diagram2(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2') + db.bind(**db_params) db.generate_mapping() @raises_exception(TypeError, 'Entity1._table_ property must be a string. Got: 123') def test_diagram3(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): _table_ = 123 id = PrimaryKey(int) + db.bind(**db_params) db.generate_mapping() def test_diagram4(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Set('Entity2', table='Table1') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, table='Table1') + db.bind(**db_params) db.generate_mapping(create_tables=True) + db.drop_all_tables(with_all_data=True) def test_diagram5(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Set('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Required(Entity1) + db.bind(**db_params) db.generate_mapping(create_tables=True) + db.drop_all_tables(with_all_data=True) @raises_exception(MappingError, "Parameter 'table' for Entity1.attr1 and Entity2.attr2 do not match") def test_diagram6(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Set('Entity2', table='Table1') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, table='Table2') + db.bind(**db_params) db.generate_mapping() @raises_exception(MappingError, 'Table name "Table1" is already in use') def test_diagram7(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): _table_ = 'Table1' id = PrimaryKey(int) @@ -86,24 +96,35 @@ class Entity1(db.Entity): class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, table='Table1') + db.bind(**db_params) db.generate_mapping() def test_diagram8(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Set('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1) + db.bind(**db_params) db.generate_mapping(create_tables=True) - m2m_table = db.schema.tables['Entity1_Entity2'] - col_names = {col.name for col in m2m_table.column_list} - self.assertEqual(col_names, {'entity1', 'entity2'}) - self.assertEqual(Entity1.attr1.get_m2m_columns(), ['entity1']) + if pony.__version__ >= '0.9': + m2m_table = db.schema.tables['entity1_attr1'] + col_names = set(m2m_table.columns) + self.assertEqual(col_names, {'entity1_id', 'entity2_id'}) + m2m_columns = [c.name for c in Entity1.attr1.meta.m2m_columns] + self.assertEqual(m2m_columns, ['entity1_id']) + else: + table_name = 'Entity1_Entity2' if db.provider.dialect == 'SQLite' else 'entity1_entity2' + m2m_table = db.schema.tables[table_name] + col_names = {col.name for col in m2m_table.column_list} + self.assertEqual(col_names, {'entity1', 'entity2'}) + self.assertEqual(Entity1.attr1.get_m2m_columns(), ['entity1']) + db.drop_all_tables(with_all_data=True) def test_diagram9(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): a = Required(int) b = Required(str) @@ -112,13 +133,21 @@ class Entity1(db.Entity): class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1) + db.bind(**db_params) db.generate_mapping(create_tables=True) - m2m_table = db.schema.tables['Entity1_Entity2'] - col_names = {col.name for col in m2m_table.column_list} - self.assertEqual(col_names, {'entity1_a', 'entity1_b', 'entity2'}) + if pony.__version__ >= '0.9': + m2m_table = db.schema.tables['entity1_attr1'] + col_names = {col for col in m2m_table.columns} + self.assertEqual(col_names, {'entity1_a', 'entity1_b', 'entity2_id'}) + else: + table_name = 'Entity1_Entity2' if db.provider.dialect == 'SQLite' else 'entity1_entity2' + m2m_table = db.schema.tables[table_name] + col_names = set([col.name for col in m2m_table.column_list]) + self.assertEqual(col_names, {'entity1_a', 'entity1_b', 'entity2'}) + db.drop_all_tables(with_all_data=True) def test_diagram10(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): a = Required(int) b = Required(str) @@ -127,11 +156,13 @@ class Entity1(db.Entity): class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, columns=['x', 'y']) + db.bind(**db_params) db.generate_mapping(create_tables=True) + db.drop_all_tables(with_all_data=True) @raises_exception(MappingError, 'Invalid number of columns for Entity2.attr2') def test_diagram11(self): - db = Database('sqlite', ':memory:') + db = Database() class Entity1(db.Entity): a = Required(int) b = Required(str) @@ -140,6 +171,7 @@ class Entity1(db.Entity): class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Set(Entity1, columns=['x']) + db.bind(**db_params) db.generate_mapping() @raises_exception(ERDiagramError, 'Base Entity does not belong to any database') @@ -149,9 +181,10 @@ class Test(Entity): @raises_exception(ERDiagramError, 'Entity class name should start with a capital letter. Got: entity1') def test_diagram13(self): - db = Database('sqlite', ':memory:') + db = Database() class entity1(db.Entity): a = Required(int) + db.bind(**db_params) db.generate_mapping() if __name__ == '__main__': diff --git a/pony/orm/tests/test_diagram_attribute.py b/pony/orm/tests/test_diagram_attribute.py index fccda95ee..f7c29582b 100644 --- a/pony/orm/tests/test_diagram_attribute.py +++ b/pony/orm/tests/test_diagram_attribute.py @@ -6,44 +6,51 @@ from pony.orm.core import * from pony.orm.core import Attribute from pony.orm.tests.testutils import * +from pony.orm.tests import db_params, only_for, setup_database, teardown_database + class TestAttribute(unittest.TestCase): + def setUp(self): + self.db = Database(**db_params) + + def tearDown(self): + teardown_database(self.db) @raises_exception(TypeError, "Attribute Entity1.id has unknown option 'another_option'") def test_attribute1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int, another_option=3) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, 'Cannot link attribute Entity1.b to abstract Entity class. Use specific Entity subclass instead') def test_attribute2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) b = Required(db.Entity) - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(TypeError, 'Default value for required attribute Entity1.b cannot be None') def test_attribute3(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) b = Required(int, default=None) def test_attribute4(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2', reverse='attr2') class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) self.assertEqual(Entity1.attr1.reverse, Entity2.attr2) def test_attribute5(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2') @@ -54,35 +61,35 @@ class Entity2(db.Entity): @raises_exception(TypeError, "Value of 'reverse' option must be name of reverse attribute). Got: 123") def test_attribute6(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2', reverse=123) @raises_exception(TypeError, "Reverse option cannot be set for this type: %r" % str) def test_attribute7(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required(str, reverse='attr1') @raises_exception(TypeError, "'Attribute' is abstract type") def test_attribute8(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Attribute(str) @raises_exception(ERDiagramError, "Attribute name cannot both start and end with underscore. Got: _attr1_") def test_attribute9(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) _attr1_ = Required(str) @raises_exception(ERDiagramError, "Duplicate use of attribute Entity1.attr1 in entity Entity2") def test_attribute10(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required(str) @@ -92,7 +99,7 @@ class Entity2(db.Entity): @raises_exception(ERDiagramError, "Invalid use of attribute Entity1.a in entity Entity2") def test_attribute11(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(str) class Entity2(db.Entity): @@ -102,33 +109,33 @@ class Entity2(db.Entity): @raises_exception(ERDiagramError, "Cannot create default primary key attribute for Entity1 because name 'id' is already in use." " Please create a PrimaryKey attribute for entity Entity1 or rename the 'id' attribute") def test_attribute12(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = Optional(str) @raises_exception(ERDiagramError, "Reverse attribute for Entity1.attr1 not found") def test_attribute13(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required('Entity2') class Entity2(db.Entity): id = PrimaryKey(int) - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, "Reverse attribute Entity1.attr1 not found") def test_attribute14(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Required(Entity1, reverse='attr1') - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, "Inconsistent reverse attributes Entity3.attr3 and Entity2.attr2") def test_attribute15(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional('Entity2') @@ -138,11 +145,11 @@ class Entity2(db.Entity): class Entity3(db.Entity): id = PrimaryKey(int) attr3 = Required(Entity2, reverse='attr2') - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, "Inconsistent reverse attributes Entity3.attr3 and Entity2.attr2") def test_attribute16(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional('Entity2') @@ -152,21 +159,21 @@ class Entity2(db.Entity): class Entity3(db.Entity): id = PrimaryKey(int) attr3 = Required(Entity2, reverse=Entity2.attr2) - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, 'Reverse attribute for Entity2.attr2 not found') def test_attribute18(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Required('Entity1') - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, "Ambiguous reverse attribute for Entity1.a. Use the 'reverse' parameter for pointing to right attribute") def test_attribute19(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) a = Required('Entity2') @@ -175,11 +182,11 @@ class Entity2(db.Entity): id = PrimaryKey(int) c = Set(Entity1) d = Set(Entity1) - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, "Ambiguous reverse attribute for Entity1.c. Use the 'reverse' parameter for pointing to right attribute") def test_attribute20(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) c = Set('Entity2') @@ -187,10 +194,10 @@ class Entity2(db.Entity): id = PrimaryKey(int) a = Required(Entity1, reverse='c') b = Optional(Entity1, reverse='c') - db.generate_mapping() + db.generate_mapping(check_tables=False) def test_attribute21(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) a = Required('Entity2', reverse='c') @@ -201,7 +208,7 @@ class Entity2(db.Entity): d = Set(Entity1) def test_attribute22(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) a = Required('Entity2', reverse='c') @@ -213,79 +220,81 @@ class Entity2(db.Entity): @raises_exception(ERDiagramError, 'Inconsistent reverse attributes Entity1.a and Entity2.b') def test_attribute23(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required('Entity2', reverse='b') class Entity2(db.Entity): b = Optional('Entity3') class Entity3(db.Entity): c = Required('Entity2') - db.generate_mapping() + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, 'Inconsistent reverse attributes Entity1.a and Entity2.c') def test_attribute23(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required('Entity2', reverse='c') b = Required('Entity2', reverse='d') class Entity2(db.Entity): c = Optional('Entity1', reverse='b') d = Optional('Entity1', reverse='a') - db.generate_mapping() + db.generate_mapping(check_tables=False) def test_attribute24(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = PrimaryKey(str, auto=True) db.generate_mapping(create_tables=True) - self.assertTrue('AUTOINCREMENT' not in db.schema.tables['Entity1'].get_create_command()) + table_name = 'Entity1' if db.provider.dialect == 'SQLite' and pony.__version__ < '0.9' else 'entity1' + self.assertTrue('AUTOINCREMENT' not in db.schema.tables[table_name].get_create_command()) @raises_exception(TypeError, "Parameters 'column' and 'columns' cannot be specified simultaneously") def test_columns1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2", column='a', columns=['b', 'c']) class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional(Entity1) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) def test_columns2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int, column='a') + db.generate_mapping(check_tables=False) self.assertEqual(Entity1.id.columns, ['a']) def test_columns3(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int, columns=['a']) self.assertEqual(Entity1.id.column, 'a') @raises_exception(MappingError, "Too many columns were specified for Entity1.id") def test_columns5(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int, columns=['a', 'b']) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Parameter 'columns' must be a list. Got: %r'" % {'a'}) def test_columns6(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int, columns={'a'}) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Parameter 'column' must be a string. Got: 4") def test_columns7(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int, column=4) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) def test_columns8(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) b = Required(int) @@ -293,12 +302,13 @@ class Entity1(db.Entity): PrimaryKey(a, b) class Entity2(db.Entity): attr2 = Required(Entity1, columns=['x', 'y']) + db.generate_mapping(check_tables=False) self.assertEqual(Entity2.attr2.column, None) self.assertEqual(Entity2.attr2.columns, ['x', 'y']) @raises_exception(MappingError, 'Invalid number of columns specified for Entity2.attr2') def test_columns9(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) b = Required(int) @@ -306,11 +316,11 @@ class Entity1(db.Entity): PrimaryKey(a, b) class Entity2(db.Entity): attr2 = Required(Entity1, columns=['x', 'y', 'z']) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(MappingError, 'Invalid number of columns specified for Entity2.attr2') def test_columns10(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) b = Required(int) @@ -318,11 +328,11 @@ class Entity1(db.Entity): PrimaryKey(a, b) class Entity2(db.Entity): attr2 = Required(Entity1, column='x') - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Items of parameter 'columns' must be strings. Got: [1, 2]") def test_columns11(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) b = Required(int) @@ -332,97 +342,97 @@ class Entity2(db.Entity): attr2 = Required(Entity1, columns=[1, 2]) def test_columns12(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_column='column2', reverse_columns=['column2']) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Parameters 'reverse_column' and 'reverse_columns' cannot be specified simultaneously") def test_columns13(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_column='column2', reverse_columns=['column3']) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Parameter 'reverse_column' must be a string. Got: 5") def test_columns14(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_column=5) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Parameter 'reverse_columns' must be a list. Got: 'column3'") def test_columns15(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_columns='column3') - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Parameter 'reverse_columns' must be a list of strings. Got: [5]") def test_columns16(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_columns=[5]) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) def test_columns17(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', column='column1', reverse_columns=['column2']) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) def test_columns18(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', table='T1') - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Parameter 'table' must be a string. Got: 5") def test_columns19(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', table=5) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "Each part of table name must be a string. Got: 1") def test_columns20(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): attr1 = Set('Entity1', reverse='attr1', table=[1, 'T1']) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) def test_columns_21(self): - db = Database('sqlite', ':memory:') + db = self.db class Stat(db.Entity): webinarshow = Optional('WebinarShow') class WebinarShow(db.Entity): stats = Required('Stat') - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) self.assertEqual(Stat.webinarshow.column, None) self.assertEqual(WebinarShow.stats.column, 'stats') def test_columns_22(self): - db = Database('sqlite', ':memory:') + db = self.db class ZStat(db.Entity): webinarshow = Optional('WebinarShow') class WebinarShow(db.Entity): stats = Required('ZStat') - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) self.assertEqual(ZStat.webinarshow.column, None) self.assertEqual(WebinarShow.stats.column, 'stats') def test_nullable1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Optional(unicode, unique=True) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) self.assertEqual(Entity1.a.nullable, True) def test_nullable2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Optional(unicode, unique=True) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: Entity1() commit() @@ -430,23 +440,23 @@ class Entity1(db.Entity): commit() def test_lambda_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(lambda: db.Entity2) class Entity2(db.Entity): b = Set(lambda: db.Entity1) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) self.assertEqual(Entity1.a.py_type, Entity2) self.assertEqual(Entity2.b.py_type, Entity1) @raises_exception(TypeError, "Invalid type of attribute Entity1.a: expected entity class, got 'Entity2'") def test_lambda_2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(lambda: 'Entity2') class Entity2(db.Entity): b = Set(lambda: db.Entity1) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(ERDiagramError, 'Interrelated entities must belong to same database. ' 'Entities Entity1 and Entity2 belongs to different databases') @@ -457,47 +467,47 @@ class Entity1(db1.Entity): db2 = Database('sqlite', ':memory:') class Entity2(db2.Entity): b = Set(lambda: db1.Entity1) - db1.generate_mapping(create_tables=True) + db1.generate_mapping(check_tables=False) @raises_exception(ValueError, 'Check for attribute Entity1.a failed. Value: 1') def test_py_check_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int, py_check=lambda val: val > 5 and val < 10) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a=1) def test_py_check_2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int, py_check=lambda val: val > 5 and val < 10) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a=7) def test_py_check_3(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Optional(date, py_check=lambda val: val.year >= 2000) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a=None) @raises_exception(ValueError, 'Check for attribute Entity1.a failed. Value: datetime.date(1999, 1, 1)') def test_py_check_4(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Optional(date, py_check=lambda val: val.year >= 2000) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a=date(1999, 1, 1)) def test_py_check_5(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Optional(date, py_check=lambda val: val.year >= 2000) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a=date(2010, 1, 1)) @@ -505,10 +515,10 @@ class Entity1(db.Entity): def test_py_check_6(self): def positive_number(val): if val <= 0: raise ValueError('Should be positive number') - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Optional(int, py_check=positive_number) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a=-1) @@ -516,27 +526,27 @@ def test_py_check_7(self): def positive_number(val): if val <= 0: raise ValueError('Should be positive number') return True - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Optional(int, py_check=positive_number) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a=1) @raises_exception(NotImplementedError, "'py_check' parameter is not supported for collection attributes") def test_py_check_8(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required('Entity2') class Entity2(db.Entity): a = Set('Entity1', py_check=lambda val: True) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) def test_py_check_truncate(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(str, py_check=lambda val: False) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: try: obj = Entity1(a='1234567890' * 1000) @@ -550,88 +560,93 @@ class Entity1(db.Entity): @raises_exception(ValueError, 'Value for attribute Entity1.a is too long. Max length is 10, value length is 10000') def test_str_max_len(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(str, 10) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: obj = Entity1(a='1234567890' * 1000) + @only_for('sqlite') def test_foreign_key_sql_type_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='SOME_TYPE') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type self.assertEqual(sql_type, 'SOME_TYPE') + @only_for('sqlite') def test_foreign_key_sql_type_2(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='SOME_TYPE') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo, sql_type='ANOTHER_TYPE') - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type self.assertEqual(sql_type, 'ANOTHER_TYPE') + @only_for('sqlite') def test_foreign_key_sql_type_3(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='SERIAL') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo, sql_type='ANOTHER_TYPE') - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type self.assertEqual(sql_type, 'ANOTHER_TYPE') def test_foreign_key_sql_type_4(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='SERIAL') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type - self.assertEqual(sql_type, 'INTEGER') + required_type = 'INT8' if db.provider_name == 'cockroach' else 'INTEGER' + self.assertEqual(required_type, sql_type) def test_foreign_key_sql_type_5(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): id = PrimaryKey(unicode, sql_type='serial') bars = Set('Bar') class Bar(db.Entity): foo = Required(Foo) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) table = db.schema.tables.get(Bar._table_) sql_type = table.column_list[1].sql_type - self.assertEqual(sql_type, 'integer') + required_type = 'int8' if db.provider_name == 'cockroach' else 'integer' + self.assertEqual(required_type, sql_type) def test_self_referenced_m2m_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Node(db.Entity): id = PrimaryKey(int) prev_nodes = Set("Node") next_nodes = Set("Node") - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) def test_implicit_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): name = Required(str) bar = Required("Bar") @@ -639,7 +654,7 @@ class Bar(db.Entity): id = PrimaryKey(int) name = Optional(str) foos = Set("Foo") - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) self.assertTrue(Foo.id.is_implicit) self.assertFalse(Foo.name.is_implicit) @@ -650,12 +665,12 @@ class Bar(db.Entity): self.assertFalse(Bar.foos.is_implicit) def test_implicit_2(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): x = Required(str) class Bar(Foo): y = Required(str) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) self.assertTrue(Foo.id.is_implicit) self.assertTrue(Foo.classtype.is_implicit) @@ -668,16 +683,16 @@ class Bar(Foo): @raises_exception(TypeError, 'Attribute Foo.x has invalid type NoneType') def test_none_type(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): x = Required(type(None)) - db.generate_mapping(create_tables=True) + db.generate_mapping(check_tables=False) @raises_exception(TypeError, "'sql_default' option value cannot be empty string, " "because it should be valid SQL literal or expression. " "Try to use \"''\", or just specify default='' instead.") def test_none_type(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): x = Required(str, sql_default='') diff --git a/pony/orm/tests/test_diagram_keys.py b/pony/orm/tests/test_diagram_keys.py index 510a497bb..653656053 100644 --- a/pony/orm/tests/test_diagram_keys.py +++ b/pony/orm/tests/test_diagram_keys.py @@ -4,11 +4,15 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import db_params, setup_database, teardown_database + class TestKeys(unittest.TestCase): + def tearDown(self): + teardown_database(self.db) def test_keys1(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = Required(str) @@ -20,7 +24,7 @@ class Entity1(db.Entity): self.assertEqual(Entity1._composite_keys_, []) def test_keys2(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = Required(int) b = Required(str) @@ -34,14 +38,14 @@ class Entity1(db.Entity): @raises_exception(ERDiagramError, 'Only one primary key can be defined in each entity class') def test_keys3(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = PrimaryKey(int) @raises_exception(ERDiagramError, 'Only one primary key can be defined in each entity class') def test_keys4(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = Required(int) @@ -49,7 +53,7 @@ class Entity1(db.Entity): PrimaryKey(b, c) def test_unique1(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = Required(int, unique=True) @@ -58,7 +62,7 @@ class Entity1(db.Entity): self.assertEqual(Entity1._composite_keys_, []) def test_unique2(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = Optional(int, unique=True) @@ -67,7 +71,7 @@ class Entity1(db.Entity): self.assertEqual(Entity1._composite_keys_, []) def test_unique2_1(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = Optional(int) @@ -79,28 +83,28 @@ class Entity1(db.Entity): @raises_exception(TypeError, 'composite_key() must receive at least two attributes as arguments') def test_unique3(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) composite_key() @raises_exception(TypeError, 'composite_key() arguments must be attributes. Got: 123') def test_unique4(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) composite_key(123, 456) @raises_exception(TypeError, "composite_key() arguments must be attributes. Got: %r" % int) def test_unique5(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) composite_key(int, a) @raises_exception(TypeError, 'Set attribute Entity1.b cannot be part of unique index') def test_unique6(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = Required(int) b = Set('Entity2') @@ -108,14 +112,14 @@ class Entity1(db.Entity): @raises_exception(TypeError, "'unique' option cannot be set for attribute Entity1.b because it is collection") def test_unique7(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = Set('Entity2', unique=True) @raises_exception(TypeError, 'Optional attribute Entity1.b cannot be part of primary key') def test_unique8(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = Required(int) b = Optional(int) @@ -123,13 +127,13 @@ class Entity1(db.Entity): @raises_exception(TypeError, 'PrimaryKey attribute Entity1.a cannot be of type float') def test_float_pk(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(float) @raises_exception(TypeError, 'Attribute Entity1.b of type float cannot be part of primary key') def test_float_composite_pk(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = Required(int) b = Required(float) @@ -137,7 +141,7 @@ class Entity1(db.Entity): @raises_exception(TypeError, 'Attribute Entity1.b of type float cannot be part of unique index') def test_float_composite_key(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = Required(int) b = Required(float) @@ -145,33 +149,33 @@ class Entity1(db.Entity): @raises_exception(TypeError, 'Unique attribute Entity1.a cannot be of type float') def test_float_unique(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = Required(float, unique=True) @raises_exception(TypeError, 'PrimaryKey attribute Entity1.a cannot be volatile') def test_volatile_pk(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int, volatile=True) @raises_exception(TypeError, 'Set attribute Entity1.b cannot be volatile') def test_volatile_set(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = PrimaryKey(int) b = Set('Entity2', volatile=True) @raises_exception(TypeError, 'Volatile attribute Entity1.b cannot be part of primary key') def test_volatile_composite_pk(self): - db = Database('sqlite', ':memory:') + db = self.db = Database(**db_params) class Entity1(db.Entity): a = Required(int) b = Required(int, volatile=True) PrimaryKey(a, b) def test_composite_key_update(self): - db = Database('sqlite', ':memory:') + db = self.db = Database() class Entity1(db.Entity): s = Set('Entity3') class Entity2(db.Entity): @@ -180,7 +184,8 @@ class Entity3(db.Entity): a = Required(Entity1) b = Required(Entity2) composite_key(a, b) - db.generate_mapping(create_tables=True) + setup_database(db) + with db_session: x = Entity1(id=1) y = Entity2(id=1) diff --git a/pony/orm/tests/test_distinct.py b/pony/orm/tests/test_distinct.py index 5f4ce1eb5..a53f67b4c 100644 --- a/pony/orm/tests/test_distinct.py +++ b/pony/orm/tests/test_distinct.py @@ -4,8 +4,9 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Department(db.Entity): number = PrimaryKey(int) @@ -29,25 +30,30 @@ class Course(db.Entity): PrimaryKey(name, semester) students = Set('Student') -db.generate_mapping(create_tables=True) - -with db_session: - d1 = Department(number=1) - d2 = Department(number=2) - g1 = Group(id=1, dept=d1) - g2 = Group(id=2, dept=d2) - s1 = Student(id=1, name='S1', age=20, group=g1, scholarship=0) - s2 = Student(id=2, name='S2', age=21, group=g1, scholarship=100) - s3 = Student(id=3, name='S3', age=23, group=g1, scholarship=200) - s4 = Student(id=4, name='S4', age=21, group=g1, scholarship=100) - s5 = Student(id=5, name='S5', age=23, group=g2, scholarship=0) - s6 = Student(id=6, name='S6', age=23, group=g2, scholarship=200) - c1 = Course(name='C1', semester=1, students=[s1, s2, s3]) - c2 = Course(name='C2', semester=1, students=[s2, s3, s5, s6]) - c3 = Course(name='C3', semester=2, students=[s4, s5, s6]) - class TestDistinct(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + d1 = Department(number=1) + d2 = Department(number=2) + g1 = Group(id=1, dept=d1) + g2 = Group(id=2, dept=d2) + s1 = Student(id=1, name='S1', age=20, group=g1, scholarship=0) + s2 = Student(id=2, name='S2', age=21, group=g1, scholarship=100) + s3 = Student(id=3, name='S3', age=23, group=g1, scholarship=200) + s4 = Student(id=4, name='S4', age=21, group=g1, scholarship=100) + s5 = Student(id=5, name='S5', age=23, group=g2, scholarship=0) + s6 = Student(id=6, name='S6', age=23, group=g2, scholarship=200) + c1 = Course(name='C1', semester=1, students=[s1, s2, s3]) + c2 = Course(name='C2', semester=1, students=[s2, s3, s5, s6]) + c3 = Course(name='C3', semester=2, students=[s4, s5, s6]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): db_session.__enter__() diff --git a/pony/orm/tests/test_entity_init.py b/pony/orm/tests/test_entity_init.py index 7f1a2b896..37d62ec1b 100644 --- a/pony/orm/tests/test_entity_init.py +++ b/pony/orm/tests/test_entity_init.py @@ -6,23 +6,32 @@ from pony.orm.tests.testutils import raises_exception from pony.orm import * +from pony.orm.tests import setup_database, teardown_database -class TestCustomInit(unittest.TestCase): - def test1(self): - db = Database('sqlite', ':memory:') +db = Database() - class User(db.Entity): - name = Required(str) - password = Required(str) - created_at = Required(datetime) - def __init__(self, name, password): - password = md5(password.encode('utf8')).hexdigest() - super(User, self).__init__(name=name, password=password, created_at=datetime.now()) - self.uppercase_name = name.upper() +class User(db.Entity): + name = Required(str) + password = Required(str) + created_at = Required(datetime) - db.generate_mapping(create_tables=True) + def __init__(self, name, password): + password = md5(password.encode('utf8')).hexdigest() + super(User, self).__init__(name=name, password=password, created_at=datetime.now()) + self.uppercase_name = name.upper() + +class TestCustomInit(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(self): + teardown_database(db) + + def test1(self): with db_session: u1 = User('John', '123') u2 = User('Mike', '456') diff --git a/pony/orm/tests/test_entity_instances.py b/pony/orm/tests/test_entity_instances.py deleted file mode 100644 index bf1e302a0..000000000 --- a/pony/orm/tests/test_entity_instances.py +++ /dev/null @@ -1,103 +0,0 @@ -import unittest - -from pony import orm -from pony.orm.core import * -from pony.orm.tests.testutils import raises_exception - -db = Database('sqlite', ':memory:') - -class Person(db.Entity): - id = orm.PrimaryKey(int, auto=True) - name = orm.Required(str, 40) - lastName = orm.Required(str, max_len=40, unique=True) - age = orm.Optional(int) - groupName = orm.Optional('Group') - chiefOfGroup = orm.Optional('Group') - -class Group(db.Entity): - name = orm.Required(str) - persons = orm.Set(Person) - chief = orm.Optional(Person, reverse='chiefOfGroup') - -db.generate_mapping(create_tables=True) - -class TestEntityInstances(unittest.TestCase): - - def setUp(self): - rollback() - db_session.__enter__() - - def tearDown(self): - rollback() - db_session.__exit__() - - def test_create_instance(self): - with orm.db_session: - Person(id=1, name='Philip', lastName='Croissan') - Person(id=2, name='Philip', lastName='Parlee', age=40) - Person(id=3, name='Philip', lastName='Illinois', age=50) - commit() - - def test_getObjectByPK(self): - self.assertEqual(Person[1].lastName, "Croissan") - - @raises_exception(ObjectNotFound , "Person[666]") - def test_getObjectByPKexception(self): - p = Person[666] - - def test_getObjectByGet(self): - p = Person.get(age=40) - self.assertEqual(p.lastName, "Parlee") - - def test_getObjectByGetNone(self): - self.assertIsNone(Person.get(age=41)) - - @raises_exception(MultipleObjectsFoundError , 'Multiple objects were found.' - ' Use Person.select(...) to retrieve them') - def test_getObjectByGetException(self): - p = Person.get(name="Philip") - - def test_updateObject(self): - with db_session: - Person[2].age=42 - self.assertEqual(Person[2].age, 42) - commit() - - @raises_exception(ObjectNotFound, 'Person[2]') - def test_deleteObject(self): - with db_session: - Person[2].delete() - p = Person[2] - - def test_bulkDelete(self): - with orm.db_session: - Person(id=4, name='Klaus', lastName='Mem', age=12) - Person(id=5, name='Abraham', lastName='Wrangler', age=13) - Person(id=6, name='Kira', lastName='Phito', age=20) - delete(p for p in Person if p.age <= 20) - self.assertEqual(select(p for p in Person if p.age <= 20).count(), 0) - - def test_bulkDeleteV2(self): - with orm.db_session: - Person(id=4, name='Klaus', lastName='Mem', age=12) - Person(id=5, name='Abraham', lastName='Wrangler', age=13) - Person(id=6, name='Kira', lastName='Phito', age=20) - Person.select(lambda p: p.id >= 4).delete(bulk=True) - self.assertEqual(select(p for p in Person if p.id >= 4).count(), 0) - - @raises_exception(UnresolvableCyclicDependency, 'Cannot save cyclic chain: Person -> Group') - def test_saveChainsException(self): - with orm.db_session: - claire = Person(name='Claire', lastName='Forlani') - annabel = Person(name='Annabel', lastName='Fiji') - Group(name='Aspen', persons=[claire, annabel], chief=claire) - print('group1=', Group[1]) - - def test_saveChainsWithFlush(self): - with orm.db_session: - claire = Person(name='Claire', lastName='Forlani') - annabel = Person(name='Annabel', lastName='Fiji') - flush() - Group(name='Aspen', persons=[claire, annabel], chief=claire) - self.assertEqual(Group[1].name, 'Aspen') - self.assertEqual(Group[1].chief.lastName, 'Forlani') \ No newline at end of file diff --git a/pony/orm/tests/test_entity_proxy.py b/pony/orm/tests/test_entity_proxy.py index ab8130ab4..4cc319c9d 100644 --- a/pony/orm/tests/test_entity_proxy.py +++ b/pony/orm/tests/test_entity_proxy.py @@ -2,23 +2,23 @@ from pony.orm import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -class TestProxy(unittest.TestCase): - def setUp(self): - db = self.db = Database('sqlite', ':memory:') - - class Country(db.Entity): - id = PrimaryKey(int) - name = Required(str) - persons = Set("Person") +db = Database() - class Person(db.Entity): - id = PrimaryKey(int) - name = Required(str) - country = Required(Country) +class Country(db.Entity): + id = PrimaryKey(int) + name = Required(str) + persons = Set("Person") - db.generate_mapping(create_tables=True) +class Person(db.Entity): + id = PrimaryKey(int) + name = Required(str) + country = Required(Country) +class TestProxy(unittest.TestCase): + def setUp(self): + setup_database(db) with db_session: c1 = Country(id=1, name='Russia') c2 = Country(id=2, name='Japan') @@ -26,11 +26,12 @@ class Person(db.Entity): Person(id=2, name='Raikou Minamoto', country=c2) Person(id=3, name='Ibaraki Douji', country=c2) + def tearDown(self): + teardown_database(db) def test_1(self): - db = self.db with db_session: - p = make_proxy(db.Person[2]) + p = make_proxy(Person[2]) with db_session: x1 = db.local_stats[None].db_count # number of queries @@ -43,9 +44,8 @@ def test_1(self): self.assertEqual(x1, x2-1) def test_2(self): - db = self.db with db_session: - p = make_proxy(db.Person[2]) + p = make_proxy(Person[2]) name = p.name country = p.country @@ -59,13 +59,12 @@ def test_2(self): self.assertEqual(x1, x2-1) def test_3(self): - db = self.db with db_session: - p = db.Person[2] + p = Person[2] proxy = make_proxy(p) with db_session: - p2 = db.Person[2] + p2 = Person[2] name1 = 'Tamamo no Mae' # It is possible to assign new attribute values to a proxy object p2.name = name1 @@ -75,13 +74,12 @@ def test_3(self): def test_4(self): - db = self.db with db_session: - p = db.Person[2] + p = Person[2] proxy = make_proxy(p) with db_session: - p2 = db.Person[2] + p2 = Person[2] name1 = 'Tamamo no Mae' p2.name = name1 @@ -92,9 +90,8 @@ def test_4(self): self.assertEqual(name1, name2) def test_5(self): - db = self.db with db_session: - p = db.Person[2] + p = Person[2] r = repr(p) self.assertEqual(r, 'Person[2]') @@ -114,9 +111,8 @@ def test_5(self): def test_6(self): - db = self.db with db_session: - p = db.Person[2] + p = Person[2] proxy = make_proxy(p) proxy.name = 'Okita Souji' # after assignment, the attribute value is the same for the proxy and for the original object @@ -125,9 +121,8 @@ def test_6(self): def test_7(self): - db = self.db with db_session: - p = db.Person[2] + p = Person[2] proxy = make_proxy(p) proxy.name = 'Okita Souji' # after assignment, the attribute value is the same for the proxy and for the original object @@ -136,11 +131,10 @@ def test_7(self): def test_8(self): - db = self.db with db_session: - c1 = db.Country[1] + c1 = Country[1] c1_proxy = make_proxy(c1) - p2 = db.Person[2] + p2 = Person[2] self.assertNotEqual(p2.country, c1) self.assertNotEqual(p2.country, c1_proxy) # proxy can be used in attribute assignment @@ -150,11 +144,10 @@ def test_8(self): def test_9(self): - db = self.db with db_session: - c2 = db.Country[2] + c2 = Country[2] c2_proxy = make_proxy(c2) - persons = select(p for p in db.Person if p.country == c2_proxy) + persons = select(p for p in Person if p.country == c2_proxy) self.assertEqual({p.id for p in persons}, {2, 3}) if __name__ == '__main__': diff --git a/pony/orm/tests/test_exists.py b/pony/orm/tests/test_exists.py index 7227db88d..6c0f3a17b 100644 --- a/pony/orm/tests/test_exists.py +++ b/pony/orm/tests/test_exists.py @@ -2,8 +2,9 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Group(db.Entity): students = Set('Student') @@ -19,21 +20,27 @@ class Student(db.Entity): class Passport(db.Entity): student = Optional(Student) -db.generate_mapping(create_tables=True) -with db_session: - g1 = Group() - g2 = Group() +class TestExists(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(id=1) + g2 = Group(id=2) - p = Passport() + p = Passport(id=1) - Student(first_name='Mashu', last_name='Kyrielight', login='Shielder', group=g1) - Student(first_name='Okita', last_name='Souji', login='Sakura', group=g1) - Student(first_name='Francis', last_name='Drake', group=g2, graduated=True) - Student(first_name='Oda', last_name='Nobunaga', group=g2, graduated=True) - Student(first_name='William', last_name='Shakespeare', group=g2, graduated=True, passport=p) + Student(id=1, first_name='Mashu', last_name='Kyrielight', login='Shielder', group=g1) + Student(id=2, first_name='Okita', last_name='Souji', login='Sakura', group=g1) + Student(id=3, first_name='Francis', last_name='Drake', group=g2, graduated=True) + Student(id=4, first_name='Oda', last_name='Nobunaga', group=g2, graduated=True) + Student(id=5, first_name='William', last_name='Shakespeare', group=g2, graduated=True, passport=p) + + @classmethod + def tearDownClass(cls): + teardown_database(db) -class TestExists(unittest.TestCase): def setUp(self): rollback() db_session.__enter__() diff --git a/pony/orm/tests/test_filter.py b/pony/orm/tests/test_filter.py index 84d2b4ca9..be412b7ab 100644 --- a/pony/orm/tests/test_filter.py +++ b/pony/orm/tests/test_filter.py @@ -3,8 +3,17 @@ import unittest from pony.orm.tests.model1 import * +from pony.orm.tests import setup_database, teardown_database + class TestFilter(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + populate_db() + @classmethod + def tearDownClass(cls): + teardown_database(db) def setUp(self): rollback() db_session.__enter__() diff --git a/pony/orm/tests/test_flush.py b/pony/orm/tests/test_flush.py index aab318169..bd5c6c1eb 100644 --- a/pony/orm/tests/test_flush.py +++ b/pony/orm/tests/test_flush.py @@ -4,21 +4,25 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() -class TestFlush(unittest.TestCase): - def setUp(self): - self.db = Database('sqlite', ':memory:') - class Person(self.db.Entity): - name = Required(unicode) +class Person(db.Entity): + name = Required(unicode) - self.db.generate_mapping(create_tables=True) - def tearDown(self): - self.db = None +class TestFlush(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(self): + teardown_database(db) def test1(self): - Person = self.db.Person with db_session: a = Person(name='A') b = Person(name='B') @@ -29,10 +33,12 @@ def test1(self): b.flush() self.assertEqual(a.id, None) - self.assertEqual(b.id, 1) + self.assertIsNotNone(b.id) + b_id = b.id self.assertEqual(c.id, None) flush() - self.assertEqual(a.id, 2) - self.assertEqual(b.id, 1) - self.assertEqual(c.id, 3) + self.assertIsNotNone(a.id) + self.assertEqual(b.id, b_id) + self.assertIsNotNone(c.id) + self.assertEqual(len({a.id, b.id, c.id}), 3) diff --git a/pony/orm/tests/test_frames.py b/pony/orm/tests/test_frames.py index 6d4bc1a39..94a3bc1f1 100644 --- a/pony/orm/tests/test_frames.py +++ b/pony/orm/tests/test_frames.py @@ -5,21 +5,27 @@ from pony.orm.core import * import pony.orm.decompiling from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Person(db.Entity): name = Required(unicode) age = Required(int) -db.generate_mapping(create_tables=True) - -with db_session: - p1 = Person(name='John', age=22) - p2 = Person(name='Mary', age=18) - p3 = Person(name='Mike', age=25) class TestFrames(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + p1 = Person(id=1, name='John', age=22) + p2 = Person(id=2, name='Mary', age=18) + p3 = Person(id=3, name='Mike', age=25) + + @classmethod + def tearDownClass(cls): + db.drop_all_tables(with_all_data=True) @db_session def test_select(self): diff --git a/pony/orm/tests/test_generator_db_session.py b/pony/orm/tests/test_generator_db_session.py index 8f54615d1..d232def19 100644 --- a/pony/orm/tests/test_generator_db_session.py +++ b/pony/orm/tests/test_generator_db_session.py @@ -5,14 +5,16 @@ from pony.orm.core import * from pony.orm.core import local from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + class TestGeneratorDbSession(unittest.TestCase): def setUp(self): - db = Database('sqlite', ':memory:') + db = Database() class Account(db.Entity): id = PrimaryKey(int) amount = Required(int) - db.generate_mapping(create_tables=True) + setup_database(db) self.db = db self.Account = Account @@ -23,6 +25,7 @@ class Account(db.Entity): a3 = Account(id=3, amount=3000) def tearDown(self): + teardown_database(self.db) assert local.db_session is None self.db = self.Account = None diff --git a/pony/orm/tests/test_get_pk.py b/pony/orm/tests/test_get_pk.py index 4579a83a0..000091af5 100644 --- a/pony/orm/tests/test_get_pk.py +++ b/pony/orm/tests/test_get_pk.py @@ -1,73 +1,67 @@ -from pony.py23compat import basestring - import unittest - -from pony.orm import * -from pony import orm -from pony.utils import cached_property from datetime import date +from pony.orm import * +from pony.orm.tests import setup_database, teardown_database -class Test(unittest.TestCase): - - @cached_property - def db(self): - return orm.Database('sqlite', ':memory:') +day = date.today() - def setUp(self): - db = self.db - self.day = date.today() +db = Database() - class A(db.Entity): - b = Required("B") - c = Required("C") - PrimaryKey(b, c) +class A(db.Entity): + b = Required("B") + c = Required("C") + PrimaryKey(b, c) - class B(db.Entity): - id = PrimaryKey(date) - a_set = Set(A) +class B(db.Entity): + id = PrimaryKey(date) + a_set = Set(A) - class C(db.Entity): - x = Required("X") - y = Required("Y") - a_set = Set(A) - PrimaryKey(x, y) +class C(db.Entity): + x = Required("X") + y = Required("Y") + a_set = Set(A) + PrimaryKey(x, y) - class X(db.Entity): - id = PrimaryKey(int) - c_set = Set(C) +class X(db.Entity): + id = PrimaryKey(int) + c_set = Set(C) - class Y(db.Entity): - id = PrimaryKey(int) - c_set = Set(C) +class Y(db.Entity): + id = PrimaryKey(int) + c_set = Set(C) - db.generate_mapping(check_tables=True, create_tables=True) - with orm.db_session: +class Test(unittest.TestCase): + def setUp(self): + setup_database(db) + with db_session: x1 = X(id=123) y1 = Y(id=456) - b1 = B(id=self.day) + b1 = B(id=day) c1 = C(x=x1, y=y1) A(b=b1, c=c1) + def tearDown(self): + teardown_database(db) @db_session def test_1(self): - a1 = self.db.A.select().first() - a2 = self.db.A[a1.get_pk()] + a1 = A.select().first() + a2 = A[a1.get_pk()] self.assertEqual(a1, a2) @db_session def test2(self): - a = self.db.A.select().first() - b = self.db.B.select().first() - c = self.db.C.select().first() + a = A.select().first() + b = B.select().first() + c = C.select().first() pk = (b.get_pk(), c._get_raw_pkval_()) - self.assertTrue(a is self.db.A[pk]) + self.assertTrue(a is A[pk]) @db_session def test3(self): - a = self.db.A.select().first() - c = self.db.C.select().first() - pk = (self.day, c.get_pk()) - self.assertTrue(a is self.db.A[pk]) \ No newline at end of file + a = A.select().first() + c = C.select().first() + pk = (day, c.get_pk()) + self.assertTrue(a is A[pk]) diff --git a/pony/orm/tests/test_getattr.py b/pony/orm/tests/test_getattr.py index 309e8dd6e..226271b52 100644 --- a/pony/orm/tests/test_getattr.py +++ b/pony/orm/tests/test_getattr.py @@ -6,12 +6,13 @@ from pony import orm from pony.utils import cached_property from pony.orm.tests.testutils import raises_exception +from pony.orm.tests import db_params, setup_database, teardown_database class Test(unittest.TestCase): @cached_property def db(self): - return orm.Database('sqlite', ':memory:') + return orm.Database() def setUp(self): db = self.db @@ -30,7 +31,7 @@ class Artist(db.Entity): hobbies = orm.Set(Hobby) genres = orm.Set(Genre) - db.generate_mapping(check_tables=True, create_tables=True) + setup_database(db) with orm.db_session: pop = Genre(name='Pop') @@ -38,7 +39,10 @@ class Artist(db.Entity): Hobby(name='Swimming') pony.options.INNER_JOIN_SYNTAX = True - + + def tearDown(self): + teardown_database(self.db) + @db_session def test_no_caching(self): for attr_name, attr_type in zip(['name', 'age'], [basestring, int]): diff --git a/pony/orm/tests/test_hooks.py b/pony/orm/tests/test_hooks.py index 23e9047ec..af3c2002f 100644 --- a/pony/orm/tests/test_hooks.py +++ b/pony/orm/tests/test_hooks.py @@ -3,48 +3,68 @@ import unittest from pony.orm.core import * +from pony.orm.tests import setup_database, teardown_database, db_params logged_events = [] -db = Database('sqlite', ':memory:') +db = Database() + class Person(db.Entity): id = PrimaryKey(int) name = Required(unicode) age = Required(int) + def before_insert(self): logged_events.append('BI_' + self.name) do_before_insert(self) + def before_update(self): logged_events.append('BU_' + self.name) do_before_update(self) + def before_delete(self): logged_events.append('BD_' + self.name) do_before_delete(self) + def after_insert(self): logged_events.append('AI_' + self.name) do_after_insert(self) + def after_update(self): logged_events.append('AU_' + self.name) do_after_update(self) + def after_delete(self): logged_events.append('AD_' + self.name) do_after_delete(self) + def do_nothing(person): pass + def set_hooks_to_do_nothing(): global do_before_insert, do_before_update, do_before_delete global do_after_insert, do_after_update, do_after_delete do_before_insert = do_before_update = do_before_delete = do_nothing do_after_insert = do_after_update = do_after_delete = do_nothing + +db.bind(**db_params) +db.generate_mapping(check_tables=False) + set_hooks_to_do_nothing() -db.generate_mapping(create_tables=True) class TestHooks(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) def setUp(self): set_hooks_to_do_nothing() @@ -90,7 +110,15 @@ def flush_for(*objects): for obj in objects: obj.flush() + class ObjectFlushTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) def setUp(self): set_hooks_to_do_nothing() diff --git a/pony/orm/tests/test_hybrid_methods_and_properties.py b/pony/orm/tests/test_hybrid_methods_and_properties.py index c1d454dc9..e5e146285 100644 --- a/pony/orm/tests/test_hybrid_methods_and_properties.py +++ b/pony/orm/tests/test_hybrid_methods_and_properties.py @@ -2,11 +2,13 @@ from pony.orm import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') +db = Database() sep = ' ' + class Person(db.Entity): id = PrimaryKey(int) first_name = Required(str) @@ -53,9 +55,11 @@ def complex_method(self): def simple_method(self): return self.complex_method() + class FakePerson(object): pass + p = FakePerson() p.last_name = '***' @@ -68,8 +72,6 @@ class Car(db.Entity): price = Required(int) color = Required(str) -db.generate_mapping(create_tables=True) - def simple_func(person): return person.full_name @@ -79,21 +81,28 @@ def complex_func(person): return person.complex_method() -with db_session: - p1 = Person(id=1, first_name='Alexander', last_name='Kozlovsky', favorite_color='white') - p2 = Person(id=2, first_name='Alexei', last_name='Malashkevich', favorite_color='green') - p3 = Person(id=3, first_name='Vitaliy', last_name='Abetkin') - p4 = Person(id=4, first_name='Alexander', last_name='Tischenko', favorite_color='blue') - c1 = Car(brand='Peugeot', model='306', owner=p1, year=2006, price=14000, color='red') - c2 = Car(brand='Honda', model='Accord', owner=p1, year=2007, price=13850, color='white') - c3 = Car(brand='Nissan', model='Skyline', owner=p2, year=2008, price=29900, color='black') - c4 = Car(brand='Volkswagen', model='Passat', owner=p1, year=2012, price=9400, color='blue') - c5 = Car(brand='Koenigsegg', model='CCXR', owner=p4, year=2016, price=4850000, color='white') - c6 = Car(brand='Lada', model='Kalina', owner=p4, year=2015, price=5000, color='white') +class TestHybridsAndProperties(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + p1 = Person(id=1, first_name='Alexander', last_name='Kozlovsky', favorite_color='white') + p2 = Person(id=2, first_name='Alexei', last_name='Malashkevich', favorite_color='green') + p3 = Person(id=3, first_name='Vitaliy', last_name='Abetkin') + p4 = Person(id=4, first_name='Alexander', last_name='Tischenko', favorite_color='blue') + + c1 = Car(id=1, brand='Peugeot', model='306', owner=p1, year=2006, price=14000, color='red') + c2 = Car(id=2, brand='Honda', model='Accord', owner=p1, year=2007, price=13850, color='white') + c3 = Car(id=3, brand='Nissan', model='Skyline', owner=p2, year=2008, price=29900, color='black') + c4 = Car(id=4, brand='Volkswagen', model='Passat', owner=p1, year=2012, price=9400, color='blue') + c5 = Car(id=5, brand='Koenigsegg', model='CCXR', owner=p4, year=2016, price=4850000, color='white') + c6 = Car(id=6, brand='Lada', model='Kalina', owner=p4, year=2015, price=5000, color='white') + @classmethod + def tearDownClass(cls): + teardown_database(db) -class TestHybridsAndProperties(unittest.TestCase): @db_session def test1(self): persons = select(p.full_name for p in Person if p.has_car)[:] diff --git a/pony/orm/tests/test_indexes.py b/pony/orm/tests/test_indexes.py index d11febca8..9f8b3f453 100644 --- a/pony/orm/tests/test_indexes.py +++ b/pony/orm/tests/test_indexes.py @@ -4,10 +4,17 @@ from pony.orm import * from pony.orm.tests.testutils import * +from pony.orm.tests import db_params, teardown_database class TestIndexes(unittest.TestCase): + def setUp(self): + self.db = Database(**db_params) + + def tearDown(self): + teardown_database(self.db) + def test_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Person(db.Entity): name = Required(str) age = Required(int) @@ -22,7 +29,8 @@ class Person(db.Entity): self.assertEqual(i2.is_pk, False) self.assertEqual(i2.is_unique, True) - table = db.schema.tables['Person'] + table_name = 'Person' if db.provider.dialect == 'SQLite' and pony.__version__ < '0.9' else 'person' + table = db.schema.tables[table_name] name_column = table.column_dict['name'] age_column = table.column_dict['age'] self.assertEqual(len(table.indexes), 2) @@ -31,7 +39,7 @@ class Person(db.Entity): self.assertEqual(db_index.is_unique, True) def test_2(self): - db = Database('sqlite', ':memory:') + db = self.db class Person(db.Entity): name = Required(str) age = Required(int) @@ -46,7 +54,8 @@ class Person(db.Entity): self.assertEqual(i2.is_pk, False) self.assertEqual(i2.is_unique, False) - table = db.schema.tables['Person'] + table_name = 'Person' if db.provider.dialect == 'SQLite' and pony.__version__ < '0.9' else 'person' + table = db.schema.tables[table_name] name_column = table.column_dict['name'] age_column = table.column_dict['age'] self.assertEqual(len(table.indexes), 2) @@ -55,11 +64,26 @@ class Person(db.Entity): self.assertEqual(db_index.is_unique, False) create_script = db.schema.generate_create_script() - index_sql = 'CREATE INDEX "idx_person__name_age" ON "Person" ("name", "age")' - self.assertTrue(index_sql in create_script) + + + dialect = self.db.provider.dialect + if pony.__version__ < '0.9': + if dialect == 'SQLite': + index_sql = 'CREATE INDEX "idx_person__name_age" ON "Person" ("name", "age")' + else: + index_sql = 'CREATE INDEX "idx_person__name_age" ON "person" ("name", "age")' + elif dialect == 'MySQL' or dialect == 'SQLite': + index_sql = 'CREATE INDEX `idx_person__name__age` ON `person` (`name`, `age`)' + elif dialect == 'PostgreSQL': + index_sql = 'CREATE INDEX "idx_person__name__age" ON "person" ("name", "age")' + elif dialect == 'Oracle': + index_sql = 'CREATE INDEX "IDX_PERSON__NAME__AGE" ON "PERSON" ("NAME", "AGE")' + else: + raise NotImplementedError + self.assertIn(index_sql, create_script) def test_3(self): - db = Database('sqlite', ':memory:') + db = self.db class User(db.Entity): name = Required(str, unique=True) @@ -77,7 +101,7 @@ class User(db.Entity): self.assertEqual(u.name, 'B') def test_4(self): # issue 321 - db = Database('sqlite', ':memory:') + db = self.db class Person(db.Entity): name = Required(str) age = Required(int) @@ -93,7 +117,7 @@ class Person(db.Entity): p1.delete() def test_5(self): - db = Database('sqlite', ':memory:') + db = self.db class Table1(db.Entity): name = Required(str) diff --git a/pony/orm/tests/test_inheritance.py b/pony/orm/tests/test_inheritance.py index 746be42d9..78badb3d8 100644 --- a/pony/orm/tests/test_inheritance.py +++ b/pony/orm/tests/test_inheritance.py @@ -4,11 +4,19 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import db_params, teardown_database + class TestInheritance(unittest.TestCase): + def setUp(self): + self.db = Database(**db_params) + + def tearDown(self): + if self.db.schema: + teardown_database(self.db) def test_0(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) @@ -17,7 +25,7 @@ class Entity1(db.Entity): self.assertEqual(Entity1._discriminator_, None) def test_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(Entity1): @@ -43,7 +51,7 @@ class Entity4(Entity2, Entity3): @raises_exception(ERDiagramError, "Multiple inheritance graph must be diamond-like. " "Entity Entity3 inherits from Entity1 and Entity2 entities which don't have common base class.") def test_2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = PrimaryKey(int) class Entity2(db.Entity): @@ -55,7 +63,7 @@ class Entity3(Entity1, Entity2): 'because both entities inherit from Entity1. ' 'To fix this, move attribute definition to base class') def test_3a(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(Entity1): @@ -64,7 +72,7 @@ class Entity3(Entity1): a = Required(int) def test3b(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) class Entity2(Entity1): @@ -80,7 +88,7 @@ class Entity4(db.Entity): @raises_exception(ERDiagramError, "Name 'a' hides base attribute Entity1.a") def test_4(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) a = Required(int) @@ -89,14 +97,14 @@ class Entity2(Entity1): @raises_exception(ERDiagramError, "Primary key cannot be redefined in derived classes") def test_5(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = PrimaryKey(int) class Entity2(Entity1): b = PrimaryKey(int) def test_6(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Discriminator(str) b = Required(int) @@ -109,7 +117,7 @@ class Entity2(Entity1): @raises_exception(TypeError, "Discriminator value for entity Entity1 " "with custom discriminator column 'a' of 'int' type is not set") def test_7(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Discriminator(int) b = Required(int) @@ -117,7 +125,7 @@ class Entity2(Entity1): c = Required(int) def test_8(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): _discriminator_ = 1 a = Discriminator(int) @@ -127,8 +135,8 @@ class Entity2(Entity1): c = Required(int) db.generate_mapping(create_tables=True) with db_session: - x = Entity1(b=10) - y = Entity2(b=10, c=20) + x = Entity1(id=1, b=10) + y = Entity2(id=2, b=10, c=20) with db_session: x = Entity1[1] y = Entity1[2] @@ -138,7 +146,7 @@ class Entity2(Entity1): self.assertEqual(y.a, 2) def test_9(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): _discriminator_ = '1' a = Discriminator(int) @@ -148,8 +156,8 @@ class Entity2(Entity1): c = Required(int) db.generate_mapping(create_tables=True) with db_session: - x = Entity1(b=10) - y = Entity2(b=10, c=20) + x = Entity1(id=1, b=10) + y = Entity2(id=2, b=10, c=20) with db_session: x = Entity1[1] y = Entity1[2] @@ -160,7 +168,7 @@ class Entity2(Entity1): @raises_exception(TypeError, "Incorrect discriminator value is set for Entity2 attribute 'a' of 'int' type: 'zzz'") def test_10(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): _discriminator_ = 1 a = Discriminator(int) @@ -170,7 +178,7 @@ class Entity2(Entity1): c = Required(int) def test_11(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): _discriminator_ = 1 a = Discriminator(int) @@ -179,7 +187,7 @@ class Entity1(db.Entity): @raises_exception(ERDiagramError, 'Invalid use of attribute Entity1.a in entity Entity2') def test_12(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) class Entity2(db.Entity): @@ -187,7 +195,7 @@ class Entity2(db.Entity): composite_index(Entity1.a, b) def test_13(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) class Entity2(Entity1): @@ -197,7 +205,7 @@ class Entity2(Entity1): [ (Entity2.id,), (Entity2.a, Entity2.b) ]) def test_14(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): d = Discriminator(str) a = Required(int) @@ -208,7 +216,7 @@ class Entity2(Entity1): [ (Entity2.id,), (Entity2.d, Entity2.a, Entity2.b) ]) def test_15(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): d = Discriminator(str) a = Required(int) @@ -219,7 +227,7 @@ class Entity2(Entity1): [ (Entity2.id,), (Entity2.d, Entity2.id, Entity2.a, Entity2.b) ]) def test_16(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) class Entity2(Entity1): @@ -229,7 +237,7 @@ class Entity2(Entity1): [ (Entity2.id,), (Entity2.classtype, Entity2.id, Entity2.a, Entity2.b) ]) def test_17(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): t = Discriminator(str) a = Required(int) @@ -241,24 +249,24 @@ class Entity2(Entity1): [ (Entity1.id,), (Entity1.t, Entity1.a, Entity1.b) ]) def test_18(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) - class Entity2(db.Entity1): + class Entity2(Entity1): b = Required(int) class Entity3(Entity1): c = Required(int) db.generate_mapping(create_tables=True) with db_session: - x = Entity1(a=10) - y = Entity2(a=20, b=30) - z = Entity3(a=40, c=50) + x = Entity1(id=1, a=10) + y = Entity2(id=2, a=20, b=30) + z = Entity3(id=3, a=40, c=50) with db_session: result = select(e for e in Entity1 if e.b == 30 or e.c == 50) self.assertEqual([ e.id for e in result ], [ 2, 3 ]) def test_discriminator_1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Discriminator(str) b = Required(int) @@ -283,7 +291,7 @@ class Entity2(db.Entity1): @raises_exception(TypeError, "Invalid discriminator attribute value for Foo. Expected: 'Foo', got: 'Baz'") def test_discriminator_2(self): - db = Database('sqlite', ':memory:') + db = self.db class Foo(db.Entity): id = PrimaryKey(int) a = Discriminator(str) diff --git a/pony/orm/tests/test_inner_join_syntax.py b/pony/orm/tests/test_inner_join_syntax.py index c70238aed..cfdbf56c2 100644 --- a/pony/orm/tests/test_inner_join_syntax.py +++ b/pony/orm/tests/test_inner_join_syntax.py @@ -2,32 +2,38 @@ from pony.orm import * from pony import orm +from pony.orm.tests import setup_database, teardown_database, only_for -class TestJoin(unittest.TestCase): +db = Database() - exclude_fixtures = {'test': ['clear_tables']} - @classmethod - def setUpClass(cls): - db = cls.db = Database('sqlite', ':memory:') +class Genre(db.Entity): + name = orm.Optional(str) # TODO primary key + artists = orm.Set('Artist') + favorite = orm.Optional(bool) + index = orm.Optional(int) + + +class Hobby(db.Entity): + name = orm.Required(str) + artists = orm.Set('Artist') - class Genre(db.Entity): - name = orm.Optional(str) # TODO primary key - artists = orm.Set('Artist') - favorite = orm.Optional(bool) - index = orm.Optional(int) - class Hobby(db.Entity): - name = orm.Required(str) - artists = orm.Set('Artist') +class Artist(db.Entity): + name = orm.Required(str) + age = orm.Optional(int) + hobbies = orm.Set(Hobby) + genres = orm.Set(Genre) - class Artist(db.Entity): - name = orm.Required(str) - age = orm.Optional(int) - hobbies = orm.Set(Hobby) - genres = orm.Set(Genre) +pony.options.INNER_JOIN_SYNTAX = True - db.generate_mapping(create_tables=True) + +@only_for('sqlite') +class TestJoin(unittest.TestCase): + exclude_fixtures = {'test': ['clear_tables']} + @classmethod + def setUpClass(cls): + setup_database(db) with orm.db_session: pop = Genre(name='pop') @@ -35,12 +41,14 @@ class Artist(db.Entity): Artist(name='Sia', age=40, genres=[pop, rock]) Artist(name='Lady GaGa', age=30, genres=[pop]) - pony.options.INNER_JOIN_SYNTAX = True + @classmethod + def tearDownClass(cls): + teardown_database(db) @db_session def test_join_1(self): - result = select(g.id for g in self.db.Genre for a in g.artists if a.name.startswith('S'))[:] - self.assertEqual(self.db.last_sql, """SELECT DISTINCT "g"."id" + result = select(g.id for g in db.Genre for a in g.artists if a.name.startswith('S'))[:] + self.assertEqual(db.last_sql, """SELECT DISTINCT "g"."id" FROM "Genre" "g" INNER JOIN "Artist_Genre" "t-1" ON "g"."id" = "t-1"."genre" @@ -50,9 +58,9 @@ def test_join_1(self): @db_session def test_join_2(self): - result = select(g.id for g in self.db.Genre for a in self.db.Artist + result = select(g.id for g in db.Genre for a in db.Artist if JOIN(a in g.artists) and a.name.startswith('S'))[:] - self.assertEqual(self.db.last_sql, """SELECT DISTINCT "g"."id" + self.assertEqual(db.last_sql, """SELECT DISTINCT "g"."id" FROM "Genre" "g" INNER JOIN "Artist_Genre" "t-1" ON "g"."id" = "t-1"."genre", "Artist" "a" @@ -62,9 +70,9 @@ def test_join_2(self): @db_session def test_join_3(self): - result = select(g.id for g in self.db.Genre for x in self.db.Artist for a in self.db.Artist + result = select(g.id for g in db.Genre for x in db.Artist for a in db.Artist if JOIN(a in g.artists) and a.name.startswith('S') and g.id == x.id)[:] - self.assertEqual(self.db.last_sql, '''SELECT DISTINCT "g"."id" + self.assertEqual(db.last_sql, '''SELECT DISTINCT "g"."id" FROM "Genre" "g" INNER JOIN "Artist_Genre" "t-1" ON "g"."id" = "t-1"."genre", "Artist" "x", "Artist" "a" @@ -74,9 +82,9 @@ def test_join_3(self): @db_session def test_join_4(self): - result = select(g.id for g in self.db.Genre for a in self.db.Artist for x in self.db.Artist + result = select(g.id for g in db.Genre for a in db.Artist for x in db.Artist if JOIN(a in g.artists) and a.name.startswith('S') and g.id == x.id)[:] - self.assertEqual(self.db.last_sql, '''SELECT DISTINCT "g"."id" + self.assertEqual(db.last_sql, '''SELECT DISTINCT "g"."id" FROM "Genre" "g" INNER JOIN "Artist_Genre" "t-1" ON "g"."id" = "t-1"."genre", "Artist" "a", "Artist" "x" diff --git a/pony/orm/tests/test_isinstance.py b/pony/orm/tests/test_isinstance.py index cd7be64cc..e7597a74c 100644 --- a/pony/orm/tests/test_isinstance.py +++ b/pony/orm/tests/test_isinstance.py @@ -5,8 +5,10 @@ from pony.orm import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database, only_for + +db = Database() -db = Database('sqlite', ':memory:', create_db=True) class Person(db.Entity): id = PrimaryKey(int, auto=True) @@ -14,26 +16,32 @@ class Person(db.Entity): dob = Optional(date) ssn = Required(str, unique=True) + class Student(Person): group = Required("Group") mentor = Optional("Teacher") attend_courses = Set("Course") + class Teacher(Person): teach_courses = Set("Course") apprentices = Set("Student") salary = Required(Decimal) + class Assistant(Student, Teacher): pass + class Professor(Teacher): position = Required(str) + class Group(db.Entity): number = PrimaryKey(int) students = Set("Student") + class Course(db.Entity): name = Required(str) semester = Required(int) @@ -41,18 +49,24 @@ class Course(db.Entity): teachers = Set(Teacher) PrimaryKey(name, semester) -db.generate_mapping(create_tables=True) - -with db_session: - p = Person(name='Person1', ssn='SSN1') - g = Group(number=123) - prof = Professor(name='Professor1', salary=1000, position='position1', ssn='SSN5') - a1 = Assistant(name='Assistant1', group=g, salary=100, ssn='SSN4', mentor=prof) - a2 = Assistant(name='Assistant2', group=g, salary=200, ssn='SSN6', mentor=prof) - s1 = Student(name='Student1', group=g, ssn='SSN2', mentor=a1) - s2 = Student(name='Student2', group=g, ssn='SSN3') class TestVolatile(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + p = Person(name='Person1', ssn='SSN1') + g = Group(number=123) + prof = Professor(name='Professor1', salary=1000, position='position1', ssn='SSN5') + a1 = Assistant(name='Assistant1', group=g, salary=100, ssn='SSN4', mentor=prof) + a2 = Assistant(name='Assistant2', group=g, salary=200, ssn='SSN6', mentor=prof) + s1 = Student(name='Student1', group=g, ssn='SSN2', mentor=a1) + s2 = Student(name='Student2', group=g, ssn='SSN3') + + @classmethod + def tearDownClass(cls): + teardown_database(db) + @db_session def test_1(self): q = select(p.name for p in Person if isinstance(p, Student)) diff --git a/pony/orm/tests/test_json.py b/pony/orm/tests/test_json.py index 2c01d5dc0..4c1dfa2cc 100644 --- a/pony/orm/tests/test_json.py +++ b/pony/orm/tests/test_json.py @@ -5,19 +5,21 @@ from pony.orm import * from pony.orm.tests.testutils import raises_exception, raises_if from pony.orm.ormtypes import Json, TrackedValue, TrackedList, TrackedDict +from pony.orm.tests import setup_database, teardown_database +db = Database() -db = Database('sqlite', ':memory:') class Product(db.Entity): name = Required(str) info = Optional(Json) tags = Optional(Json) -db.generate_mapping(create_tables=True) - class TestJson(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) def setUp(self): with db_session: @@ -64,6 +66,9 @@ def setUp(self): }, tags=['Tablets', 'Apple', 'Retina']) + @classmethod + def tearDownClass(cls): + teardown_database(db) def test(self): with db_session: @@ -613,12 +618,20 @@ def test_none_for_nonexistent_path(self): @db_session def test_str_cast(self): p = get(coalesce(str(p.name), 'empty') for p in Product) - self.assertTrue('AS text' in db.last_sql) + last_sql = db.last_sql + if db.provider.dialect == 'PostgreSQL': + self.assertTrue(')::text' in last_sql) + else: + self.assertTrue('AS text' in db.last_sql) @db_session def test_int_cast(self): p = get(coalesce(int(p.info['os']['version']), 0) for p in Product) - self.assertTrue('as integer' in db.last_sql) + last_sql = db.last_sql + if db.provider.dialect == 'PostgreSQL': + self.assertTrue(')::int' in last_sql) + else: + self.assertTrue('as integer' in last_sql) def test_nonzero(self): diff --git a/pony/orm/tests/test_lazy.py b/pony/orm/tests/test_lazy.py index 4a0314a46..8144e2a8e 100644 --- a/pony/orm/tests/test_lazy.py +++ b/pony/orm/tests/test_lazy.py @@ -3,19 +3,24 @@ import unittest from pony.orm.core import * +from pony.orm.tests import setup_database, teardown_database + class TestLazy(unittest.TestCase): def setUp(self): - self.db = Database('sqlite', ':memory:') + db = self.db = Database() class X(self.db.Entity): a = Required(int) b = Required(unicode, lazy=True) self.X = X - self.db.generate_mapping(create_tables=True) + setup_database(db) with db_session: - x1 = X(a=1, b='first') - x2 = X(a=2, b='second') - x3 = X(a=3, b='third') + x1 = X(id=1, a=1, b='first') + x2 = X(id=2, a=2, b='second') + x3 = X(id=3, a=3, b='third') + + def tearDown(self): + teardown_database(self.db) @db_session def test_lazy_1(self): diff --git a/pony/orm/tests/test_mapping.py b/pony/orm/tests/test_mapping.py index a86bfc85d..28a5fb444 100644 --- a/pony/orm/tests/test_mapping.py +++ b/pony/orm/tests/test_mapping.py @@ -5,13 +5,19 @@ from pony.orm.core import * from pony.orm.dbschema import DBSchemaError from pony.orm.tests.testutils import * +from pony.orm.tests import db_params, only_for + +@only_for('sqlite') class TestColumnsMapping(unittest.TestCase): + def setUp(self): + self.db = Database(**db_params) + # raise exception if mapping table by default is not found @raises_exception(OperationalError, 'no such table: Student') def test_table_check1(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): name = PrimaryKey(str) sql = "drop table if exists Student;" @@ -21,7 +27,7 @@ class Student(db.Entity): # no exception if table was specified def test_table_check2(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): name = PrimaryKey(str) sql = """ @@ -38,7 +44,7 @@ class Student(db.Entity): # raise exception if specified mapping table is not found @raises_exception(OperationalError, 'no such table: Table1') def test_table_check3(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): _table_ = 'Table1' name = PrimaryKey(str) @@ -46,7 +52,7 @@ class Student(db.Entity): # no exception if table was specified def test_table_check4(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): _table_ = 'Table1' name = PrimaryKey(str) @@ -64,7 +70,7 @@ class Student(db.Entity): # 'id' field created if primary key is not defined @raises_exception(OperationalError, 'no such column: Student.id') def test_table_check5(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): name = Required(str) sql = """ @@ -79,7 +85,7 @@ class Student(db.Entity): # 'id' field created if primary key is not defined def test_table_check6(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): name = Required(str) sql = """ @@ -96,7 +102,7 @@ class Student(db.Entity): @raises_exception(DBSchemaError, "Column 'name' already exists in table 'Student'") def test_table_check7(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): name = Required(str, column='name') record = Required(str, column='name') @@ -113,7 +119,7 @@ class Student(db.Entity): # user can specify column name for an attribute def test_custom_column_name(self): - db = Database('sqlite', ':memory:') + db = self.db class Student(db.Entity): name = PrimaryKey(str, column='name1') sql = """ @@ -131,7 +137,7 @@ class Student(db.Entity): @raises_exception(ERDiagramError, 'At least one attribute of one-to-one relationship Entity1.attr1 - Entity2.attr2 must be optional') def test_relations1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required("Entity2") @@ -142,7 +148,7 @@ class Entity2(db.Entity): # no exception Optional-Required def test_relations2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2") @@ -153,7 +159,7 @@ class Entity2(db.Entity): # no exception Optional-Required(column) def test_relations3(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required("Entity2", column='a') @@ -163,7 +169,7 @@ class Entity2(db.Entity): db.generate_mapping(create_tables=True) def test_relations4(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Required("Entity2") @@ -176,7 +182,7 @@ class Entity2(db.Entity): # no exception Optional-Optional def test_relations5(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2") @@ -187,7 +193,7 @@ class Entity2(db.Entity): # no exception Optional-Optional(column) def test_relations6(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2", column='a') @@ -197,7 +203,7 @@ class Entity2(db.Entity): db.generate_mapping(create_tables=True) def test_relations7(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional("Entity2", column='a') @@ -209,7 +215,7 @@ class Entity2(db.Entity): self.assertEqual(Entity2.attr2.columns, ['a1']) def test_columns1(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = PrimaryKey(int) attr1 = Set("Entity2") @@ -223,7 +229,7 @@ class Entity2(db.Entity): self.assertEqual(column_list[1].name, 'attr2') def test_columns2(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): a = Required(int) b = Required(int) @@ -240,7 +246,7 @@ class Entity2(db.Entity): self.assertEqual(column_list[2].name, 'attr2_b') def test_columns3(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity1(db.Entity): id = PrimaryKey(int) attr1 = Optional('Entity2') @@ -252,7 +258,7 @@ class Entity2(db.Entity): self.assertEqual(Entity2.attr2.columns, []) def test_columns4(self): - db = Database('sqlite', ':memory:') + db = self.db class Entity2(db.Entity): id = PrimaryKey(int) attr2 = Optional('Entity1') @@ -265,14 +271,14 @@ class Entity1(db.Entity): @raises_exception(ERDiagramError, "Mapping is not generated for entity 'E1'") def test_generate_mapping1(self): - db = Database('sqlite', ':memory:') + db = self.db class E1(db.Entity): a1 = Required(int) select(e for e in E1) @raises_exception(ERDiagramError, "Mapping is not generated for entity 'E1'") def test_generate_mapping2(self): - db = Database('sqlite', ':memory:') + db = self.db class E1(db.Entity): a1 = Required(int) e = E1(a1=1) diff --git a/pony/orm/tests/test_objects_to_save_cleanup.py b/pony/orm/tests/test_objects_to_save_cleanup.py index 331d87932..e00b09b23 100644 --- a/pony/orm/tests/test_objects_to_save_cleanup.py +++ b/pony/orm/tests/test_objects_to_save_cleanup.py @@ -1,31 +1,32 @@ - - import unittest - from pony.orm import * +from pony.orm.tests import setup_database, teardown_database -db = Database() - -class TestPost(db.Entity): - category = Optional('TestCategory') - name = Optional(str, default='Noname') +class EntityStatusTestCase(object): + @classmethod + def setUpClass(cls): + db = cls.db = Database() -class TestCategory(db.Entity): - posts = Set(TestPost) + class TestPost(db.Entity): + category = Optional('TestCategory') + name = Optional(str, default='Noname') -db.bind('sqlite', ':memory:') -db.generate_mapping(create_tables=True) + class TestCategory(db.Entity): + posts = Set(TestPost) + setup_database(db) -class EntityStatusTestCase(object): + @classmethod + def tearDownClass(cls): + teardown_database(cls.db) def make_flush(self, obj=None): raise NotImplementedError @db_session def test_delete_updated(self): - p = TestPost() + p = self.db.TestPost() self.make_flush(p) p.name = 'Pony' self.assertEqual(p._status_, 'modified') @@ -38,7 +39,7 @@ def test_delete_updated(self): @db_session def test_delete_inserted(self): - p = TestPost() + p = self.db.TestPost() self.assertEqual(p._status_, 'created') self.make_flush(p) self.assertEqual(p._status_, 'inserted') @@ -46,7 +47,7 @@ def test_delete_inserted(self): @db_session def test_cancelled(self): - p = TestPost() + p = self.db.TestPost() self.assertEqual(p._status_, 'created') p.delete() self.assertEqual(p._status_, 'cancelled') diff --git a/pony/orm/tests/test_prefetching.py b/pony/orm/tests/test_prefetching.py index 68509a72f..d77cb2870 100644 --- a/pony/orm/tests/test_prefetching.py +++ b/pony/orm/tests/test_prefetching.py @@ -4,8 +4,10 @@ from pony.orm import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() -db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(str) @@ -17,36 +19,45 @@ class Student(db.Entity): mentor = Optional('Teacher') biography = Optional(LongStr) + class Group(db.Entity): number = PrimaryKey(int) major = Required(str, lazy=True) students = Set(Student) + class Course(db.Entity): name = Required(str, unique=True) students = Set(Student) + class Teacher(db.Entity): name = Required(str) students = Set(Student) -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(number=1, major='Math') - g2 = Group(number=2, major='Computer Sciense') - c1 = Course(name='Math') - c2 = Course(name='Physics') - c3 = Course(name='Computer Science') - t1 = Teacher(name='T1') - t2 = Teacher(name='T2') - Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='S1 bio', mentor=t1) - Student(id=2, name='S2', group=g1, gpa=4.2, scholarship=100, dob=date(2000, 1, 1), biography='S2 bio') - Student(id=3, name='S3', group=g1, gpa=4.7, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) - Student(id=4, name='S4', group=g2, gpa=3.2, biography='S4 bio', courses=[c1, c3], mentor=t2) - Student(id=5, name='S5', group=g2, gpa=4.5, biography='S5 bio', courses=[c1, c3]) class TestPrefetching(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(number=1, major='Math') + g2 = Group(number=2, major='Computer Sciense') + c1 = Course(name='Math') + c2 = Course(name='Physics') + c3 = Course(name='Computer Science') + t1 = Teacher(name='T1') + t2 = Teacher(name='T2') + Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='S1 bio', mentor=t1) + Student(id=2, name='S2', group=g1, gpa=4.2, scholarship=100, dob=date(2000, 1, 1), biography='S2 bio') + Student(id=3, name='S3', group=g1, gpa=4.7, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) + Student(id=4, name='S4', group=g2, gpa=3.2, biography='S4 bio', courses=[c1, c3], mentor=t2) + Student(id=5, name='S5', group=g2, gpa=4.5, biography='S5 bio', courses=[c1, c3]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def test_1(self): with db_session: s1 = Student.select().first() @@ -115,11 +126,14 @@ def test_12(self): with db_session: s1 = Student.select().prefetch(Student.biography).first() self.assertEqual(s1.biography, 'S1 bio') - self.assertEqual(db.last_sql, -'''SELECT "s"."id", "s"."name", "s"."scholarship", "s"."gpa", "s"."dob", "s"."group", "s"."mentor", "s"."biography" -FROM "Student" "s" + table_name = 'Student' if db.provider.dialect == 'SQLite' and pony.__version__ < '0.9' else 'student' + expected_sql = '''SELECT "s"."id", "s"."name", "s"."scholarship", "s"."gpa", "s"."dob", "s"."group", "s"."mentor", "s"."biography" +FROM "%s" "s" ORDER BY 1 -LIMIT 1''') +LIMIT 1''' % table_name + if db.provider.dialect == 'SQLite' and pony.__version__ >= '0.9': + expected_sql = expected_sql.replace('"', '`') + self.assertEqual(db.last_sql, expected_sql) def test_13(self): db.merge_local_stats() diff --git a/pony/orm/tests/test_query.py b/pony/orm/tests/test_query.py index 94603fe06..6623463da 100644 --- a/pony/orm/tests/test_query.py +++ b/pony/orm/tests/test_query.py @@ -7,8 +7,10 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import teardown_database, setup_database + +db = Database() -db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(unicode) @@ -17,19 +19,26 @@ class Student(db.Entity): group = Required('Group') dob = Optional(date) + class Group(db.Entity): number = PrimaryKey(int) students = Set(Student) -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(number=1) - Student(id=1, name='S1', group=g1, gpa=3.1) - Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) - Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2)) class TestQuery(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(number=1) + Student(id=1, name='S1', group=g1, gpa=3.1) + Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) + Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2)) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() diff --git a/pony/orm/tests/test_random.py b/pony/orm/tests/test_random.py index 9a47efaa1..e0bcef9c6 100644 --- a/pony/orm/tests/test_random.py +++ b/pony/orm/tests/test_random.py @@ -2,23 +2,31 @@ from pony.orm import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() -db = Database('sqlite', ':memory:') class Person(db.Entity): id = PrimaryKey(int) name = Required(unicode) -db.generate_mapping(create_tables=True) - -with db_session: - Person(id=1, name='John') - Person(id=2, name='Mary') - Person(id=3, name='Bob') - Person(id=4, name='Mike') - Person(id=5, name='Ann') class TestRandom(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + Person(id=1, name='John') + Person(id=2, name='Mary') + Person(id=3, name='Bob') + Person(id=4, name='Mike') + Person(id=5, name='Ann') + + @classmethod + def tearDownClass(cls): + teardown_database(db) + @db_session def test_1(self): persons = Person.select().random(2) diff --git a/pony/orm/tests/test_raw_sql.py b/pony/orm/tests/test_raw_sql.py index 4c869be73..aa12eee7e 100644 --- a/pony/orm/tests/test_raw_sql.py +++ b/pony/orm/tests/test_raw_sql.py @@ -6,8 +6,10 @@ from pony.orm import * from pony.orm.tests.testutils import raises_exception +from pony.orm.tests import setup_database, teardown_database, only_for + +db = Database() -db = Database('sqlite', ':memory:') class Person(db.Entity): id = PrimaryKey(int) @@ -15,14 +17,21 @@ class Person(db.Entity): age = Required(int) dob = Required(date) -db.generate_mapping(create_tables=True) - -with db_session: - Person(id=1, name='John', age=30, dob=date(1985, 1, 1)) - Person(id=2, name='Mike', age=32, dob=date(1983, 5, 20)) - Person(id=3, name='Mary', age=20, dob=date(1995, 2, 15)) +@only_for('sqlite') class TestRawSQL(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + Person(id=1, name='John', age=30, dob=date(1985, 1, 1)) + Person(id=2, name='Mike', age=32, dob=date(1983, 5, 20)) + Person(id=3, name='Mary', age=20, dob=date(1995, 2, 15)) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + @db_session def test_1(self): # raw_sql result can be treated as a logical expression diff --git a/pony/orm/tests/test_relations_m2m.py b/pony/orm/tests/test_relations_m2m.py index 2f1708986..dc0552653 100644 --- a/pony/orm/tests/test_relations_m2m.py +++ b/pony/orm/tests/test_relations_m2m.py @@ -2,26 +2,30 @@ import unittest from pony.orm.core import * +from pony.orm.tests import db_params, teardown_database -class TestManyToManyNonComposite(unittest.TestCase): +db = Database() - def setUp(self): - db = Database('sqlite', ':memory:') - class Group(db.Entity): - number = PrimaryKey(int) - subjects = Set("Subject") +class Group(db.Entity): + number = PrimaryKey(int) + subjects = Set("Subject") + - class Subject(db.Entity): - name = PrimaryKey(str) - groups = Set(Group) +class Subject(db.Entity): + name = PrimaryKey(str) + groups = Set(Group) - self.db = db - self.Group = Group - self.Subject = Subject - self.db.generate_mapping(create_tables=True) +class TestManyToManyNonComposite(unittest.TestCase): + @classmethod + def setUpClass(cls): + db.bind(**db_params) + db.generate_mapping(check_tables=False) + db.drop_all_tables(with_all_data=True) + def setUp(self): + db.create_tables() with db_session: g1 = Group(number=101) g2 = Group(number=102) @@ -31,19 +35,25 @@ class Subject(db.Entity): s4 = Subject(name='Subj4') g1.subjects = [ s1, s2 ] + def tearDown(self): + teardown_database(db) + def test_1(self): - schema = self.db.schema + schema = db.schema m2m_table_name = 'Group_Subject' + if not (db.provider.dialect == 'SQLite' and pony.__version__ < '0.9'): + m2m_table_name = m2m_table_name.lower() self.assertIn(m2m_table_name, schema.tables) m2m_table = schema.tables[m2m_table_name] - fkeys = list(m2m_table.foreign_keys.values()) + if pony.__version__ >= '0.9': + fkeys = m2m_table.foreign_keys + else: + fkeys = set(m2m_table.foreign_keys.values()) self.assertEqual(len(fkeys), 2) for fk in fkeys: self.assertEqual(fk.on_delete, 'CASCADE') def test_2(self): - db, Group, Subject = self.db, self.Group, self.Subject - with db_session: g = Group.get(number=101) s = Subject.get(name='Subj1') @@ -51,11 +61,9 @@ def test_2(self): with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2']) + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2'}) def test_3(self): - db, Group, Subject = self.db, self.Group, self.Subject - with db_session: g = Group.get(number=101) s = Subject.get(name='Subj3') @@ -63,11 +71,9 @@ def test_3(self): with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2', 'Subj3']) + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2', 'Subj3'}) def test_4(self): - db, Group, Subject = self.db, self.Group, self.Subject - with db_session: g = Group.get(number=101) s = Subject.get(name='Subj3') @@ -75,11 +81,9 @@ def test_4(self): with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2']) + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2'}) def test_5(self): - db, Group, Subject = self.db, self.Group, self.Subject - with db_session: g = Group.get(number=101) s = Subject.get(name='Subj2') @@ -87,11 +91,9 @@ def test_5(self): with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1']) - - def test_5(self): - db, Group, Subject = self.db, self.Group, self.Subject + self.assertEqual(set(db_subjects), {'Subj1'}) + def test_6(self): with db_session: g = Group.get(number=101) s1, s2, s3, s4 = Subject.select()[:] @@ -100,12 +102,10 @@ def test_5(self): with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj3', 'Subj4']) + self.assertEqual(set(db_subjects), {'Subj3', 'Subj4'}) self.assertEqual(Group[101].subjects, {Subject['Subj3'], Subject['Subj4']}) def test_7(self): - db, Group, Subject = self.db, self.Group, self.Subject - with db_session: g = Group.get(number=101) s = Subject.get(name='Subj3') @@ -116,11 +116,9 @@ def test_7(self): with db_session: self.assertEqual(db.last_sql, last_sql) # assert no DELETE statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2']) + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2'}) def test_8(self): - db, Group, Subject = self.db, self.Group, self.Subject - with db_session: g = Group.get(number=101) s = Subject.get(name='Subj1') @@ -131,11 +129,9 @@ def test_8(self): with db_session: self.assertEqual(db.last_sql, last_sql) # assert no INSERT statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2']) + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2'}) def test_9(self): - db, Group, Subject = self.db, self.Group, self.Subject - with db_session: g = Group.get(number=101) s1 = Subject.get(name='Subj1') @@ -147,11 +143,9 @@ def test_9(self): with db_session: self.assertEqual(db.last_sql, last_sql) # assert no INSERT statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2']) + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2'}) def test_10(self): - db, Group, Subject = self.db, self.Group, self.Subject - with db_session: g2 = Group.get(number=102) s1 = Subject.get(name='Subj1') @@ -165,8 +159,6 @@ def test_10(self): self.assertEqual(db_subjects , []) def test_11(self): - db, Group, Subject = self.db, self.Group, self.Subject - with db_session: g = Group.get(number=101) s1, s2, s3, s4 = Subject.select()[:] @@ -174,11 +166,9 @@ def test_11(self): with db_session: db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj2', 'Subj3']) + self.assertEqual(set(db_subjects), {'Subj2', 'Subj3'}) def test_12(self): - db, Group, Subject = self.db, self.Group, self.Subject - with db_session: g = Group.get(number=101) s1, s2, s3, s4 = Subject.select()[:] @@ -189,11 +179,9 @@ def test_12(self): with db_session: self.assertEqual(db.last_sql, last_sql) # assert no INSERT statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2']) + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2'}) def test_13(self): - db, Group, Subject = self.db, self.Group, self.Subject - with db_session: g = Group.get(number=101) s1, s2, s3, s4 = Subject.select()[:] @@ -204,12 +192,10 @@ def test_13(self): with db_session: self.assertEqual(db.last_sql, last_sql) # assert no DELETE statement on commit db_subjects = db.select('subject from Group_Subject where "group" = 101') - self.assertEqual(db_subjects , ['Subj1', 'Subj2']) + self.assertEqual(set(db_subjects), {'Subj1', 'Subj2'}) @db_session def test_14(self): - db, Group, Subject = self.db, self.Group, self.Subject - g1 = Group[101] s1 = Subject['Subj1'] self.assertTrue(s1 in g1.subjects) @@ -242,8 +228,6 @@ def test_14(self): @db_session def test_15(self): - db, Group, Subject = self.db, self.Group, self.Subject - g = Group[101] e = g.subjects.is_empty() self.assertEqual(e, False) @@ -264,8 +248,6 @@ def test_15(self): @db_session def test_16(self): - db, Group = self.db, self.Group - g = Group[101] c = len(g.subjects) self.assertEqual(c, 2) @@ -284,8 +266,6 @@ def test_16(self): @db_session def test_17(self): - db, Group, Subject = self.db, self.Group, self.Subject - g = Group[101] s1 = Subject['Subj1'] s3 = Subject['Subj3'] diff --git a/pony/orm/tests/test_relations_one2many.py b/pony/orm/tests/test_relations_one2many.py index 9caa55401..9b1efe800 100644 --- a/pony/orm/tests/test_relations_one2many.py +++ b/pony/orm/tests/test_relations_one2many.py @@ -4,11 +4,12 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -class TestOneToManyRequired(unittest.TestCase): +class TestOneToManyRequired(unittest.TestCase): def setUp(self): - db = Database('sqlite', ':memory:', create_db=True) + db = Database() class Student(db.Entity): id = PrimaryKey(int) @@ -23,7 +24,7 @@ class Group(db.Entity): self.Group = Group self.Student = Student - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: g101 = Group(number=101) @@ -39,6 +40,7 @@ class Group(db.Entity): def tearDown(self): rollback() db_session.__exit__() + teardown_database(self.db) @raises_exception(ValueError, 'Attribute Student[1].group is required') def test_1(self): diff --git a/pony/orm/tests/test_relations_one2one1.py b/pony/orm/tests/test_relations_one2one1.py index 5971ac48e..046555a8b 100644 --- a/pony/orm/tests/test_relations_one2one1.py +++ b/pony/orm/tests/test_relations_one2one1.py @@ -2,20 +2,30 @@ import unittest from pony.orm.core import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() -db = Database('sqlite', ':memory:') class Male(db.Entity): name = Required(unicode) wife = Optional('Female', column='wife') + class Female(db.Entity): name = Required(unicode) husband = Optional('Male') -db.generate_mapping(create_tables=True) class TestOneToOne(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): with db_session: db.execute('delete from male') @@ -115,16 +125,16 @@ def test_8(self): @db_session def test_9(self): - f4 = Female(name='F4') - m4 = Male(name='M4', wife=f4) + f4 = Female(id=4, name='F4') + m4 = Male(id=4, name='M4', wife=f4) flush() self.assertEqual(f4._status_, 'inserted') self.assertEqual(m4._status_, 'inserted') @db_session def test_10(self): - m4 = Male(name='M4') - f4 = Female(name='F4', husband=m4) + m4 = Male(id=4, name='M4') + f4 = Female(id=4, name='F4', husband=m4) flush() self.assertEqual(f4._status_, 'inserted') self.assertEqual(m4._status_, 'inserted') diff --git a/pony/orm/tests/test_relations_one2one2.py b/pony/orm/tests/test_relations_one2one2.py index c3f5f303b..7a7d9809e 100644 --- a/pony/orm/tests/test_relations_one2one2.py +++ b/pony/orm/tests/test_relations_one2one2.py @@ -4,20 +4,30 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import teardown_database, setup_database + +db = Database() -db = Database('sqlite', ':memory:') class Male(db.Entity): name = Required(unicode) wife = Optional('Female', column='wife') + class Female(db.Entity): name = Required(unicode) husband = Optional('Male', column='husband') -db.generate_mapping(create_tables=True) class TestOneToOne2(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): with db_session: db.execute('update female set husband=null') diff --git a/pony/orm/tests/test_relations_one2one3.py b/pony/orm/tests/test_relations_one2one3.py index dffddb5ce..09aa1d2c8 100644 --- a/pony/orm/tests/test_relations_one2one3.py +++ b/pony/orm/tests/test_relations_one2one3.py @@ -4,10 +4,13 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database, only_for + +@only_for('sqlite') class TestOneToOne3(unittest.TestCase): def setUp(self): - self.db = Database('sqlite', ':memory:') + self.db = Database() class Person(self.db.Entity): name = Required(unicode) @@ -17,14 +20,14 @@ class Passport(self.db.Entity): code = Required(unicode) person = Required("Person") - self.db.generate_mapping(create_tables=True) + setup_database(self.db) with db_session: p1 = Person(name='John') Passport(code='123', person=p1) def tearDown(self): - self.db = None + teardown_database(self.db) @db_session def test_1(self): diff --git a/pony/orm/tests/test_relations_one2one4.py b/pony/orm/tests/test_relations_one2one4.py index 0bbb8306f..f9813bea3 100644 --- a/pony/orm/tests/test_relations_one2one4.py +++ b/pony/orm/tests/test_relations_one2one4.py @@ -4,35 +4,35 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -class TestOneToOne4(unittest.TestCase): - def setUp(self): - self.db = Database('sqlite', ':memory:') - - class Person(self.db.Entity): - name = Required(unicode) - passport = Optional("Passport") +db = Database() - class Passport(self.db.Entity): - code = Required(unicode) - person = Required("Person") +class Person(db.Entity): + name = Required(unicode) + passport = Optional("Passport") - self.db.generate_mapping(create_tables=True) +class Passport(db.Entity): + code = Required(unicode) + person = Required("Person") +class TestOneToOne4(unittest.TestCase): + def setUp(self): + setup_database(db) with db_session: - p1 = Person(name='John') - Passport(code='123', person=p1) + p1 = Person(id=1, name='John') + Passport(id=1, code='123', person=p1) def tearDown(self): - self.db = None + teardown_database(db) @raises_exception(ConstraintError, 'Cannot unlink Passport[1] from previous Person[1] object, because Passport.person attribute is required') @db_session def test1(self): - p2 = self.db.Person(name='Mike') - pas2 = self.db.Passport(code='456', person=p2) + p2 = Person(id=2, name='Mike') + pas2 = Passport(id=2, code='456', person=p2) commit() - p1 = self.db.Person.get(name='John') + p1 = Person.get(name='John') pas2.person = p1 if __name__ == '__main__': diff --git a/pony/orm/tests/test_relations_symmetric_m2m.py b/pony/orm/tests/test_relations_symmetric_m2m.py index b5762228d..79f49b699 100644 --- a/pony/orm/tests/test_relations_symmetric_m2m.py +++ b/pony/orm/tests/test_relations_symmetric_m2m.py @@ -2,15 +2,25 @@ import unittest from pony.orm.core import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() -db = Database('sqlite', ':memory:') class Person(db.Entity): name = Required(unicode) friends = Set('Person', reverse='friends') -db.generate_mapping(create_tables=True) + class TestSymmetricM2M(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): with db_session: for p in Person.select(): p.delete() diff --git a/pony/orm/tests/test_relations_symmetric_one2one.py b/pony/orm/tests/test_relations_symmetric_one2one.py index 47d1bad53..9dc72f20f 100644 --- a/pony/orm/tests/test_relations_symmetric_one2one.py +++ b/pony/orm/tests/test_relations_symmetric_one2one.py @@ -4,16 +4,25 @@ from pony.orm.core import * from pony.orm.tests.testutils import raises_exception +from pony.orm.tests import setup_database, teardown_database, only_for + +db = Database() -db = Database('sqlite', ':memory:') class Person(db.Entity): name = Required(unicode) spouse = Optional('Person', reverse='spouse') -db.generate_mapping(create_tables=True) class TestSymmetricOne2One(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): with db_session: db.execute('update person set spouse=null') diff --git a/pony/orm/tests/test_select_from_select_queries.py b/pony/orm/tests/test_select_from_select_queries.py index a8f7f860d..5a2bee399 100644 --- a/pony/orm/tests/test_select_from_select_queries.py +++ b/pony/orm/tests/test_select_from_select_queries.py @@ -3,14 +3,17 @@ from pony.orm import * from pony.orm.tests.testutils import * from pony.py23compat import PYPY2 +from pony.orm.tests import setup_database, teardown_database + +db = Database() -db = Database('sqlite', ':memory:') class Group(db.Entity): number = PrimaryKey(int) major = Required(str) students = Set('Student') + class Student(db.Entity): first_name = Required(unicode) last_name = Required(unicode) @@ -23,6 +26,7 @@ class Student(db.Entity): def full_name(self): return self.first_name + ' ' + self.last_name + class Course(db.Entity): name = Required(unicode) semester = Required(int) @@ -30,22 +34,27 @@ class Course(db.Entity): PrimaryKey(name, semester) students = Set('Student') -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(number=123, major='Computer Science') - g2 = Group(number=456, major='Graphic Design') - s1 = Student(id=1, first_name='John', last_name='Smith', age=20, group=g1, scholarship=0) - s2 = Student(id=2, first_name='Alex', last_name='Green', age=24, group=g1, scholarship=100) - s3 = Student(id=3, first_name='Mary', last_name='White', age=23, group=g1, scholarship=500) - s4 = Student(id=4, first_name='John', last_name='Brown', age=20, group=g2, scholarship=400) - s5 = Student(id=5, first_name='Bruce', last_name='Lee', age=22, group=g2, scholarship=300) - c1 = Course(name='Math', semester=1, credits=10, students=[s1, s2, s4]) - c2 = Course(name='Computer Science', semester=1, credits=20, students=[s2, s3]) - c3 = Course(name='3D Modeling', semester=2, credits=15, students=[s3, s5]) - class TestSelectFromSelect(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + g1 = Group(number=123, major='Computer Science') + g2 = Group(number=456, major='Graphic Design') + s1 = Student(id=1, first_name='John', last_name='Smith', age=20, group=g1, scholarship=0) + s2 = Student(id=2, first_name='Alex', last_name='Green', age=24, group=g1, scholarship=100) + s3 = Student(id=3, first_name='Mary', last_name='White', age=23, group=g1, scholarship=500) + s4 = Student(id=4, first_name='John', last_name='Brown', age=20, group=g2, scholarship=400) + s5 = Student(id=5, first_name='Bruce', last_name='Lee', age=22, group=g2, scholarship=300) + c1 = Course(name='Math', semester=1, credits=10, students=[s1, s2, s4]) + c2 = Course(name='Computer Science', semester=1, credits=20, students=[s2, s3]) + c3 = Course(name='3D Modeling', semester=2, credits=15, students=[s3, s5]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + @db_session def test_1(self): # basic select from another query q = select(s for s in Student if s.scholarship > 0) @@ -375,7 +384,7 @@ def test_43(self): def test_44(self): q = select(s for s in Student).order_by(Student.first_name).limit(3, offset=1) q2 = select(s.first_name for s in q) - self.assertEqual(set(q2), {'Bruce', 'John', 'Mary'}) + self.assertEqual(list(q2), ['Bruce', 'John', 'John']) @db_session def test_45(self): @@ -386,7 +395,7 @@ def test_45(self): @db_session def test_46(self): - q = select((c, count(c.students)) for c in Course).order_by(-2).limit(2) + q = select((c, count(c.students)) for c in Course).order_by(-2, 1).limit(2) q2 = select((c.name, c.credits, m) for c, m in q).limit(1, offset=1) self.assertEqual(set(q2), {('3D Modeling', 15, 2)}) diff --git a/pony/orm/tests/test_show.py b/pony/orm/tests/test_show.py index dc12832ac..e5846b800 100644 --- a/pony/orm/tests/test_show.py +++ b/pony/orm/tests/test_show.py @@ -6,8 +6,10 @@ from pony.orm import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + +db = Database() -db = Database('sqlite', ':memory:') class Student(db.Entity): name = Required(unicode) @@ -26,21 +28,29 @@ class Course(db.Entity): name = Required(unicode, unique=True) students = Set(Student) -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(number=1) - g2 = Group(number=2) - c1 = Course(name='Math') - c2 = Course(name='Physics') - c3 = Course(name='Computer Science') - Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='some text') - Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) - Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) normal_stdout = sys.stdout + class TestShow(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + + with db_session: + g1 = Group(number=1) + g2 = Group(number=2) + c1 = Course(name='Math') + c2 = Course(name='Physics') + c3 = Course(name='Computer Science') + Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='some text') + Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) + Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) + + @classmethod + def tearDownClass(cls): + teardown_database(db) + def setUp(self): rollback() db_session.__enter__() @@ -70,5 +80,6 @@ def test_2(self): 2~~~~~ ''') + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_sqlbuilding_formatstyles.py b/pony/orm/tests/test_sqlbuilding_formatstyles.py index 40804ffea..7f2a482dd 100644 --- a/pony/orm/tests/test_sqlbuilding_formatstyles.py +++ b/pony/orm/tests/test_sqlbuilding_formatstyles.py @@ -6,6 +6,7 @@ from pony.orm.dbapiprovider import DBAPIProvider from pony.orm.tests.testutils import TestPool + class TestFormatStyles(unittest.TestCase): def setUp(self): self.key1 = 'KEY1' diff --git a/pony/orm/tests/test_sqlbuilding_sqlast.py b/pony/orm/tests/test_sqlbuilding_sqlast.py index 3e4cae9c9..9d4493863 100644 --- a/pony/orm/tests/test_sqlbuilding_sqlast.py +++ b/pony/orm/tests/test_sqlbuilding_sqlast.py @@ -3,10 +3,14 @@ import unittest from pony.orm.core import Database, db_session from pony.orm.sqlsymbols import * +from pony.orm.tests import setup_database, only_for + +@only_for('sqlite') class TestSQLAST(unittest.TestCase): def setUp(self): - self.db = Database('sqlite', ':memory:') + self.db = Database() + setup_database(self.db) with db_session: conn = self.db.get_connection() conn.executescript(""" @@ -16,6 +20,13 @@ def setUp(self): ); insert or ignore into T1 values(1, 'abc'); """) + + def tearDown(self): + with db_session: + conn = self.db.get_connection() + conn.executescript("""drop table T1 + """) + @db_session def test_alias(self): sql_ast = [SELECT, [ALL, [COLUMN, "Group", "a"]], @@ -29,5 +40,6 @@ def test_alias2(self): sql, adapter = self.db._ast2sql(sql_ast) cursor = self.db._exec_sql(sql) + if __name__ == "__main__": unittest.main() diff --git a/pony/orm/tests/test_sqlite_str_functions.py b/pony/orm/tests/test_sqlite_str_functions.py index 2592d57d3..656004472 100644 --- a/pony/orm/tests/test_sqlite_str_functions.py +++ b/pony/orm/tests/test_sqlite_str_functions.py @@ -7,14 +7,17 @@ from pony.orm.core import * from pony.orm.tests.testutils import * +from pony.orm.tests import only_for db = Database('sqlite', ':memory:') + class Person(db.Entity): name = Required(unicode) age = Optional(int) image = Optional(buffer) + db.generate_mapping(create_tables=True) with db_session: @@ -22,6 +25,7 @@ class Person(db.Entity): p2 = Person(name=u'Иван') # u'\u0418\u0432\u0430\u043d' +@only_for('sqlite') class TestUnicode(unittest.TestCase): @db_session def test1(self): @@ -58,5 +62,6 @@ def test7(self): ages = db.select('select py_lower(image) from person') self.assertEqual(ages, [u'abcdef', None]) + if __name__ == '__main__': unittest.main() diff --git a/pony/orm/tests/test_time_parsing.py b/pony/orm/tests/test_time_parsing.py index a9307f828..064b43191 100644 --- a/pony/orm/tests/test_time_parsing.py +++ b/pony/orm/tests/test_time_parsing.py @@ -6,6 +6,7 @@ from pony.orm.tests.testutils import raises_exception from pony.converting import str2time + class TestTimeParsing(unittest.TestCase): def test_time_1(self): self.assertEqual(str2time('1:2'), time(1, 2)) diff --git a/pony/orm/tests/test_to_dict.py b/pony/orm/tests/test_to_dict.py index 8e2da08fb..29a11b225 100644 --- a/pony/orm/tests/test_to_dict.py +++ b/pony/orm/tests/test_to_dict.py @@ -5,40 +5,48 @@ from pony.orm import * from pony.orm.serialization import to_dict from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database -db = Database('sqlite', ':memory:') - -class Student(db.Entity): - name = Required(unicode) - scholarship = Optional(int) - gpa = Optional(Decimal, 3, 1) - dob = Optional(date) - group = Optional('Group') - courses = Set('Course') - biography = Optional(LongUnicode) - -class Group(db.Entity): - number = PrimaryKey(int) - students = Set(Student) - -class Course(db.Entity): - name = Required(unicode, unique=True) - students = Set(Student) - -db.generate_mapping(create_tables=True) - -with db_session: - g1 = Group(number=1) - g2 = Group(number=2) - c1 = Course(name='Math') - c2 = Course(name='Physics') - c3 = Course(name='Computer Science') - Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='some text') - Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) - Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) - Student(id=4, name='S4') class TestObjectToDict(unittest.TestCase): + @classmethod + def setUpClass(cls): + db = cls.db = Database() + + class Student(db.Entity): + name = Required(unicode) + scholarship = Optional(int) + gpa = Optional(Decimal, 3, 1) + dob = Optional(date) + group = Optional('Group') + courses = Set('Course') + biography = Optional(LongUnicode) + + class Group(db.Entity): + number = PrimaryKey(int) + students = Set(Student) + + class Course(db.Entity): + name = Required(unicode, unique=True) + students = Set(Student) + + setup_database(db) + + with db_session: + g1 = Group(number=1) + g2 = Group(number=2) + c1 = Course(id=1, name='Math') + c2 = Course(id=2, name='Physics') + c3 = Course(id=3, name='Computer Science') + Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='some text') + Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) + Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) + Student(id=4, name='S4') + + @classmethod + def tearDownClass(cls): + teardown_database(cls.db) + def setUp(self): rollback() db_session.__enter__() @@ -48,133 +56,172 @@ def tearDown(self): db_session.__exit__() def test1(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict() self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, group=1)) def test2(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(related_objects=True) self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, - group=Group[1])) + group=self.db.Group[1])) def test3(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(with_collections=True) self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, group=1, courses=[1, 2])) def test4(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(with_collections=True, related_objects=True) self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, - group=Group[1], courses=[Course[1], Course[2]])) + group=self.db.Group[1], courses=[self.db.Course[1], self.db.Course[2]])) def test5(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(with_lazy=True) self.assertEqual(d, dict(id=1, name='S1', scholarship=None, gpa=Decimal('3.1'), dob=None, group=1, biography='some text')) def test6(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(only=['id', 'name', 'group']) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test7(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(['id', 'name', 'group']) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test8(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(only='id, name, group') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test9(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(only='id name group') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test10(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict('id name group') self.assertEqual(d, dict(id=1, name='S1', group=1)) @raises_exception(AttributeError, 'Entity Student does not have attriute x') def test11(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict('id name x group') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test12(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict('id name group', related_objects=True) - self.assertEqual(d, dict(id=1, name='S1', group=Group[1])) + self.assertEqual(d, dict(id=1, name='S1', group=self.db.Group[1])) def test13(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude=['dob', 'gpa', 'scholarship']) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test14(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob, gpa, scholarship') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test15(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob gpa scholarship') self.assertEqual(d, dict(id=1, name='S1', group=1)) @raises_exception(AttributeError, 'Entity Student does not have attriute x') def test16(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob gpa x scholarship') self.assertEqual(d, dict(id=1, name='S1', group=1)) def test17(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob gpa scholarship', related_objects=True) - self.assertEqual(d, dict(id=1, name='S1', group=Group[1])) + self.assertEqual(d, dict(id=1, name='S1', group=self.db.Group[1])) def test18(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob gpa scholarship', with_lazy=True) self.assertEqual(d, dict(id=1, name='S1', group=1, biography='some text')) def test19(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob gpa scholarship biography', with_lazy=True) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test20(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob gpa scholarship', with_collections=True) self.assertEqual(d, dict(id=1, name='S1', group=1, courses=[1, 2])) def test21(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(exclude='dob gpa scholarship courses', with_collections=True) self.assertEqual(d, dict(id=1, name='S1', group=1)) def test22(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(only='id name group', exclude='dob group') self.assertEqual(d, dict(id=1, name='S1')) def test23(self): - s1 = Student[1] + s1 = self.db.Student[1] d = s1.to_dict(only='id name group', exclude='dob group', with_collections=True, with_lazy=True) self.assertEqual(d, dict(id=1, name='S1')) def test24(self): - c = Course(name='New Course') + c = self.db.Course(id=4, name='New Course') d = c.to_dict() # should do flush and get c.id from the database self.assertEqual(d, dict(id=4, name='New Course')) + class TestSerializationToDict(unittest.TestCase): + @classmethod + def setUpClass(cls): + db = cls.db = Database() + + class Student(db.Entity): + name = Required(unicode) + scholarship = Optional(int) + gpa = Optional(Decimal, 3, 1) + dob = Optional(date) + group = Optional('Group') + courses = Set('Course') + biography = Optional(LongUnicode) + + class Group(db.Entity): + number = PrimaryKey(int) + students = Set(Student) + + class Course(db.Entity): + name = Required(unicode, unique=True) + students = Set(Student) + + setup_database(db) + + with db_session: + g1 = Group(number=1) + g2 = Group(number=2) + c1 = Course(name='Math') + c2 = Course(name='Physics') + c3 = Course(name='Computer Science') + Student(id=1, name='S1', group=g1, gpa=3.1, courses=[c1, c2], biography='some text') + Student(id=2, name='S2', group=g1, gpa=3.2, scholarship=100, dob=date(2000, 1, 1)) + Student(id=3, name='S3', group=g1, gpa=3.3, scholarship=200, dob=date(2001, 1, 2), courses=[c2, c3]) + Student(id=4, name='S4') + + @classmethod + def tearDownClass(cls): + teardown_database(cls.db) + def setUp(self): rollback() db_session.__enter__() @@ -184,7 +231,7 @@ def tearDown(self): db_session.__exit__() def test1(self): - s4 = Student[4] + s4 = self.db.Student[4] self.assertEqual(s4.group, None) d = to_dict(s4) self.assertEqual(d, dict(Student={ diff --git a/pony/orm/tests/test_transaction_lock.py b/pony/orm/tests/test_transaction_lock.py index 002557664..07a5ca440 100644 --- a/pony/orm/tests/test_transaction_lock.py +++ b/pony/orm/tests/test_transaction_lock.py @@ -1,27 +1,31 @@ - - import unittest from pony.orm import * +from pony.orm.tests import setup_database, teardown_database db = Database() + class TestPost(db.Entity): category = Optional('TestCategory') name = Optional(str, default='Noname') + class TestCategory(db.Entity): posts = Set(TestPost) -db.bind('sqlite', ':memory:') -db.generate_mapping(create_tables=True) - -with db_session: - post = TestPost() - class TransactionLockTestCase(unittest.TestCase): + @classmethod + def setUpClass(cls): + setup_database(db) + with db_session: + cls.post = TestPost(id=1) + + @classmethod + def tearDownClass(cls): + teardown_database(db) __call__ = db_session(unittest.TestCase.__call__) @@ -29,14 +33,14 @@ def tearDown(self): rollback() def test_create(self): - p = TestPost() + p = TestPost(id=2) p.flush() cache = db._get_cache() self.assertEqual(cache.immediate, True) self.assertEqual(cache.in_transaction, True) def test_update(self): - p = TestPost[post.id] + p = TestPost[self.post.id] p.name = 'Trash' p.flush() cache = db._get_cache() @@ -44,7 +48,7 @@ def test_update(self): self.assertEqual(cache.in_transaction, True) def test_delete(self): - p = TestPost[post.id] + p = TestPost[self.post.id] p.delete() flush() cache = db._get_cache() diff --git a/pony/orm/tests/test_validate.py b/pony/orm/tests/test_validate.py index 020822c1e..813f88a31 100644 --- a/pony/orm/tests/test_validate.py +++ b/pony/orm/tests/test_validate.py @@ -3,30 +3,40 @@ from pony.orm import * from pony.orm import core from pony.orm.tests.testutils import raises_exception +from pony.orm.tests import db_params, teardown_database -db = Database('sqlite', ':memory:') +db = Database() class Person(db.Entity): id = PrimaryKey(int) name = Required(str) tel = Optional(str) -db.generate_mapping(check_tables=False) -with db_session: - db.execute(""" - create table Person( - id int primary key, - name text, - tel text - ) - """) +table_name = 'person' class TestValidate(unittest.TestCase): + @classmethod + def setUpClass(cls): + db.bind(**db_params) + db.generate_mapping(check_tables=False) + db.drop_all_tables(with_all_data=True) + with db_session(ddl=True): + db.execute(""" + create table "%s"( + id int primary key, + name text, + tel text + ) + """ % table_name) + + @classmethod + def tearDownClass(cls): + teardown_database(db) @db_session def setUp(self): - db.execute('delete from Person') + db.execute('delete from "%s"' % table_name) registry = getattr(core, '__warningregistry__', {}) for key in list(registry): if type(key) is not tuple: continue @@ -38,7 +48,7 @@ def setUp(self): def test_1a(self): with warnings.catch_warnings(): warnings.simplefilter('ignore', DatabaseContainsIncorrectEmptyValue) - db.insert('Person', id=1, name='', tel='111') + db.insert(table_name, id=1, name='', tel='111') p = Person.get(id=1) self.assertEqual(p.name, '') @@ -48,14 +58,14 @@ def test_1a(self): def test_1b(self): with warnings.catch_warnings(): warnings.simplefilter('error', DatabaseContainsIncorrectEmptyValue) - db.insert('Person', id=1, name='', tel='111') + db.insert(table_name, id=1, name='', tel='111') p = Person.get(id=1) @db_session def test_2a(self): with warnings.catch_warnings(): warnings.simplefilter('ignore', DatabaseContainsIncorrectEmptyValue) - db.insert('Person', id=1, name=None, tel='111') + db.insert(table_name, id=1, name=None, tel='111') p = Person.get(id=1) self.assertEqual(p.name, None) @@ -65,7 +75,7 @@ def test_2a(self): def test_2b(self): with warnings.catch_warnings(): warnings.simplefilter('error', DatabaseContainsIncorrectEmptyValue) - db.insert('Person', id=1, name=None, tel='111') + db.insert(table_name, id=1, name=None, tel='111') p = Person.get(id=1) diff --git a/pony/orm/tests/test_virtuals.py b/pony/orm/tests/test_virtuals.py new file mode 100644 index 000000000..e69de29bb diff --git a/pony/orm/tests/test_volatile.py b/pony/orm/tests/test_volatile.py index 663091efd..e534ee01f 100644 --- a/pony/orm/tests/test_volatile.py +++ b/pony/orm/tests/test_volatile.py @@ -2,27 +2,32 @@ from pony.orm import * from pony.orm.tests.testutils import * +from pony.orm.tests import setup_database, teardown_database + class TestVolatile(unittest.TestCase): def setUp(self): - db = self.db = Database('sqlite', ':memory:') + db = self.db = Database() class Item(self.db.Entity): name = Required(str) index = Required(int, volatile=True) - db.generate_mapping(create_tables=True) + setup_database(db) with db_session: - Item(name='A', index=1) - Item(name='B', index=2) - Item(name='C', index=3) + Item(id=1, name='A', index=1) + Item(id=2, name='B', index=2) + Item(id=3, name='C', index=3) + + def tearDown(self): + teardown_database(self.db) @db_session def test_1(self): db = self.db Item = db.Item - db.execute('update "Item" set "index" = "index" + 1') + db.execute('update "item" set "index" = "index" + 1') items = Item.select(lambda item: item.index > 0).order_by(Item.id)[:] a, b, c = items self.assertEqual(a.index, 2) From 8abeb10144018bd6e0957a60fb31c49f782117c8 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 22 Jan 2020 20:02:20 +0300 Subject: [PATCH 518/547] Support of `interleave in parent` for CockroachDB --- pony/orm/core.py | 27 +++++++++++++-- pony/orm/dbschema.py | 24 ++++++++++--- pony/orm/tests/test_interleave.py | 57 +++++++++++++++++++++++++++++++ 3 files changed, 101 insertions(+), 7 deletions(-) create mode 100644 pony/orm/tests/test_interleave.py diff --git a/pony/orm/core.py b/pony/orm/core.py index 7acc176a2..b0acd093f 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -1129,7 +1129,8 @@ def get_columns(table, column_names): on_delete = 'SET NULL' else: on_delete = None - table.add_foreign_key(attr.reverse.fk_name, child_columns, parent_table, parent_columns, attr.index, on_delete) + table.add_foreign_key(attr.reverse.fk_name, child_columns, parent_table, parent_columns, attr.index, + on_delete, interleave=attr.interleave) elif attr.index and attr.columns: if isinstance(attr.py_type, Array) and provider.dialect != 'PostgreSQL': pass # GIN indexes are supported only in PostgreSQL @@ -2016,7 +2017,7 @@ class Attribute(object): 'lazy', 'lazy_sql_cache', 'args', 'auto', 'default', 'reverse', 'composite_keys', \ 'column', 'columns', 'col_paths', '_columns_checked', 'converters', 'kwargs', \ 'cascade_delete', 'index', 'reverse_index', 'original_default', 'sql_default', 'py_check', 'hidden', \ - 'optimistic', 'fk_name', 'type_has_empty_value' + 'optimistic', 'fk_name', 'type_has_empty_value', 'interleave' def __deepcopy__(attr, memo): return attr # Attribute cannot be cloned by deepcopy() @cut_traceback @@ -2088,6 +2089,7 @@ def __init__(attr, py_type, *args, **kwargs): attr.sql_default = kwargs.pop('sql_default', None) attr.py_check = kwargs.pop('py_check', None) attr.hidden = kwargs.pop('hidden', False) + attr.interleave = kwargs.pop('interleave', None) attr.kwargs = kwargs attr.converters = [] def _init_(attr, entity, name): @@ -2140,6 +2142,12 @@ def _init_(attr, entity, name): elif attr.is_unique: throw(TypeError, 'Unique attribute %s cannot be of type float' % attr) if attr.is_volatile and (attr.is_pk or attr.is_collection): throw(TypeError, '%s attribute %s cannot be volatile' % (attr.__class__.__name__, attr)) + + if attr.interleave is not None: + if attr.is_collection: throw(TypeError, + '`interleave` option cannot be specified for %s attribute %r' % (attr.__class__.__name__, attr)) + if attr.interleave not in (True, False): throw(TypeError, + '`interleave` option value should be True, False or None. Got: %r' % attr.interleave) def linked(attr): reverse = attr.reverse if attr.cascade_delete is None: @@ -3750,6 +3758,21 @@ def __init__(entity, name, bases, cls_dict): new_attrs.append(attr) new_attrs.sort(key=attrgetter('id')) + interleave_attrs = [] + for attr in new_attrs: + if attr.interleave is not None: + if attr.interleave: + interleave_attrs.append(attr) + entity._interleave_ = None + if interleave_attrs: + if len(interleave_attrs) > 1: throw(TypeError, + 'only one attribute may be marked as interleave. Got: %s' + % ', '.join(repr(attr) for attr in interleave_attrs)) + interleave = interleave_attrs[0] + if not interleave.is_relation: throw(TypeError, + 'Interleave attribute should be part of relationship. Got: %r' % attr) + entity._interleave_ = interleave + indexes = entity._indexes_ = entity.__dict__.get('_indexes_', []) for attr in new_attrs: if attr.is_unique: indexes.append(Index(attr, is_pk=isinstance(attr, PrimaryKey))) diff --git a/pony/orm/dbschema.py b/pony/orm/dbschema.py index 9b32c081b..9124a52b3 100644 --- a/pony/orm/dbschema.py +++ b/pony/orm/dbschema.py @@ -142,8 +142,18 @@ def get_create_command(table): for foreign_key in sorted(itervalues(table.foreign_keys), key=lambda fk: fk.name): if schema.inline_fk_syntax and len(foreign_key.child_columns) == 1: continue cmd.append(schema.indent+foreign_key.get_sql() + ',') - cmd[-1] = cmd[-1][:-1] - cmd.append(')') + interleave_fks = [ fk for fk in table.foreign_keys.values() if fk.interleave ] + if interleave_fks: + assert len(interleave_fks) == 1 + fk = interleave_fks[0] + cmd.append(schema.indent+fk.get_sql()) + cmd.append(case(') INTERLEAVE IN PARENT %s (%s)') % ( + quote_name(fk.parent_table.name), + ', '.join(quote_name(col.name) for col in fk.child_columns) + )) + else: + cmd[-1] = cmd[-1][:-1] + cmd.append(')') for name, value in sorted(table.options.items()): option = table.format_option(name, value) if option: cmd.append(option) @@ -186,12 +196,14 @@ def add_index(table, index_name, columns, is_pk=False, is_unique=None, m2m=False if index and index.name == index_name and index.is_pk == is_pk and index.is_unique == is_unique: return index return table.schema.index_class(index_name, table, columns, is_pk, is_unique) - def add_foreign_key(table, fk_name, child_columns, parent_table, parent_columns, index_name=None, on_delete=False): + def add_foreign_key(table, fk_name, child_columns, parent_table, parent_columns, index_name=None, on_delete=False, + interleave=False): if fk_name is None: provider = table.schema.provider child_column_names = tuple(column.name for column in child_columns) fk_name = provider.get_default_fk_name(table.name, parent_table.name, child_column_names) - return table.schema.fk_class(fk_name, table, child_columns, parent_table, parent_columns, index_name, on_delete) + return table.schema.fk_class(fk_name, table, child_columns, parent_table, parent_columns, index_name, on_delete, + interleave=interleave) class Column(object): auto_template = '%(type)s PRIMARY KEY AUTOINCREMENT' @@ -329,7 +341,8 @@ def _get_create_sql(index, inside_table): class ForeignKey(Constraint): typename = 'Foreign key' - def __init__(foreign_key, name, child_table, child_columns, parent_table, parent_columns, index_name, on_delete): + def __init__(foreign_key, name, child_table, child_columns, parent_table, parent_columns, index_name, on_delete, + interleave=False): schema = parent_table.schema if schema is not child_table.schema: throw(DBSchemaError, 'Parent and child tables of foreign_key cannot belong to different schemata') @@ -356,6 +369,7 @@ def __init__(foreign_key, name, child_table, child_columns, parent_table, parent foreign_key.child_table = child_table foreign_key.child_columns = child_columns foreign_key.on_delete = on_delete + foreign_key.interleave = interleave if index_name is not False: child_columns_len = len(child_columns) diff --git a/pony/orm/tests/test_interleave.py b/pony/orm/tests/test_interleave.py new file mode 100644 index 000000000..1662b50ae --- /dev/null +++ b/pony/orm/tests/test_interleave.py @@ -0,0 +1,57 @@ +from __future__ import absolute_import, print_function, division + +import unittest + +from pony.orm.core import * +from pony.orm.tests.testutils import raises_exception +from pony.orm.tests import db_params, only_for + +@only_for(providers=['cockroach']) +class TestDiag(unittest.TestCase): + @raises_exception(TypeError, '`interleave` option cannot be specified for Set attribute Foo.x') + def test_1(self): + db = Database() + class Foo(db.Entity): + x = Set('Bar', interleave=True) + class Bar(db.Entity): + y = Required('Foo') + + @raises_exception(TypeError, "`interleave` option value should be True, False or None. Got: 'yes'") + def test_2(self): + db = Database() + class Foo(db.Entity): + x = Required('Bar', interleave='yes') + class Bar(db.Entity): + y = Set('Foo') + + @raises_exception(TypeError, 'only one attribute may be marked as interleave. Got: Foo.x, Foo.y') + def test_3(self): + db = Database() + class Foo(db.Entity): + x = Required(int, interleave=True) + y = Required(int, interleave=True) + + @raises_exception(TypeError, 'Interleave attribute should be part of relationship. Got: Foo.x') + def test_4(self): + db = Database() + class Foo(db.Entity): + x = Required(int, interleave=True) + + def test_5(self): + db = Database(**db_params) + class Bar(db.Entity): + y = Set('Foo') + + class Foo(db.Entity): + x = Required('Bar', interleave=True) + id = Required(int) + PrimaryKey(x, id) + + db.generate_mapping(create_tables=True) + s = ') INTERLEAVE IN PARENT "bar" ("x")' + self.assertIn(s, db.schema.tables['foo'].get_create_command()) + db.drop_all_tables() + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 91935c057083d174dd2bd250cb8a5beecb850275 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 19 Dec 2019 18:39:01 +0300 Subject: [PATCH 519/547] fix date difference in PostgreSQL --- pony/orm/dbproviders/postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index 9477d8fe1..de9d5fda6 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -76,7 +76,7 @@ def DATE_ADD(builder, expr, delta): def DATE_SUB(builder, expr, delta): return '(', builder(expr), ' - ', builder(delta), ')' def DATE_DIFF(builder, expr1, expr2): - return builder(expr1), ' - ', builder(expr2) + return '((', builder(expr1), ' - ', builder(expr2), ") * interval '1 day')" def DATETIME_ADD(builder, expr, delta): return '(', builder(expr), ' + ', builder(delta), ')' def DATETIME_SUB(builder, expr, delta): From f27aeeb7a365473398f66da296029cde701500fe Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 19 Dec 2019 18:46:58 +0300 Subject: [PATCH 520/547] min/max for PostgreSQL fixed --- pony/orm/sqltranslation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index b2fdf6308..841c8967e 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2885,7 +2885,7 @@ def minmax(monad, sqlop, *args): args = list(args) for i, arg in enumerate(args): if arg.type is bool: - args[i] = NumericExprMonad(int, [ 'TO_INT', arg.getsql() ], nullable=arg.nullable) + args[i] = NumericExprMonad(int, [ 'TO_INT', arg.getsql()[0] ], nullable=arg.nullable) sql = [ sqlop, None ] + [ arg.getsql()[0] for arg in args ] return ExprMonad.new(t, sql, nullable=any(arg.nullable for arg in args)) From 8cf3667c0262c25344b0f06fec75c9b9f5d80cec Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 19 Dec 2019 20:36:12 +0300 Subject: [PATCH 521/547] fix casting json to dobule in PostgreSQL --- pony/orm/dbproviders/postgres.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/postgres.py b/pony/orm/dbproviders/postgres.py index de9d5fda6..81f6d7158 100644 --- a/pony/orm/dbproviders/postgres.py +++ b/pony/orm/dbproviders/postgres.py @@ -95,7 +95,7 @@ def eval_json_path(builder, values): def JSON_QUERY(builder, expr, path): path_sql, has_params, has_wildcards = builder.build_json_path(path) return '(', builder(expr), " #> ", path_sql, ')' - json_value_type_mapping = {bool: 'boolean', int: 'int', float: 'real'} + json_value_type_mapping = {bool: 'boolean', int: 'int', float: 'double precision'} def JSON_VALUE(builder, expr, path, type): if type is ormtypes.Json: return builder.JSON_QUERY(expr, path) path_sql, has_params, has_wildcards = builder.build_json_path(path) From 2deadda6c2282789b2f545c7d03b9c70f77eaa9f Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Thu, 19 Dec 2019 20:41:57 +0300 Subject: [PATCH 522/547] Fix count by several columns in PostgreSQL --- pony/orm/sqlbuilding.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index f747b94d4..07f78db39 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -472,7 +472,10 @@ def COUNT(builder, distinct, *expr_list): assert distinct in (None, True, False) if not distinct: if not expr_list: return ['COUNT(*)'] - return 'COUNT(', join(', ', imap(builder, expr_list)), ')' + if builder.dialect == 'PostgreSQL': + return 'COUNT(', builder.ROW(*expr_list), ')' + else: + return 'COUNT(', join(', ', imap(builder, expr_list)), ')' if not expr_list: throw(AstError, 'COUNT(DISTINCT) without argument') if len(expr_list) == 1: return 'COUNT(DISTINCT ', builder(expr_list[0]), ')' From 861874b522841c8ecd92452df927d3921bafcacd Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 15 Jan 2020 14:36:39 +0300 Subject: [PATCH 523/547] PostgreSQL distinct bug fixed --- pony/orm/sqltranslation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 841c8967e..e99a9c557 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -657,7 +657,9 @@ def construct_sql_ast(translator, limit=None, offset=None, distinct=None, aggr_func_name=None, aggr_func_distinct=None, sep=None, for_update=False, nowait=False, skip_locked=False, is_not_null_checks=False): attr_offsets = None - if distinct is None: distinct = translator.distinct + if distinct is None: + if not translator.order: + distinct = translator.distinct ast_transformer = lambda ast: ast if for_update: sql_ast = [ 'SELECT_FOR_UPDATE', nowait, skip_locked ] From d8168d0171e66b25e73d9a374ae7e2d5350f0bba Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 4 Dec 2019 17:13:27 +0300 Subject: [PATCH 524/547] Array negative indexes fixes --- pony/orm/sqltranslation.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index e99a9c557..4bce4d565 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -2062,12 +2062,10 @@ def _index(monad, index, from_one, plus_one): expr_sql = monad.getsql()[0] index_sql = index.getsql()[0] value = index_sql[1] - if from_one and plus_one: - if value >= 0: - index_sql = ['VALUE', value + 1] - else: - index_sql = ['SUB', ['ARRAY_LENGTH', expr_sql], ['VALUE', abs(value) + 1]] - + if value >= 0: + index_sql = ['VALUE', value + int(from_one and plus_one)] + else: + index_sql = ['SUB', ['ARRAY_LENGTH', expr_sql], ['VALUE', abs(value + int(from_one and plus_one))]] return index_sql elif isinstance(index, NumericMixin): expr_sql = monad.getsql()[0] From 8f7996f992aa333f5023816ce92349bd6b00ee4b Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 4 Feb 2020 16:46:05 +0300 Subject: [PATCH 525/547] Micro refactoring/optimization --- pony/orm/sqlbuilding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 07f78db39..11ad4abc3 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -110,9 +110,9 @@ def flat(tree): x = stack_pop() if isinstance(x, basestring): result_append(x) else: - try: stack_extend(reversed(x)) + try: stack_extend(x) except TypeError: result_append(x) - return result + return result[::-1] def flat_conditions(conditions): result = [] From c6ccc99c47cc6f431187aa4721e4357f175658c6 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Fri, 6 Dec 2019 17:07:42 +0300 Subject: [PATCH 526/547] getitem support for strings including slices and negative indexes --- pony/orm/core.py | 6 +- pony/orm/dbproviders/sqlite.py | 17 ++ pony/orm/sqlbuilding.py | 89 ++++++++++ pony/orm/sqltranslation.py | 90 +++++----- pony/orm/tests/test_declarative_strings.py | 182 ++++++++++++++++----- 5 files changed, 297 insertions(+), 87 deletions(-) diff --git a/pony/orm/core.py b/pony/orm/core.py index b0acd093f..fc54cb498 100644 --- a/pony/orm/core.py +++ b/pony/orm/core.py @@ -5768,9 +5768,9 @@ def _get_translator(query, query_key, vars): all_func_vartypes.update(func_vartypes) if all_func_vartypes != translator.func_vartypes: return None, vars.copy() - for key, attrname in iteritems(translator.getattr_values): + for key, val in iteritems(translator.fixed_param_values): assert key in new_vars - if attrname != new_vars[key]: + if val != new_vars[key]: del database._translator_cache[query_key] return None, vars.copy() return translator, new_vars @@ -5785,7 +5785,7 @@ def _construct_sql_and_arguments(query, limit=None, offset=None, range=None, agg sql_key = HashableDict( query._key, vartypes=HashableDict(query._translator.vartypes), - getattr_values=HashableDict(translator.getattr_values), + fixed_param_values=HashableDict(translator.fixed_param_values), limit=limit, offset=offset, distinct=query._distinct, diff --git a/pony/orm/dbproviders/sqlite.py b/pony/orm/dbproviders/sqlite.py index 9f23c60f1..ad6ae38f9 100644 --- a/pony/orm/dbproviders/sqlite.py +++ b/pony/orm/dbproviders/sqlite.py @@ -81,6 +81,12 @@ def SELECT_FOR_UPDATE(builder, nowait, skip_locked, *sections): def INSERT(builder, table_name, columns, values, returning=None): if not values: return 'INSERT INTO %s DEFAULT VALUES' % builder.quote_name(table_name) return SQLBuilder.INSERT(builder, table_name, columns, values, returning) + def STRING_SLICE(builder, expr, start, stop): + if start is None: + start = [ 'VALUE', None ] + if stop is None: + stop = [ 'VALUE', None ] + return "py_string_slice(", builder(expr), ', ', builder(start), ', ', builder(stop), ")" def TODAY(builder): return "date('now', 'localtime')" def NOW(builder): @@ -616,6 +622,15 @@ def py_array_slice(array, start, stop): def py_make_array(*items): return dumps(items) +def py_string_slice(s, start, end): + if s is None: + return None + if isinstance(start, basestring): + start = int(start) + if isinstance(end, basestring): + end = int(end) + return s[start:end] + class SQLitePool(Pool): def __init__(pool, filename, create_db, **kwargs): # called separately in each thread pool.filename = filename @@ -650,6 +665,8 @@ def create_function(name, num_params, func): create_function('py_array_slice', 3, py_array_slice) create_function('py_make_array', -1, py_make_array) + create_function('py_string_slice', 3, py_string_slice) + if sqlite.sqlite_version_info >= (3, 6, 19): con.execute('PRAGMA foreign_keys = true') diff --git a/pony/orm/sqlbuilding.py b/pony/orm/sqlbuilding.py index 11ad4abc3..671ba5e37 100644 --- a/pony/orm/sqlbuilding.py +++ b/pony/orm/sqlbuilding.py @@ -523,6 +523,93 @@ def MAX(builder, distinct, *args): def SUBSTR(builder, expr, start, len=None): if len is None: return 'substr(', builder(expr), ', ', builder(start), ')' return 'substr(', builder(expr), ', ', builder(start), ', ', builder(len), ')' + def STRING_SLICE(builder, expr, start, stop): + if start is None: + start = [ 'VALUE', 0 ] + + if start[0] == 'VALUE': + start_value = start[1] + if builder.dialect == 'PostgreSQL' and start_value < 0: + index_sql = [ 'LENGTH', expr ] + if start_value < -1: + index_sql = [ 'SUB', index_sql, [ 'VALUE', -(start_value + 1) ] ] + else: + if start_value >= 0: start_value += 1 + index_sql = [ 'VALUE', start_value ] + else: + inner_sql = start + then = [ 'ADD', inner_sql, [ 'VALUE', 1 ] ] + else_ = [ 'ADD', [ 'LENGTH', expr ], then ] if builder.dialect == 'PostgreSQL' else inner_sql + index_sql = [ 'IF', [ 'GE', inner_sql, [ 'VALUE', 0 ] ], then, else_ ] + + if stop is None: + len_sql = None + elif stop[0] == 'VALUE': + stop_value = stop[1] + if start[0] == 'VALUE': + start_value = start[1] + if start_value >= 0 and stop_value >= 0: + len_sql = [ 'VALUE', stop_value - start_value ] + elif start_value < 0 and stop_value < 0: + len_sql = [ 'VALUE', stop_value - start_value ] + elif start_value >= 0 and stop_value < 0: + len_sql = [ 'SUB', [ 'LENGTH', expr ], [ 'VALUE', start_value - stop_value ]] + len_sql = [ 'MAX', False, len_sql, [ 'VALUE', 0 ] ] + elif start_value < 0 and stop_value >= 0: + len_sql = [ 'SUB', [ 'VALUE', stop_value + 1 ], index_sql ] + len_sql = [ 'MAX', False, len_sql, [ 'VALUE', 0 ] ] + else: + assert False # pragma: nocover1 + else: + start_sql = [ 'COALESCE', start, [ 'VALUE', 0 ] ] + if stop_value >= 0: + start_positive = [ 'SUB', stop, start_sql ] + start_negative = [ 'SUB', [ 'VALUE', stop_value + 1 ], index_sql ] + else: + start_positive = [ 'SUB', [ 'LENGTH', expr ], [ 'ADD', start_sql, [ 'VALUE', -stop_value ] ] ] + start_negative = [ 'SUB', stop, start_sql] + len_sql = [ 'IF', [ 'GE', start_sql, [ 'VALUE', 0 ] ], start_positive, start_negative ] + len_sql = [ 'MAX', False, len_sql, [ 'VALUE', 0 ] ] + else: + stop_sql = [ 'COALESCE', stop, [ 'VALUE', -1 ] ] + if start[0] == 'VALUE': + start_value = start[1] + start_sql = [ 'VALUE', start_value ] + if start_value >= 0: + stop_positive = [ 'SUB', stop_sql, start_sql ] + stop_negative = [ 'SUB', [ 'LENGTH', expr ], [ 'SUB', start_sql, stop_sql ] ] + else: + stop_positive = [ 'SUB', [ 'ADD', stop_sql, [ 'VALUE', 1 ] ], index_sql ] + stop_negative = [ 'SUB', stop_sql, start_sql] + len_sql = [ 'IF', [ 'GE', stop_sql, [ 'VALUE', 0 ] ], stop_positive, stop_negative ] + len_sql = [ 'MAX', False, len_sql, [ 'VALUE', 0 ] ] + else: + start_sql = [ 'COALESCE', start, [ 'VALUE', 0 ] ] + both_positive = [ 'SUB', stop_sql, start_sql ] + both_negative = both_positive + start_positive = [ 'SUB', [ 'LENGTH', expr ], [ 'SUB', start_sql, stop_sql ] ] + stop_positive = [ 'SUB', [ 'ADD', stop_sql, [ 'VALUE', 1 ] ], index_sql ] + len_sql = [ 'CASE', None, [ + ( + [ 'AND', [ 'GE', start_sql, [ 'VALUE', 0 ] ], [ 'GE', stop_sql, [ 'VALUE', 0 ] ] ], + both_positive + ), + ( + [ 'AND', [ 'LT', start_sql, [ 'VALUE', 0 ] ], [ 'LT', stop_sql, [ 'VALUE', 0 ] ] ], + both_negative + ), + ( + [ 'AND', [ 'GE', start_sql, [ 'VALUE', 0 ] ], [ 'LT', stop_sql, [ 'VALUE', 0 ] ] ], + start_positive + ), + ( + [ 'AND', [ 'LT', start_sql, [ 'VALUE', 0 ] ], [ 'GE', stop_sql, [ 'VALUE', 0 ] ] ], + stop_positive + ), + ]] + len_sql = [ 'MAX', False, len_sql, [ 'VALUE', 0 ] ] + sql = [ 'SUBSTR', expr, index_sql, len_sql ] + return builder(sql) def CASE(builder, expr, cases, default=None): if expr is None and default is not None and default[0] == 'CASE' and default[1] is None: cases2, default2 = default[2:] @@ -537,6 +624,8 @@ def CASE(builder, expr, cases, default=None): result.extend((' else ', builder(default))) result.append(' end') return result + def IF(builder, cond, then, else_): + return builder.CASE(None, [(cond, then)], else_) def TRIM(builder, expr, chars=None): if chars is None: return 'trim(', builder(expr), ')' return 'trim(', builder(expr), ', ', builder(chars), ')' diff --git a/pony/orm/sqltranslation.py b/pony/orm/sqltranslation.py index 4bce4d565..4ad1a2af7 100644 --- a/pony/orm/sqltranslation.py +++ b/pony/orm/sqltranslation.py @@ -257,7 +257,7 @@ def init(translator, tree, parent_translator, code_key=None, filter_num=None, ex translator.vartypes = vartypes translator.namespace_stack = [{}] if not parent_translator else [ parent_translator.namespace.copy() ] translator.func_extractors_map = {} - translator.getattr_values = {} + translator.fixed_param_values = {} translator.func_vartypes = {} translator.left_join = left_join translator.optimize = optimize @@ -1841,17 +1841,37 @@ def mixin_init(monad): assert issubclass(monad.type, basestring), monad.type __add__ = make_string_binop('+', 'CONCAT') def __getitem__(monad, index): + root_translator = monad.translator.root_translator + dialect = root_translator.database.provider.dialect + + def param_to_const(monad, is_start=True): + if isinstance(monad, ParamMonad): + key = monad.paramkey[0] + if key in root_translator.fixed_param_values: + index_value = root_translator.fixed_param_values[key] + else: + index_value = root_translator.vars[key] + if index_value is None: + index_value = 0 if is_start else -1 + root_translator.fixed_param_values[key] = index_value + return ConstMonad.new(index_value) + return monad + if isinstance(index, ListMonad): throw(TypeError, "String index must be of 'int' type. Got 'tuple' in {EXPR}") elif isinstance(index, slice): if index.step is not None: throw(TypeError, 'Step is not supported in {EXPR}') start, stop = index.start, index.stop - if start is None and stop is None: return monad - if isinstance(monad, StringConstMonad) \ - and (start is None or isinstance(start, NumericConstMonad)) \ - and (stop is None or isinstance(stop, NumericConstMonad)): - if start is not None: start = start.value - if stop is not None: stop = stop.value - return ConstMonad.new(monad.value[start:stop]) + start = param_to_const(start, is_start=True) + stop = param_to_const(stop, is_start=False) + start_value = stop_value = None + if start is None: start_value = 0 + if stop_value is None: stop_value = -1 + if isinstance(start, ConstMonad): start_value = start.value + if isinstance(stop, ConstMonad): stop_value = stop.value + if start_value == 0 and stop_value == -1: + return monad + if isinstance(monad, StringConstMonad) and start_value is not None and stop_value is not None: + return ConstMonad.new(monad.value[start_value:stop_value]) if start is not None and start.type is not int: throw(TypeError, "Invalid type of start index (expected 'int', got %r) in string slice {EXPR}" % type2str(start.type)) @@ -1859,46 +1879,34 @@ def __getitem__(monad, index): throw(TypeError, "Invalid type of stop index (expected 'int', got %r) in string slice {EXPR}" % type2str(stop.type)) expr_sql = monad.getsql()[0] - if start is None: start = ConstMonad.new(0) - - if isinstance(start, NumericConstMonad): - if start.value < 0: throw(NotImplementedError, 'Negative indices are not supported in string slice {EXPR}') - start_sql = [ 'VALUE', start.value + 1 ] - else: - start_sql = start.getsql()[0] - start_sql = [ 'ADD', start_sql, [ 'VALUE', 1 ] ] - - if stop is None: - len_sql = None - elif isinstance(stop, NumericConstMonad): - if stop.value < 0: throw(NotImplementedError, 'Negative indices are not supported in string slice {EXPR}') - if isinstance(start, NumericConstMonad): - len_sql = [ 'VALUE', stop.value - start.value ] - else: - len_sql = [ 'SUB', [ 'VALUE', stop.value ], start.getsql()[0] ] - else: - stop_sql = stop.getsql()[0] - if isinstance(start, NumericConstMonad): - len_sql = [ 'SUB', stop_sql, [ 'VALUE', start.value ] ] - else: - len_sql = [ 'SUB', stop_sql, start.getsql()[0] ] - - sql = [ 'SUBSTR', expr_sql, start_sql, len_sql ] - return StringExprMonad(monad.type, sql, - nullable=monad.nullable or start.nullable or stop is not None and stop.nullable) + start_sql = None if start is None else start.getsql()[0] + stop_sql = None if stop is None else stop.getsql()[0] + sql = [ 'STRING_SLICE', expr_sql, start_sql, stop_sql ] + return StringExprMonad(monad.type, sql, nullable= + monad.nullable or start is not None and start.nullable or stop is not None and stop.nullable) + index = param_to_const(index) if isinstance(monad, StringConstMonad) and isinstance(index, NumericConstMonad): return ConstMonad.new(monad.value[index.value]) if index.type is not int: throw(TypeError, 'String indices must be integers. Got %r in expression {EXPR}' % type2str(index.type)) expr_sql = monad.getsql()[0] + if isinstance(index, NumericConstMonad): value = index.value - if value >= 0: value += 1 - index_sql = [ 'VALUE', value ] + if dialect == 'PostgreSQL' and value < 0: + index_sql = [ 'LENGTH', expr_sql ] + if value < -1: + index_sql = [ 'SUB', index_sql, [ 'VALUE', -(value + 1) ] ] + else: + if value >= 0: value += 1 + index_sql = [ 'VALUE', value ] else: inner_sql = index.getsql()[0] - index_sql = [ 'ADD', inner_sql, [ 'CASE', None, [ (['GE', inner_sql, [ 'VALUE', 0 ]], [ 'VALUE', 1 ]) ], [ 'VALUE', 0 ] ] ] + then = ['ADD', inner_sql, ['VALUE', 1]] + else_ = [ 'ADD', ['LENGTH', expr_sql], then ] if dialect == 'PostgreSQL' else inner_sql + index_sql = [ 'IF', [ 'GE', inner_sql, [ 'VALUE', 0 ] ], then, else_ ] + sql = [ 'SUBSTR', expr_sql, index_sql, [ 'VALUE', 1 ] ] return StringExprMonad(monad.type, sql, nullable=monad.nullable) def negate(monad): @@ -2776,11 +2784,11 @@ def call(monad, obj_monad, name_monad): elif isinstance(name_monad, ParamMonad): translator = monad.translator.root_translator key = name_monad.paramkey[0] - if key in translator.getattr_values: - attrname = translator.getattr_values[key] + if key in translator.fixed_param_values: + attrname = translator.fixed_param_values[key] else: attrname = translator.vars[key] - translator.getattr_values[key] = attrname + translator.fixed_param_values[key] = attrname else: throw(TranslationError, 'Expression `{EXPR}` cannot be translated into SQL ' 'because %s will be different for each row' % ast2src(name_monad.node)) if not isinstance(attrname, basestring): diff --git a/pony/orm/tests/test_declarative_strings.py b/pony/orm/tests/test_declarative_strings.py index f82b6ff18..3c6210bd1 100644 --- a/pony/orm/tests/test_declarative_strings.py +++ b/pony/orm/tests/test_declarative_strings.py @@ -4,7 +4,7 @@ from pony.orm.core import * from pony.orm.tests.testutils import * -from pony.orm.tests import setup_database, teardown_database +from pony.orm.tests import setup_database, teardown_database, only_for db = Database() @@ -37,103 +37,199 @@ def tearDown(self): rollback() db_session.__exit__() - def test_getitem_01(self): - result = set(select(s for s in Student if s.name[:] == 'Ann')) - self.assertEqual(result, {Student[1]}) - def test_getitem_1(self): result = set(select(s for s in Student if s.name[1] == 'o')) self.assertEqual(result, {Student[2], Student[4]}) - - def test_getitem_2(self): x = 1 result = set(select(s for s in Student if s.name[x] == 'o')) self.assertEqual(result, {Student[2], Student[4]}) - def test_getitem_3(self): + def test_getitem_2(self): result = set(select(s for s in Student if s.name[-1] == 'n')) self.assertEqual(result, {Student[1], Student[4]}) - - def test_getitem_4(self): x = -1 result = set(select(s for s in Student if s.name[x] == 'n')) self.assertEqual(result, {Student[1], Student[4]}) - def test_getitem_5(self): + def test_getitem_3(self): result = set(select(s for s in Student if s.name[-2] == 't')) self.assertEqual(result, {Student[3], Student[5]}) - - @sql_debugging - def test_getitem_6(self): x = -2 - select((s.name, s.name[x]) for s in Student).show() result = set(select(s for s in Student if s.name[x] == 't')) self.assertEqual(result, {Student[3], Student[5]}) + def test_getitem_4(self): + result = set(select(s for s in Student if s.name[-s.id] == 'n')) + self.assertEqual(result, {Student[1]}) + def test_slice_1(self): - result = set(select(s for s in Student if s.name[0:3] == "Jon")) - self.assertEqual(result, {Student[4]}) + result = set(select(s for s in Student if s.name[:] == "Ann")) + self.assertEqual(result, {Student[1]}) + result = set(select(s for s in Student if s.name[0:] == "Ann")) + self.assertEqual(result, {Student[1]}) def test_slice_2(self): result = set(select(s for s in Student if s.name[:3] == "Jon")) self.assertEqual(result, {Student[4]}) - - def test_slice_3(self): - x = 3 - result = set(select(s for s in Student if s.name[:x] == "Jon")) + result = set(select(s for s in Student if s.name[0:3] == "Jon")) self.assertEqual(result, {Student[4]}) - - def test_slice_4(self): - x = 3 - result = set(select(s for s in Student if s.name[0:x] == "Jon")) + x = 0 + y = 3 + result = set(select(s for s in Student if s.name[:y] == "Jon")) + self.assertEqual(result, {Student[4]}) + result = set(select(s for s in Student if s.name[x:y] == "Jon")) + self.assertEqual(result, {Student[4]}) + result = set(select(s for s in Student if s.name[x:3] == "Jon")) self.assertEqual(result, {Student[4]}) - def test_slice_5(self): + def test_slice_3(self): result = set(select(s for s in Student if s.name[0:10] == "Ann")) self.assertEqual(result, {Student[1]}) - - def test_slice_6(self): - result = set(select(s for s in Student if s.name[0:] == "Ann")) + x = 10 + result = set(select(s for s in Student if s.name[0:x] == "Ann")) self.assertEqual(result, {Student[1]}) - - def test_slice_7(self): - result = set(select(s for s in Student if s.name[:] == "Ann")) + result = set(select(s for s in Student if s.name[:x] == "Ann")) self.assertEqual(result, {Student[1]}) def test_slice_8(self): result = set(select(s for s in Student if s.name[1:] == "nn")) self.assertEqual(result, {Student[1]}) - - def test_slice_9(self): x = 1 result = set(select(s for s in Student if s.name[x:] == "nn")) self.assertEqual(result, {Student[1]}) def test_slice_10(self): - x = 0 - result = set(select(s for s in Student if s.name[x:3] == "Ann")) - self.assertEqual(result, {Student[1]}) - - def test_slice_11(self): result = set(select(s for s in Student if s.name[1:3] == "et")) self.assertEqual(result, {Student[3], Student[5]}) - - def test_slice_12(self): x = 1 y = 3 result = set(select(s for s in Student if s.name[x:y] == "et")) self.assertEqual(result, {Student[3], Student[5]}) - def test_slice_13(self): + def test_slice_11(self): x = 10 y = 20 result = set(select(s for s in Student if s.name[x:y] == '')) self.assertEqual(result, {Student[1], Student[2], Student[3], Student[4], Student[5]}) - def test_slice_14(self): + def test_slice_12(self): result = set(select(s for s in Student if s.name[-2:] == "nn")) self.assertEqual(result, {Student[1]}) + def test_slice_13(self): + result = set(select(s for s in Student if s.name[:-1] == "Ann")) + self.assertEqual(result, {Student[1]}) + result = set(select(s for s in Student if s.name[0:-1] == "Ann")) + self.assertEqual(result, {Student[1]}) + x = 0 + y = -1 + result = set(select(s for s in Student if s.name[x:y] == "Ann")) + self.assertEqual(result, {Student[1]}) + + def test_slice_14(self): + result = set(select(s for s in Student if s.name[-4:-2] == "th")) + self.assertEqual(result, {Student[4]}) + x = -4 + y = -2 + result = set(select(s for s in Student if s.name[x:y] == "th")) + self.assertEqual(result, {Student[4]}) + + def test_slice_15(self): + result = set(select(s for s in Student if s.name[4:-2] == "th")) + self.assertEqual(result, {Student[4]}) + x = 4 + y = -2 + result = set(select(s for s in Student if s.name[x:y] == "th")) + self.assertEqual(result, {Student[4]}) + + def test_slice_16(self): + result = list(select(s.name[-2:3] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'nn', 'ob', 't', 't']) + x = -2 + y = 3 + result = list(select(s.name[x:y] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'nn', 'ob', 't', 't']) + + def test_slice_17(self): + result = list(select(s.name[s.id:5] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'b', 'h', 'nn', 't']) + x = 5 + result = list(select(s.name[s.id:x] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'b', 'h', 'nn', 't']) + + def test_slice_18(self): + result = list(select(s.name[-s.id:5] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['Pete', 'eth', 'n', 'ob', 't']) + x = 5 + result = list(select(s.name[-s.id:x] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['Pete', 'eth', 'n', 'ob', 't']) + + def test_slice_19a(self): + result = list(select(s.name[s.id:] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'b', 'h', 'nn', 'than']) + + def test_slice_19b(self): + result = list(select(s.name[s.id:-1] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', '', '', 'n', 'tha']) + x = -1 + result = list(select(s.name[s.id:x] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', '', '', 'n', 'tha']) + + def test_slice_19c(self): + result = list(select(s.name[s.id:-2] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', '', '', '', 'th']) + x = -2 + result = list(select(s.name[s.id:x] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', '', '', '', 'th']) + + def test_slice_20a(self): + result = list(select(s.name[-s.id:] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['Pete', 'eth', 'n', 'ob', 'than']) + + def test_slice_20b(self): + result = list(select(s.name[-s.id:-1] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'Pet', 'et', 'o', 'tha']) + x = -1 + result = list(select(s.name[-s.id:x] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'Pet', 'et', 'o', 'tha']) + + def test_slice_20c(self): + result = list(select(s.name[-s.id:-2] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', '', 'Pe', 'e', 'th']) + x = -2 + result = list(select(s.name[-s.id:x] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', '', 'Pe', 'e', 'th']) + + def test_slice_21(self): + result = list(select(s.name[1:s.id] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'et', 'ete', 'o', 'ona']) + x = 1 + result = list(select(s.name[x:s.id] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'et', 'ete', 'o', 'ona']) + + def test_slice_22(self): + result = list(select(s.name[-3:s.id] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'A', 'Bo', 'et', 'ete']) + x = -3 + result = list(select(s.name[x:s.id] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'A', 'Bo', 'et', 'ete']) + + def test_slice_23(self): + result = list(select(s.name[s.id:s.id+3] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'b', 'h', 'nn', 'tha']) + + def test_slice_24(self): + result = list(select(s.name[-s.id*2:-s.id] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', 'B', 'B', 'Jona', 'n']) + + def test_slice_25(self): + result = list(select(s.name[s.id:-s.id+3] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['', '', '', 'n', 'tha']) + + def test_slice_26(self): + result = list(select(s.name[-s.id:s.id+3] for s in Student).without_distinct()) + self.assertEqual(sorted(result), ['Pete', 'eth', 'n', 'ob', 'tha']) + def test_nonzero(self): result = set(select(s for s in Student if s.foo)) self.assertEqual(result, {Student[1], Student[2], Student[3]}) From bd88abe80d82df2eff55e1e5f5430aff3416c6a2 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Tue, 4 Feb 2020 18:32:31 +0300 Subject: [PATCH 527/547] Added TeamCity tests builds --- README.md | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/README.md b/README.md index 6ebf3a190..216130372 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,28 @@ +#### PostgreSQL +Python 2 + + +Python 3 + + + +#### SQLite +Python 2 + + +Python 3 + + + +#### CockroachDB +Python 2 + + +Python 3 + + + + Pony Object-Relational Mapper ============================= From 836224ae79019bf0a3691b34004e877895901b68 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 4 Feb 2020 20:08:19 +0300 Subject: [PATCH 528/547] Update CHANGELOG.md and pony.__version__: 0.7.12-dev -> 0.7.12 --- CHANGELOG.md | 22 ++++++++++++++++++++++ pony/__init__.py | 2 +- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 74912fc20..4ecd45a4b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,25 @@ +# PonyORM release 0.7.12 (2020-02-04) + +## Features + +* CockroachDB support added +* CI testing for SQLite, PostgreSQL & CockroachDB + +## Bugfixes + +* Fix translation of getting array items with negative indexes +* Fix string getitem translation for slices and negative indexes +* PostgreSQL DISTINCT bug fixed for queries with ORDER BY clause +* Fix date difference syntax in PostgreSQL +* Fix casting json to dobule in PostgreSQL +* Fix count by several columns in PostgreSQL +* Fix PostgreSQL MIN and MAX expressions on boolean columns +* Fix determination of interactive mode in PyCharm +* Fix column definition when `sql_default` is specified: DEFAULT should be before NOT NULL +* Relax checks on updating in-memory cache indexes (don't throw CacheIndexError on valid cases) +* Fix deduplication logic for attribute values + + # PonyORM release 0.7.11 (2019-10-23) ## Features diff --git a/pony/__init__.py b/pony/__init__.py index 3cd2176ef..e4fb02b81 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.12-dev' +__version__ = '0.7.12' uid = str(random.randint(1, 1000000)) From 76c8eef6ddd09e9a3c2ad7dbb606ad26ac648a80 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 26 Feb 2020 09:04:26 +0300 Subject: [PATCH 529/547] Update Pony version: 0.7.12 -> 0.7.13-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index e4fb02b81..b204d9a7d 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.12' +__version__ = '0.7.13-dev' uid = str(random.randint(1, 1000000)) From 371cb3004aea52f3e7512a7397a71e2403d15405 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 26 Feb 2020 12:59:06 +0300 Subject: [PATCH 530/547] Update README.md --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 216130372..c0f3e9eef 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,8 @@ +# Downloads +[![Downloads](https://pepy.tech/badge/pony)](https://pepy.tech/project/pony) [![Downloads](https://pepy.tech/badge/pony/month)](https://pepy.tech/project/pony/month) [![Downloads](https://pepy.tech/badge/pony/week)](https://pepy.tech/project/pony/week) + +# Tests + #### PostgreSQL Python 2 From dbdae7b879f2712e0185d620736e5d8fe2c7c1a2 Mon Sep 17 00:00:00 2001 From: Alexander Tischenko Date: Wed, 26 Feb 2020 12:59:06 +0300 Subject: [PATCH 531/547] Update README.md --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 216130372..c0f3e9eef 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,8 @@ +# Downloads +[![Downloads](https://pepy.tech/badge/pony)](https://pepy.tech/project/pony) [![Downloads](https://pepy.tech/badge/pony/month)](https://pepy.tech/project/pony/month) [![Downloads](https://pepy.tech/badge/pony/week)](https://pepy.tech/project/pony/week) + +# Tests + #### PostgreSQL Python 2 From b8e2ed0c8ef564649404539f5624460d76e2e693 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Tue, 3 Mar 2020 15:34:58 +0300 Subject: [PATCH 532/547] Update CHANGELOG.md and pony.__version__: 0.7.13-dev -> 0.7.13 --- CHANGELOG.md | 5 +++++ pony/__init__.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ecd45a4b..6e5d64213 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,8 @@ +# PonyORM release 0.7.13 (2020-03-03) + +This release does not contains new features or bugfixes. Its goal is to test automatic release building and uploading + + # PonyORM release 0.7.12 (2020-02-04) ## Features diff --git a/pony/__init__.py b/pony/__init__.py index b204d9a7d..53b3a770b 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.13-dev' +__version__ = '0.7.13' uid = str(random.randint(1, 1000000)) From 5ba67c84853e0c7d0b88c3d89ac02e6fe3aa93d3 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 4 Mar 2020 00:22:15 +0300 Subject: [PATCH 533/547] Fix wording --- CHANGELOG.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e5d64213..f6e083879 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,6 @@ # PonyORM release 0.7.13 (2020-03-03) -This release does not contains new features or bugfixes. Its goal is to test automatic release building and uploading - +This release contains no new features or bugfixes. The only reason for this release is to test our CI/CD process. # PonyORM release 0.7.12 (2020-02-04) From d702c5c6db8013f97607df9021e6994157d5ee18 Mon Sep 17 00:00:00 2001 From: Alexander Kozlovsky Date: Wed, 4 Mar 2020 00:23:23 +0300 Subject: [PATCH 534/547] Update version: 0.7.13 -> 0.7.14-dev --- pony/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/__init__.py b/pony/__init__.py index 53b3a770b..f39a9b4d9 100644 --- a/pony/__init__.py +++ b/pony/__init__.py @@ -4,7 +4,7 @@ from os.path import dirname from itertools import count -__version__ = '0.7.13' +__version__ = '0.7.14-dev' uid = str(random.randint(1, 1000000)) From 5193bb1f07e42349dadb01bd72b22e62baca4447 Mon Sep 17 00:00:00 2001 From: Alexey Malashkevich Date: Mon, 2 Mar 2020 13:01:27 -0500 Subject: [PATCH 535/547] Update BACKERS.md --- BACKERS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/BACKERS.md b/BACKERS.md index cb5e8cb4f..1dc5db04c 100644 --- a/BACKERS.md +++ b/BACKERS.md @@ -15,3 +15,4 @@ Pony ORM is Apache 2.0 licensed open source project. If you would like to suppor - Johnathan Nader - Andrei Rachalouski - Juan Pablo Scaletti +- Marcus Birkenkrahe From 3769fd48494ef417162864f7924b6fee1e8621d0 Mon Sep 17 00:00:00 2001 From: Samuel Ostholm Date: Sat, 15 Aug 2020 13:12:05 +0200 Subject: [PATCH 536/547] Working SQL Server CRUD, running model1.py test works. Basic JSON has also been implemented. --- pony/orm/dbproviders/mssql.py | 350 ++++++++++++++++++++++++++++++++++ 1 file changed, 350 insertions(+) create mode 100644 pony/orm/dbproviders/mssql.py diff --git a/pony/orm/dbproviders/mssql.py b/pony/orm/dbproviders/mssql.py new file mode 100644 index 000000000..ebb5e9cc1 --- /dev/null +++ b/pony/orm/dbproviders/mssql.py @@ -0,0 +1,350 @@ +from __future__ import absolute_import +from pony.py23compat import PY2, imap, basestring, buffer, int_types + +import json +from decimal import Decimal +from datetime import datetime, date, time, timedelta +from uuid import UUID + +NoneType = type(None) + +import warnings +warnings.filterwarnings('ignore', '^Table.+already exists$', Warning, '^pony\\.orm\\.dbapiprovider$') + +try: + import pyodbc as mssql_module + MSSQL_module_name = 'pyodbc' +except ImportError: + raise ImportError('In order to use PonyORM with MSSQL please install pyodbc') + +from pony.orm import core, dbschema, dbapiprovider, ormtypes, sqltranslation +from pony.orm.core import log_orm +from pony.orm.dbapiprovider import DBAPIProvider, Pool, get_version_tuple, wrap_dbapi_exceptions +from pony.orm.sqltranslation import SQLTranslator, TranslationError +from pony.orm.sqlbuilding import Value, Param, SQLBuilder, join +from pony.utils import throw +from pony.converting import str2timedelta, timedelta2str + +class MSSQLColumn(dbschema.Column): + auto_template = '%(type)s IDENTITY(1,1) PRIMARY KEY' + +class MSSQLSchema(dbschema.DBSchema): + dialect = 'MSSQL' + inline_fk_syntax = False + column_class = MSSQLColumn + +class MSSQLTranslator(SQLTranslator): + dialect = 'MSSQL' + json_path_wildcard_syntax = True + +class MSSQLValue(Value): + __slots__ = [] + def __unicode__(self): + value = self.value + if isinstance(value, timedelta): + if value.microseconds: + return "INTERVAL '%s' HOUR_MICROSECOND" % timedelta2str(value) + return "INTERVAL '%s' HOUR_SECOND" % timedelta2str(value) + return Value.__unicode__(self) + if not PY2: __str__ = __unicode__ + +class MSSQLBuilder(SQLBuilder): + dialect = 'MSSQL' + value_class = MSSQLValue + + def CONCAT(builder, *args): + return 'CONCAT(', join(', ', imap(builder, args)), ')' + def TRIM(builder, expr, chars=None): + if chars is None: return 'TRIM(', builder(expr), ')' + return 'TRIM(', builder(chars), ' FROM ' ,builder(expr), ')' + def LTRIM(builder, expr, chars=None): + if chars is None: return 'ltrim(', builder(expr), ')' + return 'LTRIM(', builder(chars), ' FROM ' ,builder(expr), ')' + def RTRIM(builder, expr, chars=None): + if chars is None: return 'RTRIM(', builder(expr), ')' + return 'RTRIM(', builder(chars), ' FROM ' ,builder(expr), ')' + def TO_INT(builder, expr): + return 'CAST(', builder(expr), ' AS int)' + def TO_REAL(builder, expr): + return 'CAST(', builder(expr), ' AS float)' + def TO_STR(builder, expr): + return 'CAST(', builder(expr), ' AS nvarchar)' + def YEAR(builder, expr): + return 'YEAR(', builder(expr), ')' + def MONTH(builder, expr): + return 'MONTH(', builder(expr), ')' + def DAY(builder, expr): + return 'DAY(', builder(expr), ')' + def HOUR(builder, expr): + return 'DATEPART(hh, ', builder(expr), ')' + def MINUTE(builder, expr): + return 'DATEPART(n, ', builder(expr), ')' + def SECOND(builder, expr): + return 'DATEPART(ss, ', builder(expr), ')' + +# TODO: not fixed this yet + def DATE_ADD(builder, expr, delta): + if delta[0] == 'VALUE' and isinstance(delta[1], time): + return 'ADDTIME(', builder(expr), ', ', builder(delta), ')' + return 'ADDDATE(', builder(expr), ', ', builder(delta), ')' + def DATE_SUB(builder, expr, delta): + if delta[0] == 'VALUE' and isinstance(delta[1], time): + return 'SUBTIME(', builder(expr), ', ', builder(delta), ')' + return 'SUBDATE(', builder(expr), ', ', builder(delta), ')' + def DATE_DIFF(builder, expr1, expr2): + return 'TIMEDIFF(', builder(expr1), ', ', builder(expr2), ')' + def DATETIME_ADD(builder, expr, delta): + return builder.DATE_ADD(expr, delta) + def DATETIME_SUB(builder, expr, delta): + return builder.DATE_SUB(expr, delta) + def DATETIME_DIFF(builder, expr1, expr2): + return 'TIMEDIFF(', builder(expr1), ', ', builder(expr2), ')' +# End todo + + def JSON_VALUE(builder, expr, path, type): + path_sql, has_params, has_wildcards = builder.build_json_path(path) + escaped = escapify(builder(expr)) + result = 'JSON_VALUE(', escaped, ', ', path_sql, ')' + return result + +class MSSQLStrConverter(dbapiprovider.StrConverter): + def sql_type(converter): + if converter.max_len: + return 'NVARCHAR(%d)' % converter.max_len + return 'TEXT' + +class MSSQLRealConverter(dbapiprovider.RealConverter): + def sql_type(converter): + return 'float' + +class MSSQLBlobConverter(dbapiprovider.BlobConverter): + def sql_type(converter): + return 'LONGBLOB' + +class MSSQLTimeConverter(dbapiprovider.TimeConverter): + def sql2py(converter, val): + if isinstance(val, timedelta): # MSSQLdb returns timedeltas instead of times + total_seconds = val.days * (24 * 60 * 60) + val.seconds + if 0 <= total_seconds <= 24 * 60 * 60: + minutes, seconds = divmod(total_seconds, 60) + hours, minutes = divmod(minutes, 60) + return time(hours, minutes, seconds, val.microseconds) + elif not isinstance(val, time): throw(ValueError, + 'Value of unexpected type received from database%s: instead of time or timedelta got %s' + % ('for attribute %s' % converter.attr if converter.attr else '', type(val))) + return val + +class MSSQLTimedeltaConverter(dbapiprovider.TimedeltaConverter): + sql_type_name = 'TIME' + +class MSSQLUuidConverter(dbapiprovider.UuidConverter): + def sql_type(converter): + return 'BINARY(16)' + +class MSSQLJsonConverter(dbapiprovider.JsonConverter): + def sql_type(self): + return 'NVARCHAR(MAX)' + +def quotify(sql, arguments): + for arg in arguments: + if isinstance(arg, str): + sql = sql.replace('{}', "'" + arg + "'" , 1) + else: + sql = sql.replace('{}', str(arg), 1) + + return sql + +def escapify(sql): + for escape in range(int(sql.count('|') / 2)): + sql = sql.replace('|', '[', 1) + sql = sql.replace('|', ']', 1) + return sql + +class MSSQLProvider(DBAPIProvider): + dialect = 'MSSQL' + paramstyle = 'format' + quote_char = "|" + max_name_len = 64 + max_params_count = 10000 + table_if_not_exists_syntax = True + index_if_not_exists_syntax = False + max_time_precision = default_time_precision = 0 + varchar_default_max_len = 255 + uint64_support = True + + dbapi_module = mssql_module + dbschema_cls = MSSQLSchema + translator_cls = MSSQLTranslator + sqlbuilder_cls = MSSQLBuilder + + default_schema_name = 'dbo' + + fk_types = { 'SERIAL' : 'BIGINT UNSIGNED' } + + converter_classes = [ + (NoneType, dbapiprovider.NoneConverter), + (bool, dbapiprovider.BoolConverter), + (basestring, MSSQLStrConverter), + (int_types, dbapiprovider.IntConverter), + (float, MSSQLRealConverter), + (Decimal, dbapiprovider.DecimalConverter), + (datetime, dbapiprovider.DatetimeConverter), + (date, dbapiprovider.DateConverter), + (time, MSSQLTimeConverter), + (timedelta, MSSQLTimedeltaConverter), + (UUID, MSSQLUuidConverter), + (buffer, MSSQLBlobConverter), + (ormtypes.Json, MSSQLJsonConverter), + ] + + def normalize_name(provider, name): + return name[:provider.max_name_len].lower() + + @wrap_dbapi_exceptions + def inspect_connection(provider, connection): + cursor = connection.cursor() + cursor.execute('Select @@version') + row = cursor.fetchone() + assert row is not None + provider.server_version = row[0] + # cursor.execute('select database()') + # provider.default_schema_name = cursor.fetchone()[0] + # cursor.execute('set session group_concat_max_len = 4294967295') + + def should_reconnect(provider, exc): + return isinstance(exc, mssql_module.OperationalError) and exc.args[0] in (2006, 2013) + + def get_pool(provider, *args, **kwargs): + driver = kwargs['driver'] + server = kwargs['server'] + database = kwargs['database'] + username = kwargs['username'] + password = kwargs['password'] + connection_string = f'Driver={driver};Server={server};Database={database};UID={username};PWD={password};MARS_Connection=Yes' + return Pool(mssql_module, connection_string, **kwargs) + + @wrap_dbapi_exceptions + def set_transaction_mode(provider, connection, cache): + assert not cache.in_transaction + db_session = cache.db_session + if db_session is not None and db_session.ddl: + cache.in_transaction = True + if db_session is not None and db_session.serializable: + cursor = connection.cursor() + sql = 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE' + if core.local.debug: log_orm(sql) + cursor.execute(sql) + cache.in_transaction = True + + @wrap_dbapi_exceptions + def execute(provider, cursor, sql, arguments=None, returning_id=False): + sql = sql.replace('`', "") + sql = sql.replace('\n', ' ') + sql = sql.replace('%s', '?') + sql = escapify(sql) + if type(arguments) is list: + assert arguments and not returning_id + cursor.executemany(sql, arguments) + else: + if arguments is None: cursor.execute(sql) + else: + cursor.execute(sql, arguments) + if returning_id: return int(cursor.execute('SELECT SCOPE_IDENTITY() AS [SCOPE_IDENTITY]').fetchone()[0]) + + + @wrap_dbapi_exceptions + def release(provider, connection, cache=None): + if cache is not None: + db_session = cache.db_session + if db_session is not None and db_session.ddl and cache.saved_fk_state: + try: + cursor = connection.cursor() + sql = 'SET foreign_key_checks = 1' + if core.local.debug: log_orm(sql) + cursor.execute(sql) + except: + provider.pool.drop(connection) + raise + DBAPIProvider.release(provider, connection, cache) + + def table_exists(provider, connection, table_name, case_sensitive=True): + db_name, table_name = provider.split_table_name(table_name) + cursor = connection.cursor() + sql = """ + SELECT table_name FROM information_schema.tables + WHERE table_schema=? and table_name=? + """ + cursor.execute(sql, [ db_name, table_name ]) + row = cursor.fetchone() + return row[0] if row is not None else None + + def index_exists(provider, connection, table_name, index_name, case_sensitive=True): + #TODO: might not be ready, need to check this + db_name, table_name = provider.split_table_name(table_name) + sql = """ + SELECT name + FROM sys.indexes + WHERE object_id = OBJECT_ID(?) + AND name=? + """ + cursor = connection.cursor() + cursor.execute(sql, [f'{db_name}.{table_name}', index_name]) + row = cursor.fetchone() + return row[0] if row is not None else None + + def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True): + db_name, table_name = provider.split_table_name(table_name) + + sql = """ + SELECT name + FROM sys.foreign_keys + WHERE object_id = OBJECT_ID(?) + """ + cursor = connection.cursor() + cursor.execute(sql, [ f'{db_name}.{fk_name}']) + row = cursor.fetchone() + return row[0] if row is not None else None + + def table_has_data(provider, connection, table_name): + cursor = connection.cursor() + provider.execute(cursor, escapify('SELECT TOP 1 * FROM %s' % provider.quote_name(table_name))) + return cursor.fetchone() is not None + + def drop_table(self, connection, table_name): + + cursor = connection.cursor() + + sql = ''' + + DECLARE @sql nvarchar(1000) + + WHILE EXISTS( + SELECT * + FROM sys.foreign_keys + WHERE referenced_object_id = object_id('%(table)s') + ) + BEGIN + SELECT + @sql = 'ALTER TABLE ' + OBJECT_SCHEMA_NAME(parent_object_id) + + '.[' + OBJECT_NAME(parent_object_id) + + '] DROP CONSTRAINT ' + name + FROM sys.foreign_keys + WHERE referenced_object_id = object_id('%(table)s') + exec sp_executesql @sql + END + + DROP TABLE "%(table)s" + + ''' % { 'table' : table_name } + + cursor.execute(sql) + +provider_cls = MSSQLProvider + +def str2datetime(s): + if 19 < len(s) < 26: s += '000000'[:26-len(s)] + s = s.replace('-', ' ').replace(':', ' ').replace('.', ' ').replace('T', ' ') + try: + return datetime(*imap(int, s.split())) + except ValueError: + return None # for incorrect values like 0000-00-00 00:00:00 From 92802400e616ca7d512183f0307967286772740f Mon Sep 17 00:00:00 2001 From: SOSTHOLM Date: Wed, 19 Aug 2020 10:54:02 +0200 Subject: [PATCH 537/547] added bool fix --- pony/orm/dbproviders/mssql.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/mssql.py b/pony/orm/dbproviders/mssql.py index ebb5e9cc1..cc7331adb 100644 --- a/pony/orm/dbproviders/mssql.py +++ b/pony/orm/dbproviders/mssql.py @@ -117,6 +117,10 @@ class MSSQLRealConverter(dbapiprovider.RealConverter): def sql_type(converter): return 'float' +class MSSQLBoolConverter(dbapiprovider.BoolConverter): + def sql_type(converter): + return 'BIT' + class MSSQLBlobConverter(dbapiprovider.BlobConverter): def sql_type(converter): return 'LONGBLOB' @@ -183,7 +187,7 @@ class MSSQLProvider(DBAPIProvider): converter_classes = [ (NoneType, dbapiprovider.NoneConverter), - (bool, dbapiprovider.BoolConverter), + (bool, MSSQLBoolConverter), (basestring, MSSQLStrConverter), (int_types, dbapiprovider.IntConverter), (float, MSSQLRealConverter), From 6159c9a843849ba43b58befb208cf140d5c9e369 Mon Sep 17 00:00:00 2001 From: Samuel Ostholm Date: Sat, 29 Aug 2020 08:15:46 +0200 Subject: [PATCH 538/547] WIP: mssql basics are working, including basic JSON fields, DATE_ADD etc, not yet tested --- pony/orm/dbproviders/mssqlserver.py | 263 ---------------------------- 1 file changed, 263 deletions(-) delete mode 100644 pony/orm/dbproviders/mssqlserver.py diff --git a/pony/orm/dbproviders/mssqlserver.py b/pony/orm/dbproviders/mssqlserver.py deleted file mode 100644 index 58c060264..000000000 --- a/pony/orm/dbproviders/mssqlserver.py +++ /dev/null @@ -1,263 +0,0 @@ -from __future__ import absolute_import -from pony.py23compat import PY2, imap, basestring, buffer, int_types - -from decimal import Decimal -from datetime import datetime, date, time, timedelta -from uuid import UUID - -NoneType = type(None) - -import warnings -warnings.filterwarnings('ignore', '^Table.+already exists$', Warning, '^pony\\.orm\\.dbapiprovider$') - -import pyodbc - -from pony.orm import core, dbschema, dbapiprovider -from pony.orm.core import log_orm, OperationalError -from pony.orm.dbapiprovider import DBAPIProvider, Pool, get_version_tuple, wrap_dbapi_exceptions -from pony.orm.sqltranslation import SQLTranslator -from pony.orm.sqlbuilding import SQLBuilder, join, indentable, make_unary_func -from pony.utils import throw -from pony.converting import str2timedelta, timedelta2str - -class MSColumn(dbschema.Column): - auto_template = '%(type)s PRIMARY KEY IDENTITY(1, 1)' - -class MSSchema(dbschema.DBSchema): - dialect = 'MSSQL' - inline_fk_syntax = False - column_class = MSColumn - -class MSTranslator(SQLTranslator): - dialect = 'MSSQL' - -class MSBuilder(SQLBuilder): - dialect = 'MSSQL' - - def INSERT(builder, table_name, columns, values, returning=None): - ret = [ - 'INSERT INTO ', builder.quote_name(table_name), ' (', - join(', ', [builder.quote_name(column) for column in columns ]), - ')' - ] - if returning is not None: - ret.extend(( - ' OUTPUT inserted.', builder.quote_name(returning), - )) - ret.extend(( - ' VALUES (', - join(', ', [builder(value) for value in values]), ')' - )) - return ret - - LENGTH = make_unary_func('LEN') - - @indentable - def LIMIT(builder, limit, offset=None): - if offset is None: - offset = ('VALUE', 0) - return 'OFFSET ', builder(offset), ' ROWS FETCH NEXT ', builder(limit), ' ROWS ONLY\n' - - -class MSBoolConverter(dbapiprovider.BoolConverter): - def sql_type(converter): - return "BIT" - - -class MSIntConverter(dbapiprovider.IntConverter): - signed_types = {None: 'INTEGER', 8: 'SMALLINT', 16: 'SMALLINT', 24: 'INTEGER', 32: 'INTEGER', 64: 'BIGINT'} - unsigned_types = {None: 'INTEGER', 8: 'TINYINT', 16: 'INTEGER', 24: 'INTEGER', 32: 'BIGINT'} - -class MSStrConverter(dbapiprovider.StrConverter): - def sql_type(converter): - if converter.max_len: - return 'VARCHAR(%d)' % converter.max_len - attr = converter.attr - if attr is not None and (attr.is_unique or attr.composite_keys): - return 'VARCHAR(8000)' - return 'VARCHAR(MAX)' - -class MSRealConverter(dbapiprovider.RealConverter): - def sql_type(converter): - return 'FLOAT' - -class MSBlobConverter(dbapiprovider.BlobConverter): - def sql_type(converter): - return 'VARBINARY(MAX)' - -class MSTimeConverter(dbapiprovider.TimeConverter): - def sql2py(converter, val): - if isinstance(val, timedelta): # MySQLdb returns timedeltas instead of times - total_seconds = val.days * (24 * 60 * 60) + val.seconds - if 0 <= total_seconds <= 24 * 60 * 60: - minutes, seconds = divmod(total_seconds, 60) - hours, minutes = divmod(minutes, 60) - return time(hours, minutes, seconds, val.microseconds) - elif not isinstance(val, time): throw(ValueError, - 'Value of unexpected type received from database%s: instead of time or timedelta got %s' - % ('for attribute %s' % converter.attr if converter.attr else '', type(val))) - return val - -class MSTimedeltaConverter(dbapiprovider.TimedeltaConverter): - sql_type_name = 'TIME' - -class MSUuidConverter(dbapiprovider.UuidConverter): - def sql_type(converter): - return 'BINARY(16)' - -class MSDateConverter(dbapiprovider.DateConverter): - def py2sql(converter, val): - val = dbapiprovider.DateConverter.py2sql(converter, val) - if isinstance(val, date): - val = val.strftime('%Y-%m-%d') - return val - def sql2py(converter, val): - if isinstance(val, basestring): - val = datetime.strptime(val, "%Y-%m-%d").date() - val = dbapiprovider.DateConverter.sql2py(converter, val) - return val - def sql_type(converter): - return 'DATE' - -class MSProvider(DBAPIProvider): - dialect = 'MSSQL' - paramstyle = 'qmark' - quote_char = '"' - max_name_len = 128 - - table_if_not_exists_syntax = True - index_if_not_exists_syntax = False - select_for_update_nowait_syntax = False - max_time_precision = default_time_precision = 0 - varchar_default_max_len = 255 - uint64_support = True - - dbapi_module = pyodbc - - dbschema_cls = MSSchema - translator_cls = MSTranslator - sqlbuilder_cls = MSBuilder - - - default_schema_name = 'dbo' - name_before_table = 'db_name' - - converter_classes = [ - (NoneType, dbapiprovider.NoneConverter), - (bool, MSBoolConverter), - (basestring, MSStrConverter), - (int_types, MSIntConverter), - (float, MSRealConverter), - (Decimal, dbapiprovider.DecimalConverter), - (datetime, dbapiprovider.DatetimeConverter), - (date, MSDateConverter), - (time, MSTimeConverter), - (timedelta, MSTimedeltaConverter), - (UUID, MSUuidConverter), - (buffer, MSBlobConverter), - ] - - def normalize_name(provider, name): - return name[:provider.max_name_len].lower() - - @wrap_dbapi_exceptions - def inspect_connection(provider, connection): - cursor = connection.cursor() - cursor.execute('select @@version') - row = cursor.fetchone() - assert row is not None - cursor.execute('select DB_NAME()') - provider.default_schema_name = cursor.fetchone()[0] - - def should_reconnect(provider, exc): - return isinstance(exc, pyodbc.OperationalError) - - @wrap_dbapi_exceptions - def set_transaction_mode(provider, connection, cache): - assert not cache.in_transaction - db_session = cache.db_session - if db_session is not None and db_session.serializable: - cursor = connection.cursor() - sql = 'SET TRANSACTION ISOLATION LEVEL SERIALIZABLE' - if core.debug: log_orm(sql) - cursor.execute(sql) - cache.immediate = True - if db_session is not None and (db_session.serializable or db_session.ddl): - cache.in_transaction = True - - @wrap_dbapi_exceptions - def execute(provider, cursor, sql, arguments=None, returning_id=False): - if type(arguments) is list: - assert arguments and not returning_id - cursor.executemany(sql, arguments) - else: - if arguments is None: cursor.execute(sql) - else: cursor.execute(sql, arguments) - if returning_id: - return cursor.fetchone()[0] - - - def table_exists(provider, connection, table_name, case_sensitive=True): - db_name, table_name = provider.split_table_name(table_name) - cursor = connection.cursor() - if case_sensitive: sql = 'SELECT table_name FROM information_schema.tables ' \ - 'WHERE table_schema=? and table_name=?' - else: sql = 'SELECT table_name FROM information_schema.tables ' \ - 'WHERE table_schema=? and UPPER(table_name)=UPPER(?)' - cursor.execute(sql, [ db_name, table_name ]) - row = cursor.fetchone() - return row[0] if row is not None else None - - def index_exists(provider, connection, table_name, index_name, case_sensitive=True): - table_name = provider.quote_name(table_name) - if case_sensitive: sql = "SELECT top 1 1 FROM sys.indexes WHERE name=? AND object_id=OBJECT_ID(?)" - else: sql = "SELECT top 1 1 FROM sys.indexes WHERE lower(name)=lower(?) AND object_id=OBJECT_ID(?)" - cursor = connection.cursor() - cursor.execute(sql, [ index_name, table_name ]) - row = cursor.fetchone() - return row[0] if row is not None else None - - def fk_exists(provider, connection, table_name, fk_name, case_sensitive=True): - schema_name, table_name = provider.split_table_name(table_name) - table_name = provider.quote_name(table_name) - fk_name = provider.quote_name([ schema_name, fk_name ]) - # if case_sensitive: ??? - sql = "SELECT name FROM sys.foreign_keys " \ - "WHERE object_id = OBJECT_ID(?) AND parent_object_id=OBJECT_ID(?)"; - cursor = connection.cursor() - cursor.execute(sql, [ fk_name, table_name ]) - row = cursor.fetchone() - return row[0] if row is not None else None - - def table_has_data(provider, connection, table_name): - table_name = provider.quote_name(table_name) - cursor = connection.cursor() - cursor.execute('SELECT TOP 1 1 FROM %s' % table_name) - return cursor.fetchone() is not None - - def drop_table(provider, connection, table_name): - table_name = provider.quote_name(table_name) - cursor = connection.cursor() - sql = ''' - DECLARE @sql nvarchar(1000) - - WHILE EXISTS( - SELECT * - FROM sys.foreign_keys - WHERE referenced_object_id = object_id('%(table)s') - ) - BEGIN - SELECT - @sql = 'ALTER TABLE ' + OBJECT_SCHEMA_NAME(parent_object_id) + - '.[' + OBJECT_NAME(parent_object_id) + - '] DROP CONSTRAINT ' + name - FROM sys.foreign_keys - WHERE referenced_object_id = object_id('%(table)s') - exec sp_executesql @sql - END - - DROP TABLE %(table)s - ''' % { 'table' : table_name } - cursor.execute(sql) - -provider_cls = MSProvider From 1cd4a1cef43f5637088f516f17593c2f456c109c Mon Sep 17 00:00:00 2001 From: SOSTHOLM Date: Wed, 7 Oct 2020 09:39:01 +0200 Subject: [PATCH 539/547] added support for limit --- pony/orm/dbproviders/mssql.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/pony/orm/dbproviders/mssql.py b/pony/orm/dbproviders/mssql.py index cc7331adb..ecc0f8716 100644 --- a/pony/orm/dbproviders/mssql.py +++ b/pony/orm/dbproviders/mssql.py @@ -54,33 +54,51 @@ class MSSQLBuilder(SQLBuilder): def CONCAT(builder, *args): return 'CONCAT(', join(', ', imap(builder, args)), ')' + def TRIM(builder, expr, chars=None): if chars is None: return 'TRIM(', builder(expr), ')' return 'TRIM(', builder(chars), ' FROM ' ,builder(expr), ')' + def LTRIM(builder, expr, chars=None): if chars is None: return 'ltrim(', builder(expr), ')' return 'LTRIM(', builder(chars), ' FROM ' ,builder(expr), ')' + def RTRIM(builder, expr, chars=None): if chars is None: return 'RTRIM(', builder(expr), ')' return 'RTRIM(', builder(chars), ' FROM ' ,builder(expr), ')' + def TO_INT(builder, expr): return 'CAST(', builder(expr), ' AS int)' + def TO_REAL(builder, expr): return 'CAST(', builder(expr), ' AS float)' + def TO_STR(builder, expr): return 'CAST(', builder(expr), ' AS nvarchar)' + def YEAR(builder, expr): return 'YEAR(', builder(expr), ')' + def MONTH(builder, expr): return 'MONTH(', builder(expr), ')' + def DAY(builder, expr): return 'DAY(', builder(expr), ')' + def HOUR(builder, expr): return 'DATEPART(hh, ', builder(expr), ')' + def MINUTE(builder, expr): return 'DATEPART(n, ', builder(expr), ')' + def SECOND(builder, expr): return 'DATEPART(ss, ', builder(expr), ')' + + def LIMIT(builder, limit, offset=0): + if [True for ast in builder.ast if 'ORDER_BY' in ast[0]][0] == False: + return 'ORDER BY (SELECT NULL)' + f'OFFSET {offset} ROWS FETCH NEXT {limit} ROWS ONLY' + + return f'OFFSET {offset} ROWS FETCH NEXT {limit} ROWS ONLY' # TODO: not fixed this yet def DATE_ADD(builder, expr, delta): From 6989956a418479c46206689006dbdc76ce33d183 Mon Sep 17 00:00:00 2001 From: SOSTHOLM Date: Fri, 16 Oct 2020 13:48:29 +0200 Subject: [PATCH 540/547] Update mssql.py Fixed py datetime to sql datetime string conversion --- pony/orm/dbproviders/mssql.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pony/orm/dbproviders/mssql.py b/pony/orm/dbproviders/mssql.py index ecc0f8716..66439458b 100644 --- a/pony/orm/dbproviders/mssql.py +++ b/pony/orm/dbproviders/mssql.py @@ -45,6 +45,11 @@ def __unicode__(self): if value.microseconds: return "INTERVAL '%s' HOUR_MICROSECOND" % timedelta2str(value) return "INTERVAL '%s' HOUR_SECOND" % timedelta2str(value) + + if isinstance(value, datetime): + result = value.isoformat(' ') + return self.quote_str(result) + return Value.__unicode__(self) if not PY2: __str__ = __unicode__ From af1b8b20f6d5cca1f067c26e015114a975d3c4f8 Mon Sep 17 00:00:00 2001 From: Samuel Ostholm Date: Mon, 19 Apr 2021 11:06:30 +0200 Subject: [PATCH 541/547] fixing limit mssql builder for the case where order_by is null. Also adding the possibility of returning None id after insert --- pony/orm/dbproviders/mssql.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pony/orm/dbproviders/mssql.py b/pony/orm/dbproviders/mssql.py index 66439458b..e985ec89d 100644 --- a/pony/orm/dbproviders/mssql.py +++ b/pony/orm/dbproviders/mssql.py @@ -100,7 +100,8 @@ def SECOND(builder, expr): return 'DATEPART(ss, ', builder(expr), ')' def LIMIT(builder, limit, offset=0): - if [True for ast in builder.ast if 'ORDER_BY' in ast[0]][0] == False: + order_by = [True for ast in builder.ast if 'ORDER_BY' in ast[0]] + if not order_by or order_by[0] == False: return 'ORDER BY (SELECT NULL)' + f'OFFSET {offset} ROWS FETCH NEXT {limit} ROWS ONLY' return f'OFFSET {offset} ROWS FETCH NEXT {limit} ROWS ONLY' @@ -276,7 +277,12 @@ def execute(provider, cursor, sql, arguments=None, returning_id=False): if arguments is None: cursor.execute(sql) else: cursor.execute(sql, arguments) - if returning_id: return int(cursor.execute('SELECT SCOPE_IDENTITY() AS [SCOPE_IDENTITY]').fetchone()[0]) + if returning_id: + id = cursor.execute('SELECT SCOPE_IDENTITY() AS [SCOPE_IDENTITY]').fetchone()[0] + if id: + return int(id) + else: + return id @wrap_dbapi_exceptions From d6aeaf8bf559902cca7f142781dcae049eb09106 Mon Sep 17 00:00:00 2001 From: Samuel Ostholm Date: Wed, 28 Apr 2021 09:37:04 +0200 Subject: [PATCH 542/547] fixing issue with where wildcard string searches for s gets replaced to ? --- .vscode/settings.json | 3 +++ models.py | 21 +++++++++++++++++++++ pony/orm/dbproviders/mssql.py | 7 +++++-- 3 files changed, 29 insertions(+), 2 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 models.py diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 000000000..a8cb05707 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "python.pythonPath": "C:\\Users\\sostholm\\Envs\\pony\\Scripts\\python.exe" +} \ No newline at end of file diff --git a/models.py b/models.py new file mode 100644 index 000000000..abdb3be5a --- /dev/null +++ b/models.py @@ -0,0 +1,21 @@ +from datetime import datetime +from uuid import UUID, uuid4 +from pony.orm import * + +db = Database() + +class User(db.Entity): + id = PrimaryKey(UUID, default=uuid4) + created = Required(datetime, default=lambda: datetime.now()) + + + + +db.bind('mssql', driver='ODBC Driver 17 for SQL Server', server='10.24.219.31', database='testing', username="tester2", password="tester123!") +db.generate_mapping(create_tables=True) + +if __name__ == '__main__': + with db_session: + usr = User() + users = select(u for u in User)[:] + print(users) diff --git a/pony/orm/dbproviders/mssql.py b/pony/orm/dbproviders/mssql.py index e985ec89d..c15c692ee 100644 --- a/pony/orm/dbproviders/mssql.py +++ b/pony/orm/dbproviders/mssql.py @@ -5,7 +5,7 @@ from decimal import Decimal from datetime import datetime, date, time, timedelta from uuid import UUID - +import re NoneType = type(None) import warnings @@ -25,6 +25,8 @@ from pony.utils import throw from pony.converting import str2timedelta, timedelta2str +PYODBC_VAR_REGEX = re.compile(r'(? Date: Wed, 28 Apr 2021 09:45:42 +0200 Subject: [PATCH 543/547] fixing mistake where sql gets replaced with ? --- pony/orm/dbproviders/mssql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/mssql.py b/pony/orm/dbproviders/mssql.py index c15c692ee..bcf630bb7 100644 --- a/pony/orm/dbproviders/mssql.py +++ b/pony/orm/dbproviders/mssql.py @@ -271,7 +271,7 @@ def execute(provider, cursor, sql, arguments=None, returning_id=False): sql = sql.replace('`', "") sql = sql.replace('\n', ' ') sql = sql.replace('%%', '') - sql = PYODBC_VAR_REGEX.sub('?', '%s') + sql = PYODBC_VAR_REGEX.sub('?', sql) sql = escapify(sql) if type(arguments) is list: assert arguments and not returning_id From 4c6629bfe45d2bfa2438578afb70447dd61d3073 Mon Sep 17 00:00:00 2001 From: Samuel Ostholm Date: Wed, 28 Apr 2021 09:50:09 +0200 Subject: [PATCH 544/547] removing wildcard replace in execute --- pony/orm/dbproviders/mssql.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pony/orm/dbproviders/mssql.py b/pony/orm/dbproviders/mssql.py index bcf630bb7..182dc09a4 100644 --- a/pony/orm/dbproviders/mssql.py +++ b/pony/orm/dbproviders/mssql.py @@ -270,7 +270,6 @@ def set_transaction_mode(provider, connection, cache): def execute(provider, cursor, sql, arguments=None, returning_id=False): sql = sql.replace('`', "") sql = sql.replace('\n', ' ') - sql = sql.replace('%%', '') sql = PYODBC_VAR_REGEX.sub('?', sql) sql = escapify(sql) if type(arguments) is list: From 217b64695a51a1c03ebd46e6b65c6fbc2565825c Mon Sep 17 00:00:00 2001 From: Samuel Ostholm Date: Fri, 28 May 2021 12:55:13 +0200 Subject: [PATCH 545/547] Should fix issue where ids are not returned after inserts --- pony/orm/dbproviders/mssql.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pony/orm/dbproviders/mssql.py b/pony/orm/dbproviders/mssql.py index 182dc09a4..7b8ca65e6 100644 --- a/pony/orm/dbproviders/mssql.py +++ b/pony/orm/dbproviders/mssql.py @@ -280,7 +280,8 @@ def execute(provider, cursor, sql, arguments=None, returning_id=False): else: cursor.execute(sql, arguments) if returning_id: - id = cursor.execute('SELECT SCOPE_IDENTITY() AS [SCOPE_IDENTITY]').fetchone()[0] + id = cursor.execute('SELECT @@Identity').fetchone()[0] + # id = cursor.execute('SELECT SCOPE_IDENTITY() AS [SCOPE_IDENTITY]').fetchone()[0] if id: return int(id) else: From 274f8baadef35a8d3a7832d04d40e54a5fd9b260 Mon Sep 17 00:00:00 2001 From: Samuel Ostholm Date: Fri, 30 Jul 2021 10:14:38 +0200 Subject: [PATCH 546/547] fixing returning id after insert --- pony/orm/dbproviders/mssql.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/pony/orm/dbproviders/mssql.py b/pony/orm/dbproviders/mssql.py index 7b8ca65e6..195452a31 100644 --- a/pony/orm/dbproviders/mssql.py +++ b/pony/orm/dbproviders/mssql.py @@ -58,6 +58,13 @@ def __unicode__(self): class MSSQLBuilder(SQLBuilder): dialect = 'MSSQL' value_class = MSSQLValue + + def INSERT(builder, table_name, columns, values, returning=None): + result = SQLBuilder.INSERT(builder, table_name, columns, values) + if returning is not None: + values_index = result.index(') VALUES (') + result[values_index] = ') OUTPUT Inserted.ID VALUES (' + return result def CONCAT(builder, *args): return 'CONCAT(', join(', ', imap(builder, args)), ')' @@ -277,15 +284,17 @@ def execute(provider, cursor, sql, arguments=None, returning_id=False): cursor.executemany(sql, arguments) else: if arguments is None: cursor.execute(sql) - else: - cursor.execute(sql, arguments) - if returning_id: - id = cursor.execute('SELECT @@Identity').fetchone()[0] - # id = cursor.execute('SELECT SCOPE_IDENTITY() AS [SCOPE_IDENTITY]').fetchone()[0] - if id: - return int(id) else: - return id + if returning_id: + sql = sql + _id = cursor.execute(sql, arguments).fetchone()[0] + + if not _id: + _id = cursor.execute('SELECT @@Identity').fetchone()[0] + + return _id + else: + cursor.execute(sql, arguments) @wrap_dbapi_exceptions From 11a9e5c6991689e24c6d3246a8737e8b570950e5 Mon Sep 17 00:00:00 2001 From: Samuel Ostholm Date: Sat, 2 Oct 2021 21:42:24 +0200 Subject: [PATCH 547/547] adding ssl support for connecting to mssql database --- pony/orm/dbproviders/mssql.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pony/orm/dbproviders/mssql.py b/pony/orm/dbproviders/mssql.py index 195452a31..37c923658 100644 --- a/pony/orm/dbproviders/mssql.py +++ b/pony/orm/dbproviders/mssql.py @@ -254,10 +254,17 @@ def should_reconnect(provider, exc): def get_pool(provider, *args, **kwargs): driver = kwargs['driver'] server = kwargs['server'] - database = kwargs['database'] username = kwargs['username'] password = kwargs['password'] - connection_string = f'Driver={driver};Server={server};Database={database};UID={username};PWD={password};MARS_Connection=Yes' + connection_string = f'Driver={driver};Server={server};UID={username};PWD={password};MARS_Connection=Yes' + + if 'encrypt' in kwargs: + connection_string += f';Encrypt={kwargs["encrypt"]}' + if 'trust_server_certificate' in kwargs: + connection_string += f';TrustServerCertificate={kwargs["trust_server_certificate"]}' + if 'database' in kwargs: + connection_string += f';Database={kwargs["database"]}' + return Pool(mssql_module, connection_string, **kwargs) @wrap_dbapi_exceptions