Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix slcanBus.get_version() and slcanBus.get_serial_number() #1904

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions can/interfaces/slcan.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import logging
import time
import warnings
from typing import Any, Optional, Tuple, Union
from queue import SimpleQueue
from typing import Any, Optional, Tuple, Union, cast

from can import BitTiming, BitTimingFd, BusABC, CanProtocol, Message, typechecking
from can.exceptions import (
Expand Down Expand Up @@ -131,6 +132,7 @@ def __init__(
timeout=timeout,
)

self._queue: SimpleQueue[str] = SimpleQueue()
self._buffer = bytearray()
self._can_protocol = CanProtocol.CAN_20

Expand Down Expand Up @@ -196,7 +198,7 @@ def _read(self, timeout: Optional[float]) -> Optional[str]:
# We read the `serialPortOrig.in_waiting` only once here.
in_waiting = self.serialPortOrig.in_waiting
for _ in range(max(1, in_waiting)):
new_byte = self.serialPortOrig.read(size=1)
new_byte = self.serialPortOrig.read(1)
if new_byte:
self._buffer.extend(new_byte)
else:
Expand Down Expand Up @@ -234,7 +236,10 @@ def _recv_internal(
extended = False
data = None

string = self._read(timeout)
if self._queue.qsize():
string: Optional[str] = self._queue.get_nowait()
else:
string = self._read(timeout)

if not string:
pass
Expand Down Expand Up @@ -300,7 +305,7 @@ def shutdown(self) -> None:

def fileno(self) -> int:
try:
return self.serialPortOrig.fileno()
return cast(int, self.serialPortOrig.fileno())
except io.UnsupportedOperation:
raise NotImplementedError(
"fileno is not implemented using current CAN bus on this platform"
Expand All @@ -321,19 +326,21 @@ def get_version(
int hw_version is the hardware version or None on timeout
int sw_version is the software version or None on timeout
"""
_timeout = serial.Timeout(timeout)
cmd = "V"
self._write(cmd)

string = self._read(timeout)

if not string:
pass
elif string[0] == cmd and len(string) == 6:
# convert ASCII coded version
hw_version = int(string[1:3])
sw_version = int(string[3:5])
return hw_version, sw_version

while True:
if string := self._read(_timeout.time_left()):
if string[0] == cmd:
# convert ASCII coded version
hw_version = int(string[1:3])
sw_version = int(string[3:5])
return hw_version, sw_version
else:
self._queue.put_nowait(string)
if _timeout.expired():
break
return None, None

def get_serial_number(self, timeout: Optional[float]) -> Optional[str]:
Expand All @@ -345,15 +352,17 @@ def get_serial_number(self, timeout: Optional[float]) -> Optional[str]:
:return:
:obj:`None` on timeout or a :class:`str` object.
"""
_timeout = serial.Timeout(timeout)
cmd = "N"
self._write(cmd)

string = self._read(timeout)

if not string:
pass
elif string[0] == cmd and len(string) == 6:
serial_number = string[1:-1]
return serial_number

while True:
if string := self._read(_timeout.time_left()):
if string[0] == cmd:
serial_number = string[1:-1]
return serial_number
else:
self._queue.put_nowait(string)
if _timeout.expired():
break
return None
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ exclude = [
"^can/interfaces/neousys",
"^can/interfaces/pcan",
"^can/interfaces/serial",
"^can/interfaces/slcan",
"^can/interfaces/socketcan",
"^can/interfaces/systec",
"^can/interfaces/udp_multicast",
Expand Down
100 changes: 79 additions & 21 deletions test/test_slcan.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python

import unittest
from typing import cast
import unittest.mock
from typing import cast, Optional

import serial
from serial.serialutil import SerialBase

import can.interfaces.slcan

Expand All @@ -21,20 +21,69 @@
TIMEOUT = 0.5 if IS_PYPY else 0.01 # 0.001 is the default set in slcanBus


class SerialMock(SerialBase):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

self._input_buffer = b""
self._output_buffer = b""

def open(self) -> None:
self.is_open = True

def close(self) -> None:
self.is_open = False
self._input_buffer = b""
self._output_buffer = b""

def read(self, size: int = -1, /) -> bytes:
if size > 0:
data = self._input_buffer[:size]
self._input_buffer = self._input_buffer[size:]
return data
return b""

def write(self, b: bytes, /) -> Optional[int]:
self._output_buffer = b
if b == b"N\r":
self.set_input_buffer(b"NA123\r")
elif b == b"V\r":
self.set_input_buffer(b"V1013\r")
return len(b)

def set_input_buffer(self, expected: bytes) -> None:
self._input_buffer = expected

def get_output_buffer(self) -> bytes:
return self._output_buffer

def reset_input_buffer(self) -> None:
self._input_buffer = b""

@property
def in_waiting(self) -> int:
return len(self._input_buffer)

@classmethod
def serial_for_url(cls, *args, **kwargs) -> SerialBase:
return cls(*args, **kwargs)


class slcanTestCase(unittest.TestCase):
@unittest.mock.patch("serial.serial_for_url", SerialMock.serial_for_url)
def setUp(self):
self.bus = cast(
can.interfaces.slcan.slcanBus,
can.Bus("loop://", interface="slcan", sleep_after_open=0, timeout=TIMEOUT),
)
self.serial = cast(serial.Serial, self.bus.serialPortOrig)
self.serial = cast(SerialMock, self.bus.serialPortOrig)
self.serial.reset_input_buffer()

def tearDown(self):
self.bus.shutdown()

def test_recv_extended(self):
self.serial.write(b"T12ABCDEF2AA55\r")
self.serial.set_input_buffer(b"T12ABCDEF2AA55\r")
msg = self.bus.recv(TIMEOUT)
self.assertIsNotNone(msg)
self.assertEqual(msg.arbitration_id, 0x12ABCDEF)
Expand All @@ -44,7 +93,7 @@ def test_recv_extended(self):
self.assertSequenceEqual(msg.data, [0xAA, 0x55])

# Ewert Energy Systems CANDapter specific
self.serial.write(b"x12ABCDEF2AA55\r")
self.serial.set_input_buffer(b"x12ABCDEF2AA55\r")
msg = self.bus.recv(TIMEOUT)
self.assertIsNotNone(msg)
self.assertEqual(msg.arbitration_id, 0x12ABCDEF)
Expand All @@ -54,15 +103,19 @@ def test_recv_extended(self):
self.assertSequenceEqual(msg.data, [0xAA, 0x55])

def test_send_extended(self):
payload = b"T12ABCDEF2AA55\r"
msg = can.Message(
arbitration_id=0x12ABCDEF, is_extended_id=True, data=[0xAA, 0x55]
)
self.bus.send(msg)
self.assertEqual(payload, self.serial.get_output_buffer())

self.serial.set_input_buffer(payload)
rx_msg = self.bus.recv(TIMEOUT)
self.assertTrue(msg.equals(rx_msg, timestamp_delta=None))

def test_recv_standard(self):
self.serial.write(b"t4563112233\r")
self.serial.set_input_buffer(b"t4563112233\r")
msg = self.bus.recv(TIMEOUT)
self.assertIsNotNone(msg)
self.assertEqual(msg.arbitration_id, 0x456)
Expand All @@ -72,15 +125,19 @@ def test_recv_standard(self):
self.assertSequenceEqual(msg.data, [0x11, 0x22, 0x33])

def test_send_standard(self):
payload = b"t4563112233\r"
msg = can.Message(
arbitration_id=0x456, is_extended_id=False, data=[0x11, 0x22, 0x33]
)
self.bus.send(msg)
self.assertEqual(payload, self.serial.get_output_buffer())

self.serial.set_input_buffer(payload)
rx_msg = self.bus.recv(TIMEOUT)
self.assertTrue(msg.equals(rx_msg, timestamp_delta=None))

def test_recv_standard_remote(self):
self.serial.write(b"r1238\r")
self.serial.set_input_buffer(b"r1238\r")
msg = self.bus.recv(TIMEOUT)
self.assertIsNotNone(msg)
self.assertEqual(msg.arbitration_id, 0x123)
Expand All @@ -89,15 +146,19 @@ def test_recv_standard_remote(self):
self.assertEqual(msg.dlc, 8)

def test_send_standard_remote(self):
payload = b"r1238\r"
msg = can.Message(
arbitration_id=0x123, is_extended_id=False, is_remote_frame=True, dlc=8
)
self.bus.send(msg)
self.assertEqual(payload, self.serial.get_output_buffer())

self.serial.set_input_buffer(payload)
rx_msg = self.bus.recv(TIMEOUT)
self.assertTrue(msg.equals(rx_msg, timestamp_delta=None))

def test_recv_extended_remote(self):
self.serial.write(b"R12ABCDEF6\r")
self.serial.set_input_buffer(b"R12ABCDEF6\r")
msg = self.bus.recv(TIMEOUT)
self.assertIsNotNone(msg)
self.assertEqual(msg.arbitration_id, 0x12ABCDEF)
Expand All @@ -106,19 +167,23 @@ def test_recv_extended_remote(self):
self.assertEqual(msg.dlc, 6)

def test_send_extended_remote(self):
payload = b"R12ABCDEF6\r"
msg = can.Message(
arbitration_id=0x12ABCDEF, is_extended_id=True, is_remote_frame=True, dlc=6
)
self.bus.send(msg)
self.assertEqual(payload, self.serial.get_output_buffer())

self.serial.set_input_buffer(payload)
rx_msg = self.bus.recv(TIMEOUT)
self.assertTrue(msg.equals(rx_msg, timestamp_delta=None))

def test_partial_recv(self):
self.serial.write(b"T12ABCDEF")
self.serial.set_input_buffer(b"T12ABCDEF")
msg = self.bus.recv(TIMEOUT)
self.assertIsNone(msg)

self.serial.write(b"2AA55\rT12")
self.serial.set_input_buffer(b"2AA55\rT12")
msg = self.bus.recv(TIMEOUT)
self.assertIsNotNone(msg)
self.assertEqual(msg.arbitration_id, 0x12ABCDEF)
Expand All @@ -130,28 +195,21 @@ def test_partial_recv(self):
msg = self.bus.recv(TIMEOUT)
self.assertIsNone(msg)

self.serial.write(b"ABCDEF2AA55\r")
self.serial.set_input_buffer(b"ABCDEF2AA55\r")
msg = self.bus.recv(TIMEOUT)
self.assertIsNotNone(msg)

def test_version(self):
self.serial.write(b"V1013\r")
hw_ver, sw_ver = self.bus.get_version(0)
self.assertEqual(b"V\r", self.serial.get_output_buffer())
self.assertEqual(hw_ver, 10)
self.assertEqual(sw_ver, 13)

hw_ver, sw_ver = self.bus.get_version(0)
self.assertIsNone(hw_ver)
self.assertIsNone(sw_ver)

def test_serial_number(self):
self.serial.write(b"NA123\r")
sn = self.bus.get_serial_number(0)
self.assertEqual(b"N\r", self.serial.get_output_buffer())
self.assertEqual(sn, "A123")

sn = self.bus.get_serial_number(0)
self.assertIsNone(sn)


if __name__ == "__main__":
unittest.main()
Loading