From 70540055985b45d7638c79ebe6db150294eef0f7 Mon Sep 17 00:00:00 2001 From: Thijs Schreijer Date: Wed, 3 Nov 2021 16:27:23 +0100 Subject: [PATCH] feat(topics) add validation and matching --- .editorconfig | 14 ++ mqtt/init.lua | 177 +++++++++++++++++++++++++ tests/spec/topics.lua | 300 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 491 insertions(+) create mode 100644 .editorconfig create mode 100644 tests/spec/topics.lua diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..d646b37 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,14 @@ +root = true + +[*] +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true +charset = utf-8 + +[*.lua] +indent_style = tab +indent_size = 4 + +[Makefile] +indent_style = tab diff --git a/mqtt/init.lua b/mqtt/init.lua index 5f9e848..3c042d4 100644 --- a/mqtt/init.lua +++ b/mqtt/init.lua @@ -81,6 +81,183 @@ function mqtt.run_sync(cl) end end + +--- Validates a topic with wildcards. +-- @param t (string) wildcard topic to validate +-- @return topic, or false+error +function mqtt.validate_subscribe_topic(t) + if type(t) ~= "string" then + return false, "not a string" + end + if #t < 1 then + return false, "minimum topic length is 1" + end + do + local _, count = t:gsub("#", "") + if count > 1 then + return false, "wildcard '#' may only appear once" + end + if count == 1 then + if t ~= "#" and not t:find("/#$") then + return false, "wildcard '#' must be the last character, and be prefixed with '/' (unless the topic is '#')" + end + end + end + do + local t1 = "/"..t.."/" + local i = 1 + while i do + i = t1:find("+", i) + if i then + if t1:sub(i-1, i+1) ~= "/+/" then + return false, "wildcard '+' must be enclosed between '/' (except at start/end)" + end + i = i + 1 + end + end + end + return t +end + +--- Validates a topic without wildcards. +-- @param t (string) topic to validate +-- @return topic, or false+error +function mqtt.validate_publish_topic(t) + if type(t) ~= "string" then + return false, "not a string" + end + if #t < 1 then + return false, "minimum topic length is 1" + end + if t:find("+", nil, true) or t:find("#", nil, true) then + return false, "wildcards '#', and '+' are not allowed when publishing" + end + return t +end + +--- Returns a Lua pattern from topic. +-- Takes a wildcarded-topic and returns a Lua pattern that can be used +-- to validate if a received topic matches the wildcard-topic +-- @param t (string) the wildcard topic +-- @return Lua-pattern (string) or false+err +-- @usage +-- local patt = compile_topic_pattern("homes/+/+/#") +-- +-- local topic = "homes/myhome/living/mainlights/brightness" +-- local homeid, roomid, varargs = topic:match(patt) +function mqtt.compile_topic_pattern(t) + local ok, err = mqtt.validate_subscribe_topic(t) + if not ok then + return ok, err + end + if t == "#" then + t = "(.+)" -- matches anything at least 1 character long + else + t = t:gsub("#","(.-)") -- match anything, can be empty + t = t:gsub("%+","([^/]-)") -- match anything between '/', can be empty + end + return "^"..t.."$" +end + +--- Parses wildcards in a topic into a table. +-- Options include: +-- +-- - `opts.topic`: the wild-carded topic to match against (optional if `opts.pattern` is given) +-- +-- - `opts.pattern`: the compiled pattern for the wild-carded topic (optional if `opts.topic` +-- is given). If not given then topic will be compiled and the result will be +-- stored in this field for future use (cache). +-- +-- - `opts.keys`: (optional) array of field names. The order must be the same as the +-- order of the wildcards in `topic` +-- +-- Returned tables: +-- +-- - `fields` table: the array part will have the values of the wildcards, in +-- the order they appeared. The hash part, will have the field names provided +-- in `opts.keys`, with the values of the corresponding wildcard. If a `#` +-- wildcard was used, that one will be the last in the table. +-- +-- - `varargs` table: will only be returned if the wildcard topic contained the +-- `#` wildcard. The returned table is an array, with all segments that were +-- matched by the `#` wildcard. +-- @param topic (string) incoming topic string (required) +-- @param opts (table) with options (required) +-- @return fields (table) + varargs (table or nil), or false+err on error. +-- @usage +-- local opts = { +-- topic = "homes/+/+/#", +-- keys = { "homeid", "roomid", "varargs"}, +-- } +-- local fields, varargs = topic_match("homes/myhome/living/mainlights/brightness", opts) +-- +-- print(fields[1], fields.homeid) -- "myhome myhome" +-- print(fields[2], fields.roomid) -- "living living" +-- print(fields[3], fields.varargs) -- "mainlights/brightness mainlights/brightness" +-- +-- print(varargs[1]) -- "mainlights" +-- print(varargs[2]) -- "brightness" +function mqtt.topic_match(topic, opts) + if type(topic) ~= "string" then + return false, "expected topic to be a string" + end + if type(opts) ~= "table" then + return false, "expected optionss to be a table" + end + local pattern = opts.pattern + if not pattern then + local ptopic = opts.topic + if not ptopic then + return false, "either 'opts.topic' or 'opts.pattern' must set" + end + local err + pattern, err = mqtt.compile_topic_pattern(ptopic) + if not pattern then + return false, "failed to compile 'opts.topic' into pattern: "..tostring(err) + end + -- store/cache compiled pattern for next time + opts.pattern = pattern + end + local values = { topic:match(pattern) } + if values[1] == nil then + return false, "topic does not match wildcard pattern" + end + local keys = opts.keys + if keys ~= nil then + if type(keys) ~= "table" then + return false, "expected 'opts.keys' to be a table (array)" + end + -- we have a table with keys, copy values to fields + for i, value in ipairs(values) do + local key = keys[i] + if key ~= nil then + values[key] = value + end + end + end + if not pattern:find("%(%.[%-%+]%)%$$") then -- pattern for "#" as last char + -- we're done + return values + end + -- we have a '#' wildcard + local vararg = values[#values] + local varargs = {} + local i = 0 + local ni = 0 + while ni do + ni = vararg:find("/", i, true) + if ni then + varargs[#varargs + 1] = vararg:sub(i, ni-1) + i = ni + 1 + else + varargs[#varargs + 1] = vararg:sub(i, -1) + end + end + + return values, varargs +end + + -- export module table return mqtt diff --git a/tests/spec/topics.lua b/tests/spec/topics.lua new file mode 100644 index 0000000..36070ce --- /dev/null +++ b/tests/spec/topics.lua @@ -0,0 +1,300 @@ +local mqtt = require "mqtt" + +describe("topics", function() + + describe("publish (plain)", function() + it("allows proper topics", function() + local ok, err + ok, err = mqtt.validate_publish_topic("hello/world") + assert.is_nil(err) + assert.is.truthy(ok) + + ok, err = mqtt.validate_publish_topic("hello/world/") + assert.is_nil(err) + assert.is.truthy(ok) + + ok, err = mqtt.validate_publish_topic("/hello/world") + assert.is_nil(err) + assert.is.truthy(ok) + + ok, err = mqtt.validate_publish_topic("/") + assert.is_nil(err) + assert.is.truthy(ok) + + ok, err = mqtt.validate_publish_topic("//////") + assert.is_nil(err) + assert.is.truthy(ok) + + ok, err = mqtt.validate_publish_topic("/") + assert.is_nil(err) + assert.is.truthy(ok) + + end) + + it("returns the topic passed in on success", function() + local ok = mqtt.validate_publish_topic("hello/world") + assert.are.equal("hello/world", ok) + end) + + it("must be a string", function() + local ok, err = mqtt.validate_publish_topic(true) + assert.is_false(ok) + assert.is_string(err) + end) + + it("minimum length 1", function() + local ok, err = mqtt.validate_publish_topic("") + assert.is_false(ok) + assert.is_string(err) + end) + + it("wildcard '#' is not allowed", function() + local ok, err = mqtt.validate_publish_topic("hello/world/#") + assert.is_false(ok) + assert.is_string(err) + end) + + it("wildcard '+' is not allowed", function() + local ok, err = mqtt.validate_publish_topic("hello/+/world") + assert.is_false(ok) + assert.is_string(err) + end) + + end) + + + + describe("subscribe (wildcarded)", function() + + it("allows proper topics", function() + local ok, err + ok, err = mqtt.validate_subscribe_topic("hello/world") + assert.is_nil(err) + assert.is.truthy(ok) + + ok, err = mqtt.validate_subscribe_topic("hello/world/") + assert.is_nil(err) + assert.is.truthy(ok) + + ok, err = mqtt.validate_subscribe_topic("/hello/world") + assert.is_nil(err) + assert.is.truthy(ok) + + ok, err = mqtt.validate_subscribe_topic("/") + assert.is_nil(err) + assert.is.truthy(ok) + + ok, err = mqtt.validate_subscribe_topic("//////") + assert.is_nil(err) + assert.is.truthy(ok) + + ok, err = mqtt.validate_subscribe_topic("#") + assert.is_nil(err) + assert.is.truthy(ok) + + ok, err = mqtt.validate_subscribe_topic("/#") + assert.is_nil(err) + assert.is.truthy(ok) + + ok, err = mqtt.validate_subscribe_topic("+") + assert.is_nil(err) + assert.is.truthy(ok) + + ok, err = mqtt.validate_subscribe_topic("+/hello/#") + assert.is_nil(err) + assert.is.truthy(ok) + + ok, err = mqtt.validate_subscribe_topic("+/+/+/+/+") + assert.is_nil(err) + assert.is.truthy(ok) + end) + + it("returns the topic passed in on success", function() + local ok = mqtt.validate_subscribe_topic("hello/world") + assert.are.equal("hello/world", ok) + end) + + it("must be a string", function() + local ok, err = mqtt.validate_subscribe_topic(true) + assert.is_false(ok) + assert.is_string(err) + end) + + it("minimum length 1", function() + local ok, err = mqtt.validate_subscribe_topic("") + assert.is_false(ok) + assert.is_string(err) + end) + + it("wildcard '#' is only allowed as last segment", function() + local ok, err = mqtt.validate_subscribe_topic("hello/#/world") + assert.is_false(ok) + assert.is_string(err) + end) + + it("wildcard '+' is only allowed as full segment", function() + local ok, err = mqtt.validate_subscribe_topic("hello/+there/world") + assert.is_false(ok) + assert.is_string(err) + end) + + end) + + + + describe("pattern compiler & matcher", function() + + it("basic parsing works", function() + local opts = { + topic = "+/+", + pattern = nil, + keys = { "hello", "world"} + } + local res, err = mqtt.topic_match("hello/world", opts) + assert.is_nil(err) + assert.same(res, { + "hello", "world", + hello = "hello", + world = "world", + }) + -- compiled pattern is now added + assert.not_nil(opts.pattern) + end) + + it("incoming topic is required", function() + local opts = { + topic = "+/+", + pattern = nil, + keys = { "hello", "world"} + } + local ok, err = mqtt.topic_match(nil, opts) + assert.is_false(ok) + assert.is_string(err) + end) + + it("wildcard topic or pattern is required", function() + local opts = { + topic = nil, + pattern = nil, + keys = { "hello", "world"} + } + local ok, err = mqtt.topic_match("hello/world", opts) + assert.is_false(ok) + assert.is_string(err) + end) + + it("pattern must match", function() + local opts = { + topic = "+/+/+", -- one too many + pattern = nil, + keys = { "hello", "world"} + } + local ok, err = mqtt.topic_match("hello/world", opts) + assert.is_false(ok) + assert.is_string(err) + end) + + it("pattern '+' works", function() + local opts = { + topic = "+", + pattern = nil, + keys = { "hello" } + } + -- matches topic + local res, err = mqtt.topic_match("hello", opts) + assert.is_nil(err) + assert.same(res, { + "hello", + hello = "hello", + }) + end) + + it("wildcard '+' matches empty segments", function() + local opts = { + topic = "+/+/+", + pattern = nil, + keys = { "hello", "there", "world"} + } + local res, err = mqtt.topic_match("//", opts) + assert.is_nil(err) + assert.same(res, { + "", "", "", + hello = "", + there = "", + world = "", + }) + end) + + it("pattern '#' matches all segments", function() + local opts = { + topic = "#", + pattern = nil, + keys = nil, + } + local res, var = mqtt.topic_match("hello/there/world", opts) + assert.same(res, { + "hello/there/world" + }) + assert.same(var, { + "hello", + "there", + "world", + }) + end) + + it("pattern '/#' skips first segment", function() + local opts = { + topic = "/#", + pattern = nil, + keys = nil, + } + local res, var = mqtt.topic_match("/hello/world", opts) + assert.same(res, { + "hello/world" + }) + assert.same(var, { + "hello", + "world", + }) + end) + + it("combined wildcards '+/+/#'", function() + local opts = { + topic = "+/+/#", + pattern = nil, + keys = nil, + } + local res, var = mqtt.topic_match("hello/there/my/world", opts) + assert.same(res, { + "hello", + "there", + "my/world" + }) + assert.same(var, { + "my", + "world", + }) + end) + + it("trailing '/' in topic with '#'", function() + local opts = { + topic = "+/+/#", + pattern = nil, + keys = nil, + } + local res, var = mqtt.topic_match("hello/there/world/", opts) + assert.same(res, { + "hello", + "there", + "world/" + }) + assert.same(var, { + "world", + "", + }) + end) + + + end) + +end)