diff --git a/examples/request_reply.zig b/examples/request_reply.zig index 1f167cf..c69f1b6 100644 --- a/examples/request_reply.zig +++ b/examples/request_reply.zig @@ -29,10 +29,10 @@ pub fn main() !void { defer connection.destroy(); var count: u32 = 0; - const subscription = try connection.subscribe(u32, "channel", onMessage, &count); + const subscription = try connection.subscribe(*u32, "channel", onMessage, &count); defer subscription.destroy(); - while (count < 10) : (nats.sleep(1000)) { + while (count < 10) : (nats.sleep(100)) { const reply = try connection.request("channel", "greetings", 1000); defer reply.destroy(); diff --git a/src/connection.zig b/src/connection.zig index 312c29d..018ade6 100644 --- a/src/connection.zig +++ b/src/connection.zig @@ -18,24 +18,21 @@ pub const nats_c = @cImport({ @cInclude("nats/nats.h"); }); -const sub_ = @import("./subscription.zig"); -const Subscription = sub_.Subscription; -const SubscriptionCallbackSignature = sub_.SubscriptionCallbackSignature; -const makeSubscriptionCallbackThunk = sub_.makeSubscriptionCallbackThunk; +const Subscription = @import("./subscription.zig").Subscription; +const SubscriptionCallbackSignature = @import("./subscription.zig").SubscriptionCallbackSignature; +const makeSubscriptionCallbackThunk = @import("./subscription.zig").makeSubscriptionCallbackThunk; -const msg_ = @import("./message.zig"); -const Message = msg_.Message; +const Message = @import("./message.zig").Message; -const err_ = @import("./error.zig"); -const Error = err_.Error; -const Status = err_.Status; -const ErrorInfo = err_.ErrorInfo; +const Error = @import("./error.zig").Error; +const Status = @import("./error.zig").Status; +const ErrorInfo = @import("./error.zig").ErrorInfo; -const sta_ = @import("./statistics.zig"); -const Statistics = sta_.Statistics; -const StatsCounts = sta_.StatsCounts; +const Statistics = @import("./statistics.zig").Statistics; +const StatsCounts = @import("./statistics.zig").StatsCounts; const thunk = @import("./thunk.zig"); +const checkUserDataType = @import("./thunk.zig").checkUserDataType; pub const default_server_url: [:0]const u8 = nats_c.NATS_DEFAULT_URL; @@ -343,7 +340,7 @@ pub const Connection = opaque { comptime T: type, subject: [:0]const u8, callback: SubscriptionCallbackSignature(T), - userdata: *T, + userdata: T, ) Error!*Subscription { var sub: *Subscription = undefined; const status = Status.fromInt(nats_c.natsConnection_Subscribe( @@ -362,7 +359,7 @@ pub const Connection = opaque { subject: [:0]const u8, timeout: i64, callback: SubscriptionCallbackSignature(T), - userdata: *T, + userdata: T, ) Error!*Subscription { var sub: *Subscription = undefined; @@ -396,7 +393,7 @@ pub const Connection = opaque { subject: [:0]const u8, queue_group: [:0]const u8, callback: SubscriptionCallbackSignature(T), - userdata: *T, + userdata: T, ) Error!*Subscription { var sub: *Subscription = undefined; @@ -419,7 +416,7 @@ pub const Connection = opaque { queue_group: [:0]const u8, timeout: i64, callback: SubscriptionCallbackSignature(T), - userdata: *T, + userdata: T, ) Error!*Subscription { var sub: *Subscription = undefined; @@ -502,7 +499,7 @@ pub const ConnectionOptions = opaque { self: *ConnectionOptions, comptime T: type, comptime callback: *const TokenCallbackSignature(T), - userdata: *T, + userdata: T, ) Error!void { return Status.fromInt(nats_c.natsOptions_SetTokenHandler( @ptrCast(self), @@ -641,7 +638,7 @@ pub const ConnectionOptions = opaque { self: *ConnectionOptions, comptime T: type, comptime callback: *const ReconnectDelayCallbackSignature(T), - userdata: *T, + userdata: T, ) Error!void { return Status.fromInt( nats_c.natsOptions_SetCustomReconnectDelay( @@ -668,7 +665,7 @@ pub const ConnectionOptions = opaque { self: *ConnectionOptions, comptime T: type, comptime callback: *const ErrorHandlerCallbackSignature(T), - userdata: *T, + userdata: T, ) Error!void { return Status.fromInt( nats_c.natsOptions_SetErrorHandler( @@ -683,7 +680,7 @@ pub const ConnectionOptions = opaque { self: *ConnectionOptions, comptime T: type, comptime callback: *const ConnectionCallbackSignature(T), - userdata: *T, + userdata: T, ) Error!void { return Status.fromInt(nats_c.natsOptions_SetClosedCB( @ptrCast(self), @@ -696,7 +693,7 @@ pub const ConnectionOptions = opaque { self: *ConnectionOptions, comptime T: type, comptime callback: *const ConnectionCallbackSignature(T), - userdata: *T, + userdata: T, ) Error!void { return Status.fromInt(nats_c.natsOptions_SetClosedCB( @ptrCast(self), @@ -709,7 +706,7 @@ pub const ConnectionOptions = opaque { self: *ConnectionOptions, comptime T: type, comptime callback: *const ConnectionCallbackSignature(T), - userdata: *T, + userdata: T, ) Error!void { return Status.fromInt(nats_c.natsOptions_SetClosedCB( @ptrCast(self), @@ -722,7 +719,7 @@ pub const ConnectionOptions = opaque { self: *ConnectionOptions, comptime T: type, comptime callback: *const ConnectionCallbackSignature(T), - userdata: *T, + userdata: T, ) Error!void { return Status.fromInt(nats_c.natsOptions_SetClosedCB( @ptrCast(self), @@ -735,7 +732,7 @@ pub const ConnectionOptions = opaque { self: *ConnectionOptions, comptime T: type, comptime callback: *const ConnectionCallbackSignature(T), - userdata: *T, + userdata: T, ) Error!void { return Status.fromInt(nats_c.natsOptions_SetClosedCB( @ptrCast(self), @@ -752,7 +749,7 @@ pub const ConnectionOptions = opaque { comptime read_callback: *const AttachEventLoopCallbackSignature(T), comptime write_callback: *const AttachEventLoopCallbackSignature(T), comptime detach_callback: *const thunk.SimpleCallbackSignature(T), - loop: *L, + loop: L, ) Error!void { return Status.fromInt(nats_c.natsOptions_SetEventLoop( @ptrCast(self), @@ -819,7 +816,7 @@ pub const ConnectionOptions = opaque { comptime T: type, comptime callback: *const ConnectionCallbackSignature(T), retry: bool, - userdata: *T, + userdata: T, ) Error!void { return Status.fromInt(nats_c.natsOptions_SetRetryOnFailedConnect( @ptrCast(self), @@ -835,8 +832,8 @@ pub const ConnectionOptions = opaque { comptime U: type, comptime jwt_callback: *const JwtHandlerCallbackSignature(T), comptime sig_callback: *const SignatureHandlerCallbackSignature(U), - jwt_userdata: *T, - sig_userdata: *U, + jwt_userdata: T, + sig_userdata: U, ) Error!void { return Status.fromInt(nats_c.natsOptions_SetUserCredentialsCallbacks( @ptrCast(self), @@ -871,7 +868,7 @@ pub const ConnectionOptions = opaque { comptime T: type, comptime sig_callback: *const SignatureHandlerCallbackSignature(T), pub_key: [:0]const u8, - sig_userdata: *T, + sig_userdata: T, ) Error!void { return Status.fromInt(nats_c.natsOptions_SetUserCredentialsCallbacks( @ptrCast(self), @@ -919,16 +916,17 @@ pub const ConnectionOptions = opaque { const TokenCallback = fn (?*anyopaque) callconv(.C) [*c]const u8; pub fn TokenCallbackSignature(comptime T: type) type { - return fn (*T) [:0]const u8; + return fn (T) [:0]const u8; } fn makeTokenCallbackThunk( comptime T: type, comptime callback: *const TokenCallbackSignature(T), ) *const TokenCallback { + comptime checkUserDataType(T); return struct { fn thunk(userdata: ?*anyopaque) callconv(.C) [*c]const u8 { - const data: *T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; return callback(data).ptr; } }.thunk; @@ -937,17 +935,18 @@ fn makeTokenCallbackThunk( const ConnectionCallback = fn (?*nats_c.natsConnection, ?*anyopaque) callconv(.C) void; pub fn ConnectionCallbackSignature(comptime T: type) type { - return fn (*T, *Connection) void; + return fn (T, *Connection) void; } fn makeConnectionCallbackThunk( comptime T: type, comptime callback: *const ConnectionCallbackSignature(T), ) *const ConnectionCallback { + comptime checkUserDataType(T); return struct { fn thunk(conn: ?*nats_c.natsConnection, userdata: ?*anyopaque) callconv(.C) void { const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; - const data: *T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; callback(data, connection); } }.thunk; @@ -956,13 +955,14 @@ fn makeConnectionCallbackThunk( const ReconnectDelayCallback = fn (?*nats_c.natsConnection, c_int, ?*anyopaque) callconv(.C) i64; pub fn ReconnectDelayCallbackSignature(comptime T: type) type { - return fn (*T, *Connection, c_int) i64; + return fn (T, *Connection, c_int) i64; } fn makeReconnectDelayCallbackThunk( comptime T: type, comptime callback: *const ReconnectDelayCallbackSignature(T), ) *const ReconnectDelayCallback { + comptime checkUserDataType(T); return struct { fn thunk( conn: ?*nats_c.natsConnection, @@ -970,7 +970,7 @@ fn makeReconnectDelayCallbackThunk( userdata: ?*anyopaque, ) callconv(.C) i64 { const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; - const data: *T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; return callback(data, connection, attempts); } }.thunk; @@ -984,13 +984,14 @@ const ErrorHandlerCallback = fn ( ) callconv(.C) void; pub fn ErrorHandlerCallbackSignature(comptime T: type) type { - return fn (*T, *Connection, *Subscription, Status) void; + return fn (T, *Connection, *Subscription, Status) void; } fn makeErrorHandlerCallbackThunk( comptime T: type, comptime callback: *const ErrorHandlerCallbackSignature(T), ) *const ErrorHandlerCallback { + comptime checkUserDataType(T); return struct { fn thunk( conn: ?*nats_c.natsConnection, @@ -1000,7 +1001,7 @@ fn makeErrorHandlerCallbackThunk( ) callconv(.C) void { const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; const subscription: *Subscription = if (sub) |s| @ptrCast(s) else unreachable; - const data: *T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; callback(data, connection, subscription, Status.fromInt(status)); } @@ -1016,7 +1017,7 @@ const AttachEventLoopCallback = fn ( ) callconv(.C) nats_c.natsStatus; pub fn AttachEventLoopCallbackSignature(comptime T: type, comptime L: type) type { - return fn (*L, *Connection, c_int) anyerror!*T; + return fn (L, *Connection, c_int) anyerror!T; } fn makeAttachEventLoopCallbackThunk( @@ -1024,6 +1025,8 @@ fn makeAttachEventLoopCallbackThunk( comptime L: type, comptime callback: *const AttachEventLoopCallbackSignature(T, L), ) *const ReconnectDelayCallback { + comptime checkUserDataType(T); + comptime checkUserDataType(L); return struct { fn thunk( userdata: *?*anyopaque, @@ -1032,7 +1035,7 @@ fn makeAttachEventLoopCallbackThunk( sock: ?*nats_c.natsSock, ) callconv(.C) nats_c.natsStatus { const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; - const ev_loop: *L = if (loop) |l| @ptrCast(l) else unreachable; + const ev_loop: L = if (loop) |l| @alignCast(@ptrCast(l)) else unreachable; userdata.* = callback(ev_loop, connection, sock) catch |err| return Status.fromError(err).toInt(); @@ -1045,13 +1048,14 @@ fn makeAttachEventLoopCallbackThunk( const EventLoopAddRemoveCallback = fn (?*nats_c.natsConnection, c_int, ?*anyopaque) callconv(.C) nats_c.natsStatus; pub fn EventLoopAddRemoveCallbackSignature(comptime T: type) type { - return fn (*T, *Connection, c_int) anyerror!void; + return fn (T, *Connection, c_int) anyerror!void; } fn makeEventLoopAddRemoveCallbackThunk( comptime T: type, comptime callback: *const EventLoopAddRemoveCallbackSignature(T), ) *const ReconnectDelayCallback { + comptime checkUserDataType(T); return struct { fn thunk( conn: ?*nats_c.natsConnection, @@ -1059,7 +1063,7 @@ fn makeEventLoopAddRemoveCallbackThunk( userdata: ?*anyopaque, ) callconv(.C) nats_c.natsStatus { const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; - const data: *T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; callback(data, connection, attempts) catch |err| return Status.fromError(err).toInt(); @@ -1071,18 +1075,19 @@ fn makeEventLoopAddRemoveCallbackThunk( const EventLoopDetachCallback = fn (?*anyopaque) callconv(.C) nats_c.natsStatus; pub fn EventLoopDetachCallbackSignature(comptime T: type) type { - return fn (*T) anyerror!void; + return fn (T) anyerror!void; } fn makeEventLoopDetachCallbackThunk( comptime T: type, comptime callback: *const EventLoopDetachCallbackSignature(T), ) *const ReconnectDelayCallback { + comptime checkUserDataType(T); return struct { fn thunk( userdata: ?*anyopaque, ) callconv(.C) nats_c.natsStatus { - const data: *T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; callback(data) catch |err| return Status.fromError(err).toInt(); return nats_c.NATS_OK; } @@ -1099,20 +1104,21 @@ pub const JwtResponseOrError = union(enum) { }; pub fn JwtHandlerCallbackSignature(comptime T: type) type { - return fn (*T) JwtResponseOrError; + return fn (T) JwtResponseOrError; } fn makeJwtHandlerCallbackThunk( comptime T: type, comptime callback: *const JwtHandlerCallbackSignature(T), ) *const JwtHandlerCallback { + comptime checkUserDataType(T); return struct { fn thunk( jwt_out_raw: ?*?[*:0]u8, err_out_raw: ?*?[*:0]u8, userdata: ?*anyopaque, ) callconv(.C) nats_c.natsStatus { - const data: *T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; const err_out = err_out_raw orelse unreachable; const jwt_out = jwt_out_raw orelse unreachable; @@ -1140,13 +1146,14 @@ pub const SignatureResponseOrError = union(enum) { }; pub fn SignatureHandlerCallbackSignature(comptime T: type) type { - return fn (*T, [:0]const u8) SignatureResponseOrError; + return fn (T, [:0]const u8) SignatureResponseOrError; } fn makeSignatureHandlerCallbackThunk( comptime T: type, comptime callback: *const SignatureHandlerCallbackSignature(T), ) *const SignatureHandlerCallback { + comptime checkUserDataType(T); return struct { fn thunk( err_out_raw: ?*?[*:0]u8, @@ -1155,7 +1162,7 @@ fn makeSignatureHandlerCallbackThunk( nonsense: ?[*:0]const u8, userdata: ?*anyopaque, ) callconv(.C) nats_c.natsStatus { - const data: *T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; const nonce = nonsense orelse unreachable; const err_out = err_out_raw orelse unreachable; const sig_out = sig_out_raw orelse unreachable; diff --git a/src/subscription.zig b/src/subscription.zig index 27c0bfb..f2bf350 100644 --- a/src/subscription.zig +++ b/src/subscription.zig @@ -27,6 +27,7 @@ const Error = err_.Error; const Status = err_.Status; const thunk = @import("./thunk.zig"); +const checkUserDataType = @import("./thunk.zig").checkUserDataType; pub const Subscription = opaque { pub const MessageCount = struct { @@ -191,13 +192,14 @@ const SubscriptionCallback = fn ( ) callconv(.C) void; pub fn SubscriptionCallbackSignature(comptime T: type) type { - return fn (*T, *Connection, *Subscription, *Message) void; + return fn (T, *Connection, *Subscription, *Message) void; } pub fn makeSubscriptionCallbackThunk( comptime T: type, comptime callback: *const SubscriptionCallbackSignature(T), ) *const SubscriptionCallback { + comptime checkUserDataType(T); return struct { fn thunk( conn: ?*nats_c.natsConnection, @@ -211,7 +213,7 @@ pub fn makeSubscriptionCallbackThunk( const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; const subscription: *Subscription = if (sub) |s| @ptrCast(s) else unreachable; - const data: *T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; + const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; callback(data, connection, subscription, message); } diff --git a/src/thunk.zig b/src/thunk.zig index d430077..52ad9b8 100644 --- a/src/thunk.zig +++ b/src/thunk.zig @@ -18,16 +18,43 @@ pub const nats_c = @cImport({ @cInclude("nats/nats.h"); }); +pub fn checkUserDataType(comptime T: type) void { + switch (@typeInfo(T)) { + .Optional => |info| switch (@typeInfo(info.child)) { + .Optional => @compileError( + "nats callbacks can only accept an (optional) single, many," ++ + " or c pointer as userdata, not slices. \"" ++ + @typeName(T) ++ "\" has more than one optional specifier.", + ), + else => checkUserDataType(info.child), + }, + .Pointer => |info| switch (info.size) { + .Slice => @compileError( + "nats callbacks can only accept an (optional) single, many," ++ + " or c pointer as userdata, not slices. \"" ++ + @typeName(T) ++ "\" appears to be a slice.", + ), + else => {}, + }, + else => @compileError( + "nats callbacks can only accept an (optional) single, many," ++ + " or c pointer as userdata, not slices. \"" ++ + @typeName(T) ++ "\" is not a pointer type.", + ), + } +} + const SimpleCallback = fn (?*anyopaque) callconv(.C) void; pub fn SimpleCallbackThunkSignature(comptime T: type) type { - return fn (*T) void; + return fn (T) void; } pub fn makeSimpleCallbackThunk( comptime T: type, comptime callback: *const SimpleCallbackThunkSignature(T), ) *const SimpleCallback { + comptime checkUserDataType(T); return struct { fn thunk(userdata: ?*anyopaque) callconv(.C) void { const data: *T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; diff --git a/tests/connection.zig b/tests/connection.zig index f952475..1425948 100644 --- a/tests/connection.zig +++ b/tests/connection.zig @@ -164,14 +164,14 @@ test "nats.ConnectionOptions" { try options.setMaxReconnect(10); try options.setReconnectWait(500); try options.setReconnectJitter(100, 200); - try options.setCustomReconnectDelay(u32, reconnectDelayHandler, &userdata); + try options.setCustomReconnectDelay(*u32, reconnectDelayHandler, &userdata); try options.setReconnectBufSize(1024); try options.setMaxPendingMessages(50); - try options.setErrorHandler(u32, errorHandler, &userdata); - try options.setClosedCallback(u32, connectionHandler, &userdata); - try options.setDisconnectedCallback(u32, connectionHandler, &userdata); - try options.setDiscoveredServersCallback(u32, connectionHandler, &userdata); - try options.setLameDuckModeCallback(u32, connectionHandler, &userdata); + try options.setErrorHandler(*u32, errorHandler, &userdata); + try options.setClosedCallback(*u32, connectionHandler, &userdata); + try options.setDisconnectedCallback(*u32, connectionHandler, &userdata); + try options.setDiscoveredServersCallback(*u32, connectionHandler, &userdata); + try options.setLameDuckModeCallback(*u32, connectionHandler, &userdata); try options.ignoreDiscoveredServers(true); try options.useGlobalMessageDelivery(false); try options.ipResolutionOrder(.ipv4_first); @@ -179,8 +179,8 @@ test "nats.ConnectionOptions" { try options.useOldRequestStyle(false); try options.setFailRequestsOnDisconnect(true); try options.setNoEcho(true); - try options.setRetryOnFailedConnect(u32, connectionHandler, true, &userdata); - try options.setUserCredentialsCallbacks(u32, u32, jwtHandler, signatureHandler, &userdata, &userdata); + try options.setRetryOnFailedConnect(*u32, connectionHandler, true, &userdata); + try options.setUserCredentialsCallbacks(*u32, *u32, jwtHandler, signatureHandler, &userdata, &userdata); try options.setWriteDeadline(5); try options.disableNoResponders(true); try options.setCustomInboxPrefix("_FOOBOX"); @@ -200,7 +200,7 @@ test "nats.ConnectionOptions (crypto edition)" { defer options.destroy(); var userdata: u32 = 0; - try options.setTokenHandler(u32, tokenHandler, &userdata); + try options.setTokenHandler(*u32, tokenHandler, &userdata); try options.setSecure(false); try options.setCertificatesChain(rsa_cert, rsa_key); try options.setCiphers("-ALL:HIGH"); diff --git a/tests/subscription.zig b/tests/subscription.zig index c8e33b0..bd508c2 100644 --- a/tests/subscription.zig +++ b/tests/subscription.zig @@ -113,7 +113,7 @@ test "nats.Subscription (async)" { { var count: u32 = 0; - const subscription = try connection.subscribe(u32, message_subject, onMessage, &count); + const subscription = try connection.subscribe(*u32, message_subject, onMessage, &count); defer subscription.destroy(); const response = try connection.requestMessage(message, 1000); @@ -126,7 +126,7 @@ test "nats.Subscription (async)" { { var count: u32 = 0; const subscription = try connection.subscribeTimeout( - u32, + *u32, message_subject, 1000, onMessage, @@ -144,7 +144,7 @@ test "nats.Subscription (async)" { { var count: u32 = 0; const subscription = try connection.queueSubscribe( - u32, + *u32, message_subject, "queuegroup", onMessage, @@ -162,7 +162,7 @@ test "nats.Subscription (async)" { { var count: u32 = 0; const subscription = try connection.queueSubscribeTimeout( - u32, + *u32, message_subject, "queuegroup", 1000,