diff --git a/src/omero/gateway/__init__.py b/src/omero/gateway/__init__.py index ae23448d7..42f2ce660 100644 --- a/src/omero/gateway/__init__.py +++ b/src/omero/gateway/__init__.py @@ -26,7 +26,7 @@ from past.utils import old_div from builtins import object import os - +from functools import wraps import warnings from collections import defaultdict @@ -1505,6 +1505,23 @@ def values(self): return () +def assertConnected(func): + """decorator that raises an exception if the decorated function is used when the gateway + is not connected. + + Assumes the first argument is a BlitzGateway instance, or the function + is a BlitzGateway method + + """ + @wraps(func) + def wrapped(conn, *args, **kwargs): + if not conn._connected: + raise Ice.ConnectionLostException("You need to be connected to execute this function") + + return func(conn, *args, **kwargs) + return wrapped + + class _BlitzGateway (object): """ Connection wrapper. Handles connecting and keeping the session alive, @@ -2308,12 +2325,11 @@ def isConnected(self): :return: Boolean """ - return self._connected ######################## # # Connection Stuff # # - + @assertConnected def getEventContext(self): """ Returns omero_System_ice.EventContext. diff --git a/test/unit/test_gateway.py b/test/unit/test_gateway.py index 4d3d0363e..5af2929c9 100644 --- a/test/unit/test_gateway.py +++ b/test/unit/test_gateway.py @@ -30,7 +30,7 @@ import sys from omero.gateway import BlitzGateway, ImageWrapper, \ - WellWrapper, LogicalChannelWrapper, OriginalFileWrapper + WellWrapper, LogicalChannelWrapper, OriginalFileWrapper, assertConnected from omero.model import ImageI, PixelsI, ExperimenterI, EventI, \ ProjectI, TagAnnotationI, FileAnnotationI, OriginalFileI, \ MapAnnotationI, NamedValue, PlateI, WellI, \ @@ -40,6 +40,31 @@ from omero.rtypes import rstring, rtime, rlong, rint, rdouble +def test_assertConnected(): + + class MockGateway: + _connected = False + + def connect(self): + self._connected = True + return True + + def isConnected(self): + return self._connected + + @assertConnected + def foo(self): + return 1 + + conn = MockGateway() + assert not conn.isConnected() + with pytest.raises(Ice.ConnectionLostException): + conn.foo() + + assert conn.connect() + assert conn.foo() + + class MockQueryService(object): def __init__(self, obj_to_be_returned):