Skip to content

Commit

Permalink
Introduce client object in python tests (#772)
Browse files Browse the repository at this point in the history
Thus far, the client end of the socket is the only piece of client state
tracked in tests, for which a global `socket` variable has been used. In
preparation to add more state, replace the `socket` global with a
`client` global object that groups all client state.

Signed-off-by: Mattias Nissler <[email protected]>
Reviewed-by: John Levon <[email protected]>
Reviewed-by: Thanos Makatos <[email protected]>
  • Loading branch information
mnissler-rivos authored Aug 31, 2023
1 parent 2e8ec2e commit a7eedff
Show file tree
Hide file tree
Showing 21 changed files with 365 additions and 351 deletions.
51 changes: 32 additions & 19 deletions test/py/libvfio_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,25 +688,38 @@ def connect_sock():
return sock


def connect_client(ctx):
sock = connect_sock()

json = b'{ "capabilities": { "max_msg_fds": 8 } }'
# struct vfio_user_version
payload = struct.pack("HH%dsc" % len(json), LIBVFIO_USER_MAJOR,
LIBVFIO_USER_MINOR, json, b'\0')
hdr = vfio_user_header(VFIO_USER_VERSION, size=len(payload))
sock.send(hdr + payload)
vfu_attach_ctx(ctx, expect=0)
payload = get_reply(sock, expect=0)
return sock


def disconnect_client(ctx, sock):
sock.close()

# notice client closed connection
vfu_run_ctx(ctx, errno.ENOTCONN)
class Client:
"""Models a VFIO-user client connected to the server under test."""

def __init__(self, sock=None):
self.sock = sock
self.client_cmd_socket = None

def connect(self, ctx):
self.sock = connect_sock()

json = b'{ "capabilities": { "max_msg_fds": 8 } }'
# struct vfio_user_version
payload = struct.pack("HH%dsc" % len(json), LIBVFIO_USER_MAJOR,
LIBVFIO_USER_MINOR, json, b'\0')
hdr = vfio_user_header(VFIO_USER_VERSION, size=len(payload))
self.sock.send(hdr + payload)
vfu_attach_ctx(ctx, expect=0)
payload = get_reply(self.sock, expect=0)
return self.sock

def disconnect(self, ctx):
self.sock.close()
self.sock = None

# notice client closed connection
vfu_run_ctx(ctx, errno.ENOTCONN)


def connect_client(*args, **kwargs):
client = Client()
client.connect(*args, **kwargs)
return client


def get_reply(sock, expect=0):
Expand Down
4 changes: 2 additions & 2 deletions test/py/test_destroy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@


def setup_function(function):
global ctx, sock
global ctx, client
ctx = prepare_ctx_for_dma()
assert ctx is not None
sock = connect_client(ctx)
client = connect_client(ctx)


def teardown_function(function):
Expand Down
10 changes: 5 additions & 5 deletions test/py/test_device_get_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,27 +49,27 @@ def test_device_get_info():

# test short write

sock = connect_client(ctx)
client = connect_client(ctx)

payload = struct.pack("II", 0, 0)

msg(ctx, sock, VFIO_USER_DEVICE_GET_INFO, payload,
msg(ctx, client.sock, VFIO_USER_DEVICE_GET_INFO, payload,
expect=errno.EINVAL)

# bad argsz

payload = vfio_user_device_info(argsz=8, flags=0,
num_regions=0, num_irqs=0)

msg(ctx, sock, VFIO_USER_DEVICE_GET_INFO, payload,
msg(ctx, client.sock, VFIO_USER_DEVICE_GET_INFO, payload,
expect=errno.EINVAL)

# valid with larger argsz

payload = vfio_user_device_info(argsz=32, flags=0,
num_regions=0, num_irqs=0)

result = msg(ctx, sock, VFIO_USER_DEVICE_GET_INFO, payload)
result = msg(ctx, client.sock, VFIO_USER_DEVICE_GET_INFO, payload)

(argsz, flags, num_regions, num_irqs) = struct.unpack("IIII", result)

Expand All @@ -78,7 +78,7 @@ def test_device_get_info():
assert num_regions == VFU_PCI_DEV_NUM_REGIONS
assert num_irqs == VFU_DEV_NUM_IRQS

disconnect_client(ctx, sock)
client.disconnect(ctx)

vfu_destroy_ctx(ctx)

Expand Down
22 changes: 11 additions & 11 deletions test/py/test_device_get_irq_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@
import errno

ctx = None
sock = None
client = None

argsz = len(vfio_irq_info())


def test_device_get_irq_info_setup():
global ctx, sock
global ctx, client

ctx = vfu_create_ctx(flags=LIBVFIO_USER_FLAG_ATTACH_NB)
assert ctx is not None
Expand All @@ -55,27 +55,27 @@ def test_device_get_irq_info_setup():
ret = vfu_realize_ctx(ctx)
assert ret == 0

sock = connect_client(ctx)
client = connect_client(ctx)


def test_device_get_irq_info_bad_in():
payload = struct.pack("II", 0, 0)

msg(ctx, sock, VFIO_USER_DEVICE_GET_IRQ_INFO, payload,
msg(ctx, client.sock, VFIO_USER_DEVICE_GET_IRQ_INFO, payload,
expect=errno.EINVAL)

# bad argsz
payload = vfio_irq_info(argsz=8, flags=0, index=VFU_DEV_REQ_IRQ,
count=0)

msg(ctx, sock, VFIO_USER_DEVICE_GET_IRQ_INFO, payload,
msg(ctx, client.sock, VFIO_USER_DEVICE_GET_IRQ_INFO, payload,
expect=errno.EINVAL)

# bad index
payload = vfio_irq_info(argsz=argsz, flags=0, index=VFU_DEV_NUM_IRQS,
count=0)

msg(ctx, sock, VFIO_USER_DEVICE_GET_IRQ_INFO, payload,
msg(ctx, client.sock, VFIO_USER_DEVICE_GET_IRQ_INFO, payload,
expect=errno.EINVAL)


Expand All @@ -86,12 +86,12 @@ def test_device_get_irq_info():
payload = vfio_irq_info(argsz=argsz + 16, flags=0, index=VFU_DEV_REQ_IRQ,
count=0)

msg(ctx, sock, VFIO_USER_DEVICE_GET_IRQ_INFO, payload)
msg(ctx, client.sock, VFIO_USER_DEVICE_GET_IRQ_INFO, payload)

payload = vfio_irq_info(argsz=argsz, flags=0, index=VFU_DEV_REQ_IRQ,
count=0)

result = msg(ctx, sock, VFIO_USER_DEVICE_GET_IRQ_INFO, payload)
result = msg(ctx, client.sock, VFIO_USER_DEVICE_GET_IRQ_INFO, payload)

info, _ = vfio_irq_info.pop_from_buffer(result)

Expand All @@ -103,7 +103,7 @@ def test_device_get_irq_info():
payload = vfio_irq_info(argsz=argsz, flags=0, index=VFU_DEV_ERR_IRQ,
count=0)

result = msg(ctx, sock, VFIO_USER_DEVICE_GET_IRQ_INFO, payload)
result = msg(ctx, client.sock, VFIO_USER_DEVICE_GET_IRQ_INFO, payload)

info, _ = vfio_irq_info.pop_from_buffer(result)

Expand All @@ -115,7 +115,7 @@ def test_device_get_irq_info():
payload = vfio_irq_info(argsz=argsz, flags=0, index=VFU_DEV_MSIX_IRQ,
count=0)

result = msg(ctx, sock, VFIO_USER_DEVICE_GET_IRQ_INFO, payload)
result = msg(ctx, client.sock, VFIO_USER_DEVICE_GET_IRQ_INFO, payload)

info, _ = vfio_irq_info.pop_from_buffer(result)

Expand All @@ -126,7 +126,7 @@ def test_device_get_irq_info():


def test_device_get_irq_info_cleanup():
disconnect_client(ctx, sock)
client.disconnect(ctx)

vfu_destroy_ctx(ctx)

Expand Down
49 changes: 25 additions & 24 deletions test/py/test_device_get_region_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
import tempfile

ctx = None
sock = None
client = None

argsz = len(vfio_region_info())
migr_region_size = 2 << PAGE_SHIFT
migr_mmap_areas = [(PAGE_SIZE, PAGE_SIZE)]


def test_device_get_region_info_setup():
global ctx, sock
global ctx, client

ctx = vfu_create_ctx(flags=LIBVFIO_USER_FLAG_ATTACH_NB)
assert ctx is not None
Expand Down Expand Up @@ -89,14 +89,14 @@ def test_device_get_region_info_setup():
ret = vfu_realize_ctx(ctx)
assert ret == 0

sock = connect_client(ctx)
client = connect_client(ctx)


def test_device_get_region_info_short_write():

payload = struct.pack("II", 0, 0)

msg(ctx, sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload,
msg(ctx, client.sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload,
expect=errno.EINVAL)


Expand All @@ -106,7 +106,7 @@ def test_device_get_region_info_bad_argsz():
index=VFU_PCI_DEV_BAR1_REGION_IDX, cap_offset=0,
size=0, offset=0)

msg(ctx, sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload,
msg(ctx, client.sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload,
expect=errno.EINVAL)


Expand All @@ -116,7 +116,7 @@ def test_device_get_region_info_bad_index():
index=VFU_PCI_DEV_NUM_REGIONS, cap_offset=0,
size=0, offset=0)

msg(ctx, sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload,
msg(ctx, client.sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload,
expect=errno.EINVAL)


Expand All @@ -126,7 +126,7 @@ def test_device_get_region_info_larger_argsz():
index=VFU_PCI_DEV_BAR1_REGION_IDX, cap_offset=0,
size=0, offset=0)

result = msg(ctx, sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload)
result = msg(ctx, client.sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload)

assert len(result) == argsz

Expand All @@ -142,13 +142,13 @@ def test_device_get_region_info_larger_argsz():


def test_device_get_region_info_small_argsz_caps():
global sock
global client

payload = vfio_region_info(argsz=argsz, flags=0,
index=VFU_PCI_DEV_BAR2_REGION_IDX, cap_offset=0,
size=0, offset=0)

result = msg(ctx, sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload)
result = msg(ctx, client.sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload)

info, _ = vfio_region_info.pop_from_buffer(result)

Expand All @@ -167,20 +167,21 @@ def test_device_get_region_info_small_argsz_caps():
assert info.offset == 0x8000

# skip reading the SCM_RIGHTS
disconnect_client(ctx, sock)
client.disconnect(ctx)


def test_device_get_region_info_caps():
global sock
global client

sock = connect_client(ctx)
client = connect_client(ctx)

payload = vfio_region_info(argsz=80, flags=0,
index=VFU_PCI_DEV_BAR2_REGION_IDX, cap_offset=0,
size=0, offset=0)
payload = bytes(payload) + b'\0' * (80 - 32)

fds, result = msg_fds(ctx, sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload)
fds, result = msg_fds(ctx, client.sock, VFIO_USER_DEVICE_GET_REGION_INFO,
payload)

info, result = vfio_region_info.pop_from_buffer(result)
cap, result = vfio_region_info_cap_sparse_mmap.pop_from_buffer(result)
Expand All @@ -203,20 +204,20 @@ def test_device_get_region_info_caps():
assert area2.size == 0x2000

assert len(fds) == 1
disconnect_client(ctx, sock)
client.disconnect(ctx)


def test_device_get_region_info_migr():
global sock
global client

sock = connect_client(ctx)
client = connect_client(ctx)

payload = vfio_region_info(argsz=80, flags=0,
index=VFU_PCI_DEV_MIGR_REGION_IDX, cap_offset=0,
size=0, offset=0)
payload = bytes(payload) + b'\0' * (80 - 32)

result = msg(ctx, sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload)
result = msg(ctx, client.sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload)

info, result = vfio_region_info.pop_from_buffer(result)
mcap, result = vfio_region_info_cap_type.pop_from_buffer(result)
Expand All @@ -241,7 +242,7 @@ def test_device_get_region_info_migr():
assert area.size == migr_mmap_areas[0][1]

# skip reading the SCM_RIGHTS
disconnect_client(ctx, sock)
client.disconnect(ctx)


def test_device_get_region_info_cleanup():
Expand All @@ -260,13 +261,13 @@ def test_device_get_pci_config_space_info_implicit_pci_init():
ret = vfu_realize_ctx(ctx)
assert ret == 0

sock = connect_client(ctx)
client = connect_client(ctx)

payload = vfio_region_info(argsz=argsz + 8, flags=0,
index=VFU_PCI_DEV_CFG_REGION_IDX, cap_offset=0,
size=0, offset=0)

result = msg(ctx, sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload)
result = msg(ctx, client.sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload)

assert len(result) == argsz

Expand All @@ -281,7 +282,7 @@ def test_device_get_pci_config_space_info_implicit_pci_init():
assert info.size == PCI_CFG_SPACE_EXP_SIZE
assert info.offset == 0

disconnect_client(ctx, sock)
client.disconnect(ctx)

vfu_destroy_ctx(ctx)

Expand All @@ -296,13 +297,13 @@ def test_device_get_pci_config_space_info_implicit_no_pci_init():
ret = vfu_realize_ctx(ctx)
assert ret == 0

sock = connect_client(ctx)
client = connect_client(ctx)

payload = vfio_region_info(argsz=argsz + 8, flags=0,
index=VFU_PCI_DEV_CFG_REGION_IDX, cap_offset=0,
size=0, offset=0)

result = msg(ctx, sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload)
result = msg(ctx, client.sock, VFIO_USER_DEVICE_GET_REGION_INFO, payload)

assert len(result) == argsz

Expand All @@ -315,7 +316,7 @@ def test_device_get_pci_config_space_info_implicit_no_pci_init():
assert info.size == PCI_CFG_SPACE_SIZE
assert info.offset == 0

disconnect_client(ctx, sock)
client.disconnect(ctx)

vfu_destroy_ctx(ctx)

Expand Down
Loading

0 comments on commit a7eedff

Please sign in to comment.