diff --git a/src/connection.zig b/src/connection.zig index 2de3e6b..9fdc218 100644 --- a/src/connection.zig +++ b/src/connection.zig @@ -29,8 +29,7 @@ const ErrorInfo = @import("./error.zig").ErrorInfo; const Statistics = @import("./statistics.zig").Statistics; const StatsCounts = @import("./statistics.zig").StatsCounts; -const thunk = @import("./thunk.zig"); -const checkUserDataType = @import("./thunk.zig").checkUserDataType; +const thunkhelper = @import("./thunk.zig"); pub const default_server_url: [:0]const u8 = nats_c.NATS_DEFAULT_URL; @@ -346,7 +345,7 @@ pub const Connection = opaque { @ptrCast(self), subject.ptr, makeSubscriptionCallbackThunk(T, callback), - @constCast(@ptrCast(userdata)), + thunkhelper.opaqueFromUserdata(userdata), )); return status.toError() orelse sub; } @@ -367,7 +366,7 @@ pub const Connection = opaque { subject.ptr, timeout, makeSubscriptionCallbackThunk(T, callback), - @constCast(@ptrCast(userdata)), + thunkhelper.opaqueFromUserdata(userdata), )); return status.toError() orelse sub; @@ -401,7 +400,7 @@ pub const Connection = opaque { subject.ptr, queue_group.ptr, makeSubscriptionCallbackThunk(T, callback), - @constCast(@ptrCast(userdata)), + thunkhelper.opaqueFromUserdata(userdata), )); return status.toError() orelse sub; @@ -425,7 +424,7 @@ pub const Connection = opaque { queue_group.ptr, timeout, makeSubscriptionCallbackThunk(T, callback), - @constCast(@ptrCast(userdata)), + thunkhelper.opaqueFromUserdata(userdata), )); return status.toError() orelse sub; @@ -502,7 +501,7 @@ pub const ConnectionOptions = opaque { return Status.fromInt(nats_c.natsOptions_SetTokenHandler( @ptrCast(self), makeTokenCallbackThunk(T, callback), - @constCast(@ptrCast(userdata)), + thunkhelper.opaqueFromUserdata(userdata), )).raise(); } @@ -642,7 +641,7 @@ pub const ConnectionOptions = opaque { nats_c.natsOptions_SetCustomReconnectDelay( @ptrCast(self), makeReconnectDelayCallbackThunk(T, callback), - @constCast(@ptrCast(userdata)), + thunkhelper.opaqueFromUserdata(userdata), ), ).raise(); } @@ -669,7 +668,7 @@ pub const ConnectionOptions = opaque { nats_c.natsOptions_SetErrorHandler( @ptrCast(self), makeErrorHandlerCallbackThunk(T, callback), - @constCast(@ptrCast(userdata)), + thunkhelper.opaqueFromUserdata(userdata), ), ).raise(); } @@ -683,7 +682,7 @@ pub const ConnectionOptions = opaque { return Status.fromInt(nats_c.natsOptions_SetClosedCB( @ptrCast(self), makeConnectionCallbackThunk(T, callback), - @constCast(@ptrCast(userdata)), + thunkhelper.opaqueFromUserdata(userdata), )).raise(); } @@ -696,7 +695,7 @@ pub const ConnectionOptions = opaque { return Status.fromInt(nats_c.natsOptions_SetClosedCB( @ptrCast(self), makeConnectionCallbackThunk(T, callback), - @constCast(@ptrCast(userdata)), + thunkhelper.opaqueFromUserdata(userdata), )).raise(); } @@ -709,7 +708,7 @@ pub const ConnectionOptions = opaque { return Status.fromInt(nats_c.natsOptions_SetClosedCB( @ptrCast(self), makeConnectionCallbackThunk(T, callback), - @constCast(@ptrCast(userdata)), + thunkhelper.opaqueFromUserdata(userdata), )).raise(); } @@ -722,7 +721,7 @@ pub const ConnectionOptions = opaque { return Status.fromInt(nats_c.natsOptions_SetClosedCB( @ptrCast(self), makeConnectionCallbackThunk(T, callback), - @constCast(@ptrCast(userdata)), + thunkhelper.opaqueFromUserdata(userdata), )).raise(); } @@ -735,7 +734,7 @@ pub const ConnectionOptions = opaque { return Status.fromInt(nats_c.natsOptions_SetClosedCB( @ptrCast(self), makeConnectionCallbackThunk(T, callback), - @constCast(@ptrCast(userdata)), + thunkhelper.opaqueFromUserdata(userdata), )).raise(); } @@ -746,12 +745,12 @@ pub const ConnectionOptions = opaque { comptime attach_callback: *const AttachEventLoopCallbackSignature(T, L), comptime read_callback: *const AttachEventLoopCallbackSignature(T), comptime write_callback: *const AttachEventLoopCallbackSignature(T), - comptime detach_callback: *const thunk.SimpleCallbackSignature(T), + comptime detach_callback: *const thunkhelper.SimpleCallbackSignature(T), loop: L, ) Error!void { return Status.fromInt(nats_c.natsOptions_SetEventLoop( @ptrCast(self), - @constCast(@ptrCast(loop)), + thunkhelper.opaqueFromUserdata(loop), makeAttachEventLoopCallbackThunk(T, L, attach_callback), makeEventLoopAddRemoveCallbackThunk(T, read_callback), makeEventLoopAddRemoveCallbackThunk(T, write_callback), @@ -820,7 +819,7 @@ pub const ConnectionOptions = opaque { @ptrCast(self), retry, makeConnectionCallbackThunk(T, callback), - @constCast(@ptrCast(userdata)), + thunkhelper.opaqueFromUserdata(userdata), )).raise(); } @@ -836,9 +835,9 @@ pub const ConnectionOptions = opaque { return Status.fromInt(nats_c.natsOptions_SetUserCredentialsCallbacks( @ptrCast(self), makeJwtHandlerCallbackThunk(T, jwt_callback), - @constCast(@ptrCast(jwt_userdata)), + thunkhelper.opaqueFromUserdata(jwt_userdata), makeSignatureHandlerCallbackThunk(U, sig_callback), - @constCast(@ptrCast(sig_userdata)), + thunkhelper.opaqueFromUserdata(sig_userdata), )).raise(); } @@ -876,7 +875,7 @@ pub const ConnectionOptions = opaque { @ptrCast(self), pub_key.ptr, makeSignatureHandlerCallbackThunk(T, sig_callback), - @constCast(@ptrCast(sig_userdata)), + thunkhelper.opaqueFromUserdata(sig_userdata), )).raise(); } @@ -917,18 +916,17 @@ pub const ConnectionOptions = opaque { const TokenCallback = fn (?*anyopaque) callconv(.C) [*c]const u8; -pub fn TokenCallbackSignature(comptime T: type) type { - return fn (T) [:0]const u8; +pub fn TokenCallbackSignature(comptime UDT: type) type { + return fn (UDT) [:0]const u8; } fn makeTokenCallbackThunk( - comptime T: type, - comptime callback: *const TokenCallbackSignature(T), + comptime UDT: type, + comptime callback: *const TokenCallbackSignature(UDT), ) *const TokenCallback { - comptime checkUserDataType(T); return struct { fn thunk(userdata: ?*anyopaque) callconv(.C) [*c]const u8 { - const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data = thunkhelper.userdataFromOpaque(UDT, userdata); return callback(data).ptr; } }.thunk; @@ -936,19 +934,18 @@ fn makeTokenCallbackThunk( const ConnectionCallback = fn (?*nats_c.natsConnection, ?*anyopaque) callconv(.C) void; -pub fn ConnectionCallbackSignature(comptime T: type) type { - return fn (T, *Connection) void; +pub fn ConnectionCallbackSignature(comptime UDT: type) type { + return fn (UDT, *Connection) void; } fn makeConnectionCallbackThunk( - comptime T: type, - comptime callback: *const ConnectionCallbackSignature(T), + comptime UDT: type, + comptime callback: *const ConnectionCallbackSignature(UDT), ) *const ConnectionCallback { - comptime checkUserDataType(T); return struct { fn thunk(conn: ?*nats_c.natsConnection, userdata: ?*anyopaque) callconv(.C) void { const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; - const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data = thunkhelper.userdataFromOpaque(UDT, userdata); callback(data, connection); } }.thunk; @@ -956,15 +953,14 @@ fn makeConnectionCallbackThunk( const ReconnectDelayCallback = fn (?*nats_c.natsConnection, c_int, ?*anyopaque) callconv(.C) i64; -pub fn ReconnectDelayCallbackSignature(comptime T: type) type { - return fn (T, *Connection, c_int) i64; +pub fn ReconnectDelayCallbackSignature(comptime UDT: type) type { + return fn (UDT, *Connection, c_int) i64; } fn makeReconnectDelayCallbackThunk( - comptime T: type, - comptime callback: *const ReconnectDelayCallbackSignature(T), + comptime UDT: type, + comptime callback: *const ReconnectDelayCallbackSignature(UDT), ) *const ReconnectDelayCallback { - comptime checkUserDataType(T); return struct { fn thunk( conn: ?*nats_c.natsConnection, @@ -972,7 +968,7 @@ fn makeReconnectDelayCallbackThunk( userdata: ?*anyopaque, ) callconv(.C) i64 { const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; - const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data = thunkhelper.userdataFromOpaque(UDT, userdata); return callback(data, connection, attempts); } }.thunk; @@ -985,15 +981,14 @@ const ErrorHandlerCallback = fn ( ?*anyopaque, ) callconv(.C) void; -pub fn ErrorHandlerCallbackSignature(comptime T: type) type { - return fn (T, *Connection, *Subscription, Status) void; +pub fn ErrorHandlerCallbackSignature(comptime UDT: type) type { + return fn (UDT, *Connection, *Subscription, Status) void; } fn makeErrorHandlerCallbackThunk( - comptime T: type, - comptime callback: *const ErrorHandlerCallbackSignature(T), + comptime UDT: type, + comptime callback: *const ErrorHandlerCallbackSignature(UDT), ) *const ErrorHandlerCallback { - comptime checkUserDataType(T); return struct { fn thunk( conn: ?*nats_c.natsConnection, @@ -1003,8 +998,8 @@ fn makeErrorHandlerCallbackThunk( ) callconv(.C) void { const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; const subscription: *Subscription = if (sub) |s| @ptrCast(s) else unreachable; - const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data = thunkhelper.userdataFromOpaque(UDT, userdata); callback(data, connection, subscription, Status.fromInt(status)); } }.thunk; @@ -1018,17 +1013,16 @@ const AttachEventLoopCallback = fn ( nats_c.natsSock, ) callconv(.C) nats_c.natsStatus; -pub fn AttachEventLoopCallbackSignature(comptime T: type, comptime L: type) type { - return fn (L, *Connection, c_int) anyerror!T; +pub fn AttachEventLoopCallbackSignature(comptime UDT: type, comptime L: type) type { + return fn (L, *Connection, c_int) anyerror!UDT; } fn makeAttachEventLoopCallbackThunk( - comptime T: type, + comptime UDT: type, comptime L: type, - comptime callback: *const AttachEventLoopCallbackSignature(T, L), + comptime callback: *const AttachEventLoopCallbackSignature(UDT, L), ) *const ReconnectDelayCallback { - comptime checkUserDataType(T); - comptime checkUserDataType(L); + comptime thunkhelper.checkUserDataType(L); return struct { fn thunk( userdata: *?*anyopaque, @@ -1037,10 +1031,12 @@ fn makeAttachEventLoopCallbackThunk( sock: ?*nats_c.natsSock, ) callconv(.C) nats_c.natsStatus { const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; - const ev_loop: L = if (loop) |l| @alignCast(@ptrCast(l)) else unreachable; - userdata.* = callback(ev_loop, connection, sock) catch |err| + const ev_loop = thunkhelper.userdataFromOpaque(L, loop); + + const result = callback(ev_loop, connection, sock) catch |err| return Status.fromError(err).toInt(); + userdata.* = thunkhelper.opaqueFromUserdata(result); return nats_c.NATS_OK; } @@ -1049,15 +1045,14 @@ fn makeAttachEventLoopCallbackThunk( const EventLoopAddRemoveCallback = fn (?*nats_c.natsConnection, c_int, ?*anyopaque) callconv(.C) nats_c.natsStatus; -pub fn EventLoopAddRemoveCallbackSignature(comptime T: type) type { - return fn (T, *Connection, c_int) anyerror!void; +pub fn EventLoopAddRemoveCallbackSignature(comptime UDT: type) type { + return fn (UDT, *Connection, c_int) anyerror!void; } fn makeEventLoopAddRemoveCallbackThunk( - comptime T: type, - comptime callback: *const EventLoopAddRemoveCallbackSignature(T), + comptime UDT: type, + comptime callback: *const EventLoopAddRemoveCallbackSignature(UDT), ) *const ReconnectDelayCallback { - comptime checkUserDataType(T); return struct { fn thunk( conn: ?*nats_c.natsConnection, @@ -1065,7 +1060,7 @@ fn makeEventLoopAddRemoveCallbackThunk( userdata: ?*anyopaque, ) callconv(.C) nats_c.natsStatus { const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; - const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data = thunkhelper.userdataFromOpaque(UDT, userdata); callback(data, connection, attempts) catch |err| return Status.fromError(err).toInt(); @@ -1076,21 +1071,21 @@ fn makeEventLoopAddRemoveCallbackThunk( const EventLoopDetachCallback = fn (?*anyopaque) callconv(.C) nats_c.natsStatus; -pub fn EventLoopDetachCallbackSignature(comptime T: type) type { - return fn (T) anyerror!void; +pub fn EventLoopDetachCallbackSignature(comptime UDT: type) type { + return fn (UDT) anyerror!void; } fn makeEventLoopDetachCallbackThunk( - comptime T: type, - comptime callback: *const EventLoopDetachCallbackSignature(T), + comptime UDT: type, + comptime callback: *const EventLoopDetachCallbackSignature(UDT), ) *const ReconnectDelayCallback { - comptime checkUserDataType(T); return struct { fn thunk( userdata: ?*anyopaque, ) callconv(.C) nats_c.natsStatus { - const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data = thunkhelper.userdataFromOpaque(UDT, userdata); callback(data) catch |err| return Status.fromError(err).toInt(); + return nats_c.NATS_OK; } }.thunk; @@ -1105,26 +1100,24 @@ pub const JwtResponseOrError = union(enum) { error_message: [:0]u8, }; -pub fn JwtHandlerCallbackSignature(comptime T: type) type { - return fn (T) JwtResponseOrError; +pub fn JwtHandlerCallbackSignature(comptime UDT: type) type { + return fn (UDT) JwtResponseOrError; } fn makeJwtHandlerCallbackThunk( - comptime T: type, - comptime callback: *const JwtHandlerCallbackSignature(T), + comptime UDT: type, + comptime callback: *const JwtHandlerCallbackSignature(UDT), ) *const JwtHandlerCallback { - comptime checkUserDataType(T); return struct { fn thunk( jwt_out_raw: ?*?[*:0]u8, err_out_raw: ?*?[*:0]u8, userdata: ?*anyopaque, ) callconv(.C) nats_c.natsStatus { - const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; const err_out = err_out_raw orelse unreachable; const jwt_out = jwt_out_raw orelse unreachable; - switch (callback(data)) { + switch (callback(thunkhelper.userdataFromOpaque(UDT, userdata))) { .jwt => |jwt| { jwt_out.* = jwt.ptr; return nats_c.NATS_OK; @@ -1147,15 +1140,14 @@ pub const SignatureResponseOrError = union(enum) { error_message: [:0]u8, }; -pub fn SignatureHandlerCallbackSignature(comptime T: type) type { - return fn (T, [:0]const u8) SignatureResponseOrError; +pub fn SignatureHandlerCallbackSignature(comptime UDT: type) type { + return fn (UDT, [:0]const u8) SignatureResponseOrError; } fn makeSignatureHandlerCallbackThunk( - comptime T: type, - comptime callback: *const SignatureHandlerCallbackSignature(T), + comptime UDT: type, + comptime callback: *const SignatureHandlerCallbackSignature(UDT), ) *const SignatureHandlerCallback { - comptime checkUserDataType(T); return struct { fn thunk( err_out_raw: ?*?[*:0]u8, @@ -1164,12 +1156,12 @@ fn makeSignatureHandlerCallbackThunk( nonsense: ?[*:0]const u8, userdata: ?*anyopaque, ) callconv(.C) nats_c.natsStatus { - const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; const nonce = nonsense orelse unreachable; const err_out = err_out_raw orelse unreachable; const sig_out = sig_out_raw orelse unreachable; const sig_len_out = sig_len_out_raw orelse unreachable; + const data = thunkhelper.userdataFromOpaque(UDT, userdata); switch (callback(data, std.mem.sliceTo(nonce, 0))) { .signature => |sig| { sig_out.* = sig.ptr; diff --git a/src/subscription.zig b/src/subscription.zig index f430f2a..546c9f4 100644 --- a/src/subscription.zig +++ b/src/subscription.zig @@ -24,8 +24,7 @@ const err_ = @import("./error.zig"); const Error = err_.Error; const Status = err_.Status; -const thunk = @import("./thunk.zig"); -const checkUserDataType = @import("./thunk.zig").checkUserDataType; +const thunkhelper = @import("./thunk.zig"); pub const Subscription = opaque { pub const MessageCount = struct { @@ -171,13 +170,13 @@ pub const Subscription = opaque { pub fn setCompletionCallback( self: *Subscription, comptime T: type, - comptime callback: *const thunk.SimpleCallbackThunkSignature(T), + comptime callback: *const thunkhelper.SimpleCallbackThunkSignature(T), userdata: T, ) Error!void { return Status.fromInt(nats_c.natsSubscription_SetOnCompleteCB( @ptrCast(self), - thunk.makeSimpleCallbackThunk(T, callback), - @constCast(@ptrCast(userdata)), + thunkhelper.makeSimpleCallbackThunk(T, callback), + thunkhelper.opaqueFromUserdata(userdata), )).raise(); } }; @@ -189,15 +188,14 @@ const SubscriptionCallback = fn ( ?*anyopaque, ) callconv(.C) void; -pub fn SubscriptionCallbackSignature(comptime T: type) type { - return fn (T, *Connection, *Subscription, *Message) void; +pub fn SubscriptionCallbackSignature(comptime UDT: type) type { + return fn (UDT, *Connection, *Subscription, *Message) void; } pub fn makeSubscriptionCallbackThunk( - comptime T: type, - comptime callback: *const SubscriptionCallbackSignature(T), + comptime UDT: type, + comptime callback: *const SubscriptionCallbackSignature(UDT), ) *const SubscriptionCallback { - comptime checkUserDataType(T); return struct { fn thunk( conn: ?*nats_c.natsConnection, @@ -211,8 +209,7 @@ pub fn makeSubscriptionCallbackThunk( const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; const subscription: *Subscription = if (sub) |s| @ptrCast(s) else unreachable; - const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; - + const data = thunkhelper.userdataFromOpaque(UDT, userdata); callback(data, connection, subscription, message); } }.thunk; diff --git a/src/thunk.zig b/src/thunk.zig index c114213..c8972cd 100644 --- a/src/thunk.zig +++ b/src/thunk.zig @@ -16,68 +16,51 @@ const std = @import("std"); const nats_c = @import("./nats_c.zig").nats_c; -pub fn CallbackType(comptime T: type) type { - return switch (@typeInfo(T)) { - .optional => |info| ?CallbackType(info.child), - .pointer => |info| switch (info.size) { - .Slice => *const T, - else => T, - }, - else => *T, +const optional = if (@hasField(std.builtin.Type, "optional")) .optional else .Optional; +const pointer = if (@hasField(std.builtin.Type, "pointer")) .pointer else .Pointer; +const void_type = if (@hasField(std.builtin.Type, "void")) .void else .Void; +const null_type = if (@hasField(std.builtin.Type, "null")) .null else .Null; + +pub fn opaqueFromUserdata(userdata: anytype) ?*anyopaque { + checkUserDataType(@TypeOf(userdata)); + return switch (@typeInfo(@TypeOf(userdata))) { + optional, pointer => @constCast(@ptrCast(userdata)), + void_type => null, + else => @compileError("Unsupported userdata type " ++ @typeName(@TypeOf(userdata))), }; } -pub const checkUserDataType = if (@hasField(std.builtin.Type, "optional")) - checkUserDataType_14 -else - checkUserDataType_13; - -pub fn checkUserDataType_14(comptime T: type) void { - switch (@typeInfo(T)) { - .optional => |info| switch (@typeInfo(info.child)) { - .optional => @compileError( - "nats callbacks can only accept an (optional) single, many," ++ - " or c pointer as userdata. \"" ++ - @typeName(T) ++ "\" has more than one optional specifier.", - ), - else => checkUserDataType(info.child), - }, - .pointer => |info| switch (info.size) { - .Slice => @compileError( - "nats callbacks can only accept an (optional) single, many," ++ - " or c pointer as userdata, not slices. \"" ++ - @typeName(T) ++ "\" appears to be a slice.", - ), - else => {}, - }, - else => @compileError( - "nats callbacks can only accept an (optional) single, many," ++ - " or c pointer as userdata. \"" ++ - @typeName(T) ++ "\" is not a pointer type.", - ), - } +pub fn userdataFromOpaque(comptime UDT: type, userdata: ?*anyopaque) UDT { + comptime checkUserDataType(UDT); + return if (UDT == void) + void{} + else if (@typeInfo(UDT) == optional) + @alignCast(@ptrCast(userdata)) + else + @alignCast(@ptrCast(userdata.?)); } -pub fn checkUserDataType_13(comptime T: type) void { +pub fn checkUserDataType(comptime T: type) void { switch (@typeInfo(T)) { - .Optional => |info| switch (@typeInfo(info.child)) { - .Optional => @compileError( - "nats callbacks can only accept an (optional) single, many," ++ + optional => |info| switch (@typeInfo(info.child)) { + optional => @compileError( + "nats callbacks can only accept void or an (optional) single, many," ++ " or c pointer as userdata. \"" ++ @typeName(T) ++ "\" has more than one optional specifier.", ), else => checkUserDataType(info.child), }, - .Pointer => |info| switch (info.size) { + pointer => |info| switch (info.size) { .Slice => @compileError( - "nats callbacks can only accept an (optional) single, many," ++ + "nats callbacks can only accept void or an (optional) single, many," ++ " or c pointer as userdata, not slices. \"" ++ @typeName(T) ++ "\" appears to be a slice.", ), else => {}, }, + void_type => {}, else => @compileError( - "nats callbacks can only accept an (optional) single, many," ++ + "nats callbacks can only accept void or an (optional) single, many," ++ " or c pointer as userdata. \"" ++ @typeName(T) ++ "\" is not a pointer type.", ), @@ -86,19 +69,18 @@ pub fn checkUserDataType_13(comptime T: type) void { const SimpleCallback = fn (?*anyopaque) callconv(.C) void; -pub fn SimpleCallbackThunkSignature(comptime T: type) type { - return fn (T) void; +pub fn SimpleCallbackThunkSignature(comptime UDT: type) type { + return fn (UDT) void; } pub fn makeSimpleCallbackThunk( - comptime T: type, - comptime callback: *const SimpleCallbackThunkSignature(T), + comptime UDT: type, + comptime callback: *const SimpleCallbackThunkSignature(UDT), ) *const SimpleCallback { - comptime checkUserDataType(T); + comptime checkUserDataType(UDT); return struct { fn thunk(userdata: ?*anyopaque) callconv(.C) void { - const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; - callback(data); + callback(userdataFromOpaque(UDT, userdata)); } }.thunk; } diff --git a/tests/connection.zig b/tests/connection.zig index d571ba4..32fd48c 100644 --- a/tests/connection.zig +++ b/tests/connection.zig @@ -99,42 +99,46 @@ test "nats.Connection" { connection.drainTimeout(1000) catch {}; } -fn reconnectDelayHandler(userdata: *const u32, connection: *nats.Connection, attempts: c_int) i64 { - _ = userdata; - _ = connection; - _ = attempts; +fn callbacks(comptime UDT: type) type { + return struct { + fn reconnectDelayHandler(userdata: UDT, connection: *nats.Connection, attempts: c_int) i64 { + _ = userdata; + _ = connection; + _ = attempts; - return 0; -} + return 0; + } -fn errorHandler( - userdata: *const u32, - connection: *nats.Connection, - subscription: *nats.Subscription, - status: nats.Status, -) void { - _ = userdata; - _ = connection; - _ = subscription; - _ = status; -} + fn errorHandler( + userdata: UDT, + connection: *nats.Connection, + subscription: *nats.Subscription, + status: nats.Status, + ) void { + _ = userdata; + _ = connection; + _ = subscription; + _ = status; + } -fn connectionHandler(userdata: *const u32, connection: *nats.Connection) void { - _ = userdata; - _ = connection; -} + fn connectionHandler(userdata: UDT, connection: *nats.Connection) void { + _ = userdata; + _ = connection; + } -fn jwtHandler(userdata: *const u32) nats.JwtResponseOrError { - _ = userdata; - // return .{ .jwt = std.heap.raw_c_allocator.dupeZ(u8, "abcdef") catch @panic("no!") }; - return .{ .error_message = std.heap.raw_c_allocator.dupeZ(u8, "dang") catch @panic("no!") }; -} + fn jwtHandler(userdata: UDT) nats.JwtResponseOrError { + _ = userdata; + // return .{ .jwt = std.heap.raw_c_allocator.dupeZ(u8, "abcdef") catch @panic("no!") }; + return .{ .error_message = std.heap.raw_c_allocator.dupeZ(u8, "dang") catch @panic("no!") }; + } -fn signatureHandler(userdata: *const u32, nonce: [:0]const u8) nats.SignatureResponseOrError { - _ = userdata; - _ = nonce; - // return .{ .signature = std.heap.raw_c_allocator.dupe(u8, "01230123") catch @panic("no!") }; - return .{ .error_message = std.heap.raw_c_allocator.dupeZ(u8, "whoops") catch @panic("no!") }; + fn signatureHandler(userdata: UDT, nonce: [:0]const u8) nats.SignatureResponseOrError { + _ = userdata; + _ = nonce; + // return .{ .signature = std.heap.raw_c_allocator.dupe(u8, "01230123") catch @panic("no!") }; + return .{ .error_message = std.heap.raw_c_allocator.dupeZ(u8, "whoops") catch @panic("no!") }; + } + }; } test "nats.ConnectionOptions" { @@ -164,14 +168,26 @@ test "nats.ConnectionOptions" { try options.setMaxReconnect(10); try options.setReconnectWait(500); try options.setReconnectJitter(100, 200); - try options.setCustomReconnectDelay(*const u32, reconnectDelayHandler, &userdata); + try options.setCustomReconnectDelay(*const u32, callbacks(*const u32).reconnectDelayHandler, &userdata); + try options.setCustomReconnectDelay(void, callbacks(void).reconnectDelayHandler, {}); + try options.setCustomReconnectDelay(?*const u32, callbacks(?*const u32).reconnectDelayHandler, null); try options.setReconnectBufSize(1024); try options.setMaxPendingMessages(50); - try options.setErrorHandler(*const u32, errorHandler, &userdata); - try options.setClosedCallback(*const u32, connectionHandler, &userdata); - try options.setDisconnectedCallback(*const u32, connectionHandler, &userdata); - try options.setDiscoveredServersCallback(*const u32, connectionHandler, &userdata); - try options.setLameDuckModeCallback(*const u32, connectionHandler, &userdata); + try options.setErrorHandler(*const u32, callbacks(*const u32).errorHandler, &userdata); + try options.setErrorHandler(void, callbacks(void).errorHandler, {}); + try options.setErrorHandler(?*const u32, callbacks(?*const u32).errorHandler, null); + try options.setClosedCallback(*const u32, callbacks(*const u32).connectionHandler, &userdata); + try options.setClosedCallback(void, callbacks(void).connectionHandler, {}); + try options.setClosedCallback(?*const u32, callbacks(?*const u32).connectionHandler, null); + try options.setDisconnectedCallback(*const u32, callbacks(*const u32).connectionHandler, &userdata); + try options.setDisconnectedCallback(void, callbacks(void).connectionHandler, {}); + try options.setDisconnectedCallback(?*const u32, callbacks(?*const u32).connectionHandler, null); + try options.setDiscoveredServersCallback(*const u32, callbacks(*const u32).connectionHandler, &userdata); + try options.setDiscoveredServersCallback(void, callbacks(void).connectionHandler, {}); + try options.setDiscoveredServersCallback(?*const u32, callbacks(?*const u32).connectionHandler, null); + try options.setLameDuckModeCallback(*const u32, callbacks(*const u32).connectionHandler, &userdata); + try options.setLameDuckModeCallback(void, callbacks(void).connectionHandler, {}); + try options.setLameDuckModeCallback(?*const u32, callbacks(?*const u32).connectionHandler, null); try options.ignoreDiscoveredServers(true); try options.useGlobalMessageDelivery(false); try options.ipResolutionOrder(.ipv4_first); @@ -179,8 +195,11 @@ test "nats.ConnectionOptions" { try options.useOldRequestStyle(false); try options.setFailRequestsOnDisconnect(true); try options.setNoEcho(true); - try options.setRetryOnFailedConnect(*const u32, connectionHandler, true, &userdata); - try options.setUserCredentialsCallbacks(*const u32, *const u32, jwtHandler, signatureHandler, &userdata, &userdata); + try options.setRetryOnFailedConnect(*const u32, callbacks(*const u32).connectionHandler, true, &userdata); + try options.setRetryOnFailedConnect(void, callbacks(void).connectionHandler, true, {}); + try options.setRetryOnFailedConnect(?*const u32, callbacks(?*const u32).connectionHandler, true, null); + try options.setUserCredentialsCallbacks(*const u32, *const u32, callbacks(*const u32).jwtHandler, callbacks(*const u32).signatureHandler, &userdata, &userdata); + try options.setUserCredentialsCallbacks(void, void, callbacks(void).jwtHandler, callbacks(void).signatureHandler, {}, {}); try options.setWriteDeadline(5); try options.disableNoResponders(true); try options.setCustomInboxPrefix("_FOOBOX"); diff --git a/tests/main.zig b/tests/main.zig index 2cf0fb6..75513c3 100644 --- a/tests/main.zig +++ b/tests/main.zig @@ -1,9 +1,11 @@ // This file is licensed under the CC0 1.0 license. // See: https://creativecommons.org/publicdomain/zero/1.0/legalcode -test { - _ = @import("./nats.zig"); - _ = @import("./connection.zig"); - _ = @import("./message.zig"); - _ = @import("./subscription.zig"); +comptime { + if (@import("builtin").is_test) { + _ = @import("./nats.zig"); + _ = @import("./connection.zig"); + _ = @import("./message.zig"); + _ = @import("./subscription.zig"); + } } diff --git a/tests/subscription.zig b/tests/subscription.zig index b2feab8..6b6bd4e 100644 --- a/tests/subscription.zig +++ b/tests/subscription.zig @@ -190,3 +190,109 @@ test "nats.Subscription (async)" { ); } } + +fn onVoidMessage( + userdata: void, + connection: *nats.Connection, + subscription: *nats.Subscription, + message: *nats.Message, +) void { + _ = subscription; + _ = userdata; + + if (message.getReply()) |reply| { + connection.publish(reply, "greetings") catch @panic("OH NO"); + } else @panic("HOW"); +} + +fn onVoidClose(userdata: void) void { + _ = userdata; +} + +test "nats.Subscription (async, void)" { + var server = try util.TestServer.launch(.{}); + defer server.stop(); + + try nats.init(nats.default_spin_count); + defer nats.deinit(); + + const connection = try nats.Connection.connectTo(server.url); + defer connection.destroy(); + + const message_subject: [:0]const u8 = "hello"; + const message_reply: [:0]const u8 = "reply"; + const message_data: [:0]const u8 = "world"; + + const message = try nats.Message.create(message_subject, message_reply, message_data); + defer message.destroy(); + + { + { + const subscription = try connection.subscribe(void, message_subject, onVoidMessage, {}); + defer subscription.destroy(); + + try subscription.setCompletionCallback(void, onVoidClose, {}); + + const response = try connection.requestMessage(message, 1000); + try std.testing.expectEqualStrings( + "greetings", + response.getData() orelse return error.TestUnexpectedResult, + ); + } + // we have to sleep to allow the close callback to run. + nats.sleep(1); + } +} + +fn onNullMessage( + userdata: ?*void, + connection: *nats.Connection, + subscription: *nats.Subscription, + message: *nats.Message, +) void { + _ = subscription; + _ = userdata; + + if (message.getReply()) |reply| { + connection.publish(reply, "greetings") catch @panic("OH NO"); + } else @panic("HOW"); +} + +fn onNullClose(userdata: ?*void) void { + _ = userdata; +} + +test "nats.Subscription (async, null)" { + var server = try util.TestServer.launch(.{}); + defer server.stop(); + + try nats.init(nats.default_spin_count); + defer nats.deinit(); + + const connection = try nats.Connection.connectTo(server.url); + defer connection.destroy(); + + const message_subject: [:0]const u8 = "hello"; + const message_reply: [:0]const u8 = "reply"; + const message_data: [:0]const u8 = "world"; + + const message = try nats.Message.create(message_subject, message_reply, message_data); + defer message.destroy(); + + { + { + const subscription = try connection.subscribe(?*void, message_subject, onNullMessage, null); + defer subscription.destroy(); + + try subscription.setCompletionCallback(?*void, onNullClose, null); + + const response = try connection.requestMessage(message, 1000); + try std.testing.expectEqualStrings( + "greetings", + response.getData() orelse return error.TestUnexpectedResult, + ); + } + // we have to sleep to allow the close callback to run. + nats.sleep(1); + } +}