Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent a Session from being closed multiple times #660

Open
wants to merge 4 commits into
base: orm-migrations
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pony/orm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ def rollback():
class DBSessionContextManager(object):
__slots__ = 'retry', 'retry_exceptions', 'allowed_exceptions', \
'immediate', 'ddl', 'serializable', 'strict', 'optimistic', \
'sql_debug', 'show_values'
'sql_debug', 'show_values', '_closed'
def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False, strict=False, optimistic=True,
retry_exceptions=(TransactionError,), allowed_exceptions=(), sql_debug=None, show_values=None):
if retry != 0:
Expand All @@ -448,6 +448,7 @@ def __init__(db_session, retry=0, immediate=False, ddl=False, serializable=False
db_session.allowed_exceptions = allowed_exceptions
db_session.sql_debug = sql_debug
db_session.show_values = show_values
db_session._closed = None
def __call__(db_session, *args, **kwargs):
if not args and not kwargs: return db_session
if len(args) > 1: throw(TypeError,
Expand All @@ -464,6 +465,7 @@ def __enter__(db_session):
"@db_session can accept 'retry' parameter only when used as decorator and not as context manager")
db_session._enter()
def _enter(db_session):

if local.db_session is None:
assert not local.db_context_counter
local.db_session = db_session
Expand All @@ -472,10 +474,14 @@ def _enter(db_session):
elif db_session.serializable and not local.db_session.serializable: throw(TransactionError,
'Cannot start serializable transaction inside non-serializable transaction')
local.db_context_counter += 1
db_session._closed = False
if db_session.sql_debug is not None:
local.push_debug_state(db_session.sql_debug, db_session.show_values)
def __exit__(db_session, exc_type=None, exc=None, tb=None):
if db_session._closed: return
local.db_context_counter -= 1
assert local.db_context_counter >= 0
db_session._closed = True
try:
if not local.db_context_counter:
assert local.db_session is db_session
Expand Down
24 changes: 24 additions & 0 deletions pony/orm/decompiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,9 @@ def COMPARE_OP(decompiler, op):
oper1 = decompiler.stack.pop()
return ast.Compare(oper1, [(op, oper2)])

def CONTAINS_OP(decompiler, invert):
return decompiler.COMPARE_OP('not in' if invert else 'in')

def DUP_TOP(decompiler):
return decompiler.stack[-1]

Expand All @@ -353,6 +356,9 @@ def JUMP_IF_FALSE(decompiler, endpos):

JUMP_IF_FALSE_OR_POP = JUMP_IF_FALSE

def JUMP_IF_NOT_EXC_MATCH(decompiler, endpos):
raise NotImplementedError

def JUMP_IF_TRUE(decompiler, endpos):
return decompiler.conditional_jump(endpos, True)

Expand Down Expand Up @@ -436,10 +442,28 @@ def JUMP_FORWARD(decompiler, endpos):
if decompiler.targets.get(endpos) is then: decompiler.targets[endpos] = if_exp
return if_exp

def IS_OP(decompiler, invert):
return decompiler.COMPARE_OP('is not' if invert else 'is')

def LIST_APPEND(decompiler, offset=None):
throw(InvalidQuery('Use generator expression (... for ... in ...) '
'instead of list comprehension [... for ... in ...] inside query'))

def LIST_EXTEND(decompiler, offset):
if offset != 1:
raise NotImplementedError(offset)
items = decompiler.stack.pop()
if not isinstance(items, ast.Const):
raise NotImplementedError(type(items))
if not isinstance(items.value, tuple):
raise NotImplementedError(type(items.value))
lst = decompiler.stack.pop()
if not isinstance(lst, ast.List):
raise NotImplementedError(type(lst))
values = tuple(ast.Const(v) for v in items.value)
lst.nodes = lst.nodes + values
return lst

def LOAD_ATTR(decompiler, attr_name):
return ast.Getattr(decompiler.stack.pop(), attr_name)

Expand Down
9 changes: 6 additions & 3 deletions pony/orm/tests/test_declarative_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,12 @@ def test29(self):
@raises_exception(NotImplementedError, "date(s.id, 1, 1)")
def test30(self):
select(s for s in Student if s.dob < date(s.id, 1, 1))
@raises_exception(ExprEvalError, "`max()` raises TypeError: max() expects at least one argument" if PYPY else
"`max()` raises TypeError: max expected 1 arguments, got 0" if sys.version_info[:2] < (3, 8) else
"`max()` raises TypeError: max expected 1 argument, got 0")
@raises_exception(
ExprEvalError,
"`max()` raises TypeError: max() expects at least one argument" if PYPY else
"`max()` raises TypeError: max expected 1 arguments, got 0" if sys.version_info[:2] < (3, 8) else
"`max()` raises TypeError: max expected 1 argument, got 0" if sys.version_info[:2] < (3, 9) else
"`max()` raises TypeError: max expected at least 1 argument, got 0")
def test31(self):
select(s for s in Student if s.id < max())
@raises_exception(TypeError, "Incomparable types 'Student' and 'Course' in expression: s in s.courses")
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,11 +107,10 @@ def test_suite():

if __name__ == "__main__":
pv = sys.version_info[:2]
if pv not in ((2, 7), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8)):
if pv not in ((2, 7), (3, 3), (3, 4), (3, 5), (3, 6), (3, 7), (3, 8), (3, 9)):
s = "Sorry, but %s %s requires Python of one of the following versions: 2.7, 3.3-3.8." \
" You have version %s"
print(s % (name, version, sys.version.split(' ', 1)[0]))
sys.exit(1)

REQUIRES = ['docopt']

Expand Down