diff --git a/src/sqlite.zig b/src/sqlite.zig index 1099f48..9444752 100644 --- a/src/sqlite.zig +++ b/src/sqlite.zig @@ -1,4 +1,5 @@ -pub const std = @import("std"); +const builtin = @import("builtin"); +const std = @import("std"); const util = @import("util.zig"); const Value = @import("value.zig").Value; const Connection = @import("connection.zig").Connection; @@ -14,6 +15,7 @@ pub const SQLite3 = opaque { flags: c_int = c.SQLITE_OPEN_READWRITE | c.SQLITE_OPEN_CREATE | c.SQLITE_OPEN_FULLMUTEX | c.SQLITE_OPEN_EXRESCODE, busy_timeout: ?c_int = 5_000, foreign_keys: ?enum { off, on } = .on, + extensions: []const []const u8 = &.{}, }; pub fn open(options: Options) !*SQLite3 { @@ -31,9 +33,43 @@ pub const SQLite3 = opaque { inline else => |t| db.execAll("PRAGMA foreign_keys = " ++ @tagName(t)) catch return error.ConnectionFailed, }; + for (options.extensions) |ext| { + db.loadExtension(ext) catch return error.ConnectionFailed; + } + return db; } + pub fn loadExtension(self: *SQLite3, name: []const u8) !void { + var err: ?[*:0]const u8 = null; + var buf: [255]u8 = undefined; + const sym = try std.fmt.bufPrintZ(&buf, "sqlite3_{s}_init", .{name}); + + switch (builtin.target.os.tag) { + .macos, .linux => { + const lib = dlopen(null, 1) orelse return error.LoadExtensionFailed; + defer _ = std.c.dlclose(lib); + + if (std.c.dlsym(lib, sym)) |p| { + // NOTE: Statically linked extensions should be compiled with -DSQLITE_CORE, + // in which case the sqlite3_api_routines parameter is unused. + // see https://github.com/sqlite/sqlite/blob/eaa50b866075f4c1a19065600e4f1bae059eb505/src/sqlite3ext.h#L712 + const init: *const fn (*c.sqlite3, *?[*:0]const u8, *const c.sqlite3_api_routines) callconv(.C) c_int = @ptrCast(@alignCast(p)); + return check(init(self.ptr(), &err, undefined)); + } + }, + else => {}, // TODO: Windows + } + + if (comptime @hasField(c, "sqlite3_enable_load_extension")) { + const zName = try std.fmt.bufPrintZ(&buf, "{s}", .{name}); + try check(c.sqlite3_load_extension(self.ptr(), zName, null, null)); + } else { + util.log.err("SQLite extension loading is disabled", .{}); + return error.LoadExtensionFailed; + } + } + pub fn execAll(self: *SQLite3, sql: []const u8) !void { const csql = try std.heap.c_allocator.dupeZ(u8, sql); defer std.heap.c_allocator.free(csql); @@ -137,3 +173,6 @@ pub fn check(code: c_int) !void { else => error.DbError, }; } + +// Because std.c.dlopen is wrong. +extern "c" fn dlopen(path: ?[*:0]const u8, mode: c_int) ?*anyopaque;