diff --git a/asyncua/sync.py b/asyncua/sync.py index c92468fe0..09ba154e8 100644 --- a/asyncua/sync.py +++ b/asyncua/sync.py @@ -2,6 +2,7 @@ sync API of asyncua """ import asyncio +import functools from pathlib import Path from threading import Thread, Condition import logging @@ -98,30 +99,75 @@ def wrapper(self, *args, **kwargs): return wrapper +def sync_wrapper(aio_func): + def wrapper(*args, **kwargs): + if not args: + raise RuntimeError("first argument of function must a ThreadLoop object") + if isinstance(args[0], ThreadLoop): + tloop = args[0] + args = list(args)[1:] + elif hasattr(args[0], "tloop"): + tloop = args[0].tloop + else: + raise RuntimeError("first argument of function must a ThreadLoop object") + args, kwargs = _to_async(args, kwargs) + result = tloop.post(aio_func(*args, **kwargs)) + return _to_sync(tloop, result) + + return wrapper + + def syncfunc(aio_func): """ decorator for sync function """ def decorator(func, *args, **kwargs): - def wrapper(*args, **kwargs): - if not args: - raise RuntimeError("first argument of function must a ThreadLoop object") - if isinstance(args[0], ThreadLoop): - tloop = args[0] - args = list(args)[1:] - elif hasattr(args[0], "tloop"): - tloop = args[0].tloop - else: - raise RuntimeError("first argument of function must a ThreadLoop object") - args, kwargs = _to_async(args, kwargs) - result = tloop.post(aio_func(*args, **kwargs)) - return _to_sync(tloop, result) - - return wrapper + return sync_wrapper(aio_func) return decorator +def sync_uaclient_method(aio_func): + """ + Usage: + + ```python + from asyncua.client.ua_client import UaClient + from asyncua.sync import Client + + with Client('otp.tcp://localhost') as client: + read_attributes = sync_uaclient_method(UaClient.read_attributes)(client) + results = read_attributes(...) + ... + ``` + """ + def sync_method(client: 'Client'): + uaclient = client.aio_obj.uaclient + return functools.partial(sync_wrapper(aio_func), client.tloop, uaclient) + + return sync_method + + +def sync_async_client_method(aio_func): + """ + Usage: + + ```python + from asyncua.client import Client as AsyncClient + from asyncua.sync import Client + + with Client('otp.tcp://localhost') as client: + read_attributes = sync_async_client_method(AsyncClient.read_attributes)(client) + results = read_attributes(...) + ... + ``` + """ + def sync_method(client: 'Client'): + return functools.partial(sync_wrapper(aio_func), client.tloop, client) + + return sync_method + + @syncfunc(aio_func=common.methods.call_method_full) def call_method_full(parent, methodid, *args): pass diff --git a/tests/test_sync.py b/tests/test_sync.py index 06492c165..b8090ab02 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -4,7 +4,21 @@ import pytest -from asyncua.sync import Client, Server, ThreadLoop, SyncNode, call_method_full, XmlExporter, new_enum, new_struct, new_struct_field +from asyncua.client import Client as AsyncClient +from asyncua.client.ua_client import UaClient +from asyncua.sync import ( + Client, + Server, + ThreadLoop, + SyncNode, + call_method_full, + XmlExporter, + new_enum, + new_struct, + new_struct_field, + sync_async_client_method, + sync_uaclient_method, +) from asyncua import ua, uamethod @@ -61,6 +75,24 @@ def test_sync_client(client, idx): assert myvar.read_value() == 6.7 +def test_sync_uaclient_method(client, idx): + client.load_type_definitions() + myvar = client.nodes.root.get_child(["0:Objects", f"{idx}:MyObject", f"{idx}:MyVariable"]) + read_attributes = sync_uaclient_method(UaClient.read_attributes)(client) + results = read_attributes([myvar.nodeid], attr=ua.AttributeIds.Value) + assert len(results) == 1 + assert results[0].Value.Value == 6.7 + + +def test_sync_async_client_method(client, idx): + client.load_type_definitions() + myvar = client.nodes.root.get_child(["0:Objects", f"{idx}:MyObject", f"{idx}:MyVariable"]) + read_attributes = sync_async_client_method(AsyncClient.read_attributes)(client) + results = read_attributes([myvar], attr=ua.AttributeIds.Value) + assert len(results) == 1 + assert results[0].Value.Value == 6.7 + + def test_sync_client_get_node(client): node = client.get_node(85) assert node == client.nodes.objects