Skip to content

Commit

Permalink
add sync_wrapper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
okapies committed Jul 21, 2023
1 parent e1dc3be commit 5b59411
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 16 deletions.
76 changes: 61 additions & 15 deletions asyncua/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
sync API of asyncua
"""
import asyncio
import functools
from pathlib import Path
from threading import Thread, Condition
import logging
Expand Down Expand Up @@ -98,30 +99,75 @@ def wrapper(self, *args, **kwargs):
return wrapper


def sync_wrapper(aio_func):
def wrapper(*args, **kwargs):
if not args:
raise RuntimeError("first argument of function must a ThreadLoop object")
if isinstance(args[0], ThreadLoop):
tloop = args[0]
args = list(args)[1:]
elif hasattr(args[0], "tloop"):
tloop = args[0].tloop
else:
raise RuntimeError("first argument of function must a ThreadLoop object")
args, kwargs = _to_async(args, kwargs)
result = tloop.post(aio_func(*args, **kwargs))
return _to_sync(tloop, result)

return wrapper


def syncfunc(aio_func):
"""
decorator for sync function
"""
def decorator(func, *args, **kwargs):
def wrapper(*args, **kwargs):
if not args:
raise RuntimeError("first argument of function must a ThreadLoop object")
if isinstance(args[0], ThreadLoop):
tloop = args[0]
args = list(args)[1:]
elif hasattr(args[0], "tloop"):
tloop = args[0].tloop
else:
raise RuntimeError("first argument of function must a ThreadLoop object")
args, kwargs = _to_async(args, kwargs)
result = tloop.post(aio_func(*args, **kwargs))
return _to_sync(tloop, result)

return wrapper
return sync_wrapper(aio_func)

return decorator


def sync_uaclient_method(aio_func):
"""
Usage:
```python
from asyncua.client.ua_client import UaClient
from asyncua.sync import Client
with Client('otp.tcp://localhost') as client:
read_attributes = sync_uaclient_method(UaClient.read_attributes)(client)
results = read_attributes(...)
...
```
"""
def sync_method(client: 'Client'):
uaclient = client.aio_obj.uaclient
return functools.partial(sync_wrapper(aio_func), client.tloop, uaclient)

return sync_method


def sync_async_client_method(aio_func):
"""
Usage:
```python
from asyncua.client import Client as AsyncClient
from asyncua.sync import Client
with Client('otp.tcp://localhost') as client:
read_attributes = sync_async_client_method(AsyncClient.read_attributes)(client)
results = read_attributes(...)
...
```
"""
def sync_method(client: 'Client'):
return functools.partial(sync_wrapper(aio_func), client.tloop, client)

return sync_method


@syncfunc(aio_func=common.methods.call_method_full)
def call_method_full(parent, methodid, *args):
pass
Expand Down
34 changes: 33 additions & 1 deletion tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,21 @@

import pytest

from asyncua.sync import Client, Server, ThreadLoop, SyncNode, call_method_full, XmlExporter, new_enum, new_struct, new_struct_field
from asyncua.client import Client as AsyncClient
from asyncua.client.ua_client import UaClient
from asyncua.sync import (
Client,
Server,
ThreadLoop,
SyncNode,
call_method_full,
XmlExporter,
new_enum,
new_struct,
new_struct_field,
sync_async_client_method,
sync_uaclient_method,
)
from asyncua import ua, uamethod


Expand Down Expand Up @@ -61,6 +75,24 @@ def test_sync_client(client, idx):
assert myvar.read_value() == 6.7


def test_sync_uaclient_method(client, idx):
client.load_type_definitions()
myvar = client.nodes.root.get_child(["0:Objects", f"{idx}:MyObject", f"{idx}:MyVariable"])
read_attributes = sync_uaclient_method(UaClient.read_attributes)(client)
results = read_attributes([myvar.nodeid], attr=ua.AttributeIds.Value)
assert len(results) == 1
assert results[0].Value.Value == 6.7


def test_sync_async_client_method(client, idx):
client.load_type_definitions()
myvar = client.nodes.root.get_child(["0:Objects", f"{idx}:MyObject", f"{idx}:MyVariable"])
read_attributes = sync_async_client_method(AsyncClient.read_attributes)(client)
results = read_attributes([myvar], attr=ua.AttributeIds.Value)
assert len(results) == 1
assert results[0].Value.Value == 6.7


def test_sync_client_get_node(client):
node = client.get_node(85)
assert node == client.nodes.objects
Expand Down

0 comments on commit 5b59411

Please sign in to comment.