thunk: rework handling of userdata

A number of cases were not handled properly: even though optional
userdata was allowed, the handling would cause unreachable to be
reached if null handles were actually passed. Also, being able to use
void to specify "no userdata please" is useful; however, in this case
we do have to still pass NULL to the C library calls.

The casting logic has been pulled out to some helper functions, which
make it more consistent. Some mediocre additional test coverage has
been added as well.
This commit is contained in:
torque 2024-10-01 23:17:20 -07:00
parent 68a98232a1
commit c82911682e
Signed by: torque
SSH Key Fingerprint: SHA256:nCrXefBNo6EbjNSQhv0nXmEg/VuNq3sMF5b8zETw3Tk
6 changed files with 282 additions and 184 deletions

View File

@ -29,8 +29,7 @@ const ErrorInfo = @import("./error.zig").ErrorInfo;
const Statistics = @import("./statistics.zig").Statistics; const Statistics = @import("./statistics.zig").Statistics;
const StatsCounts = @import("./statistics.zig").StatsCounts; const StatsCounts = @import("./statistics.zig").StatsCounts;
const thunk = @import("./thunk.zig"); const thunkhelper = @import("./thunk.zig");
const checkUserDataType = @import("./thunk.zig").checkUserDataType;
pub const default_server_url: [:0]const u8 = nats_c.NATS_DEFAULT_URL; pub const default_server_url: [:0]const u8 = nats_c.NATS_DEFAULT_URL;
@ -346,7 +345,7 @@ pub const Connection = opaque {
@ptrCast(self), @ptrCast(self),
subject.ptr, subject.ptr,
makeSubscriptionCallbackThunk(T, callback), makeSubscriptionCallbackThunk(T, callback),
@constCast(@ptrCast(userdata)), thunkhelper.opaqueFromUserdata(userdata),
)); ));
return status.toError() orelse sub; return status.toError() orelse sub;
} }
@ -367,7 +366,7 @@ pub const Connection = opaque {
subject.ptr, subject.ptr,
timeout, timeout,
makeSubscriptionCallbackThunk(T, callback), makeSubscriptionCallbackThunk(T, callback),
@constCast(@ptrCast(userdata)), thunkhelper.opaqueFromUserdata(userdata),
)); ));
return status.toError() orelse sub; return status.toError() orelse sub;
@ -401,7 +400,7 @@ pub const Connection = opaque {
subject.ptr, subject.ptr,
queue_group.ptr, queue_group.ptr,
makeSubscriptionCallbackThunk(T, callback), makeSubscriptionCallbackThunk(T, callback),
@constCast(@ptrCast(userdata)), thunkhelper.opaqueFromUserdata(userdata),
)); ));
return status.toError() orelse sub; return status.toError() orelse sub;
@ -425,7 +424,7 @@ pub const Connection = opaque {
queue_group.ptr, queue_group.ptr,
timeout, timeout,
makeSubscriptionCallbackThunk(T, callback), makeSubscriptionCallbackThunk(T, callback),
@constCast(@ptrCast(userdata)), thunkhelper.opaqueFromUserdata(userdata),
)); ));
return status.toError() orelse sub; return status.toError() orelse sub;
@ -502,7 +501,7 @@ pub const ConnectionOptions = opaque {
return Status.fromInt(nats_c.natsOptions_SetTokenHandler( return Status.fromInt(nats_c.natsOptions_SetTokenHandler(
@ptrCast(self), @ptrCast(self),
makeTokenCallbackThunk(T, callback), makeTokenCallbackThunk(T, callback),
@constCast(@ptrCast(userdata)), thunkhelper.opaqueFromUserdata(userdata),
)).raise(); )).raise();
} }
@ -642,7 +641,7 @@ pub const ConnectionOptions = opaque {
nats_c.natsOptions_SetCustomReconnectDelay( nats_c.natsOptions_SetCustomReconnectDelay(
@ptrCast(self), @ptrCast(self),
makeReconnectDelayCallbackThunk(T, callback), makeReconnectDelayCallbackThunk(T, callback),
@constCast(@ptrCast(userdata)), thunkhelper.opaqueFromUserdata(userdata),
), ),
).raise(); ).raise();
} }
@ -669,7 +668,7 @@ pub const ConnectionOptions = opaque {
nats_c.natsOptions_SetErrorHandler( nats_c.natsOptions_SetErrorHandler(
@ptrCast(self), @ptrCast(self),
makeErrorHandlerCallbackThunk(T, callback), makeErrorHandlerCallbackThunk(T, callback),
@constCast(@ptrCast(userdata)), thunkhelper.opaqueFromUserdata(userdata),
), ),
).raise(); ).raise();
} }
@ -683,7 +682,7 @@ pub const ConnectionOptions = opaque {
return Status.fromInt(nats_c.natsOptions_SetClosedCB( return Status.fromInt(nats_c.natsOptions_SetClosedCB(
@ptrCast(self), @ptrCast(self),
makeConnectionCallbackThunk(T, callback), makeConnectionCallbackThunk(T, callback),
@constCast(@ptrCast(userdata)), thunkhelper.opaqueFromUserdata(userdata),
)).raise(); )).raise();
} }
@ -696,7 +695,7 @@ pub const ConnectionOptions = opaque {
return Status.fromInt(nats_c.natsOptions_SetClosedCB( return Status.fromInt(nats_c.natsOptions_SetClosedCB(
@ptrCast(self), @ptrCast(self),
makeConnectionCallbackThunk(T, callback), makeConnectionCallbackThunk(T, callback),
@constCast(@ptrCast(userdata)), thunkhelper.opaqueFromUserdata(userdata),
)).raise(); )).raise();
} }
@ -709,7 +708,7 @@ pub const ConnectionOptions = opaque {
return Status.fromInt(nats_c.natsOptions_SetClosedCB( return Status.fromInt(nats_c.natsOptions_SetClosedCB(
@ptrCast(self), @ptrCast(self),
makeConnectionCallbackThunk(T, callback), makeConnectionCallbackThunk(T, callback),
@constCast(@ptrCast(userdata)), thunkhelper.opaqueFromUserdata(userdata),
)).raise(); )).raise();
} }
@ -722,7 +721,7 @@ pub const ConnectionOptions = opaque {
return Status.fromInt(nats_c.natsOptions_SetClosedCB( return Status.fromInt(nats_c.natsOptions_SetClosedCB(
@ptrCast(self), @ptrCast(self),
makeConnectionCallbackThunk(T, callback), makeConnectionCallbackThunk(T, callback),
@constCast(@ptrCast(userdata)), thunkhelper.opaqueFromUserdata(userdata),
)).raise(); )).raise();
} }
@ -735,7 +734,7 @@ pub const ConnectionOptions = opaque {
return Status.fromInt(nats_c.natsOptions_SetClosedCB( return Status.fromInt(nats_c.natsOptions_SetClosedCB(
@ptrCast(self), @ptrCast(self),
makeConnectionCallbackThunk(T, callback), makeConnectionCallbackThunk(T, callback),
@constCast(@ptrCast(userdata)), thunkhelper.opaqueFromUserdata(userdata),
)).raise(); )).raise();
} }
@ -746,12 +745,12 @@ pub const ConnectionOptions = opaque {
comptime attach_callback: *const AttachEventLoopCallbackSignature(T, L), comptime attach_callback: *const AttachEventLoopCallbackSignature(T, L),
comptime read_callback: *const AttachEventLoopCallbackSignature(T), comptime read_callback: *const AttachEventLoopCallbackSignature(T),
comptime write_callback: *const AttachEventLoopCallbackSignature(T), comptime write_callback: *const AttachEventLoopCallbackSignature(T),
comptime detach_callback: *const thunk.SimpleCallbackSignature(T), comptime detach_callback: *const thunkhelper.SimpleCallbackSignature(T),
loop: L, loop: L,
) Error!void { ) Error!void {
return Status.fromInt(nats_c.natsOptions_SetEventLoop( return Status.fromInt(nats_c.natsOptions_SetEventLoop(
@ptrCast(self), @ptrCast(self),
@constCast(@ptrCast(loop)), thunkhelper.opaqueFromUserdata(loop),
makeAttachEventLoopCallbackThunk(T, L, attach_callback), makeAttachEventLoopCallbackThunk(T, L, attach_callback),
makeEventLoopAddRemoveCallbackThunk(T, read_callback), makeEventLoopAddRemoveCallbackThunk(T, read_callback),
makeEventLoopAddRemoveCallbackThunk(T, write_callback), makeEventLoopAddRemoveCallbackThunk(T, write_callback),
@ -820,7 +819,7 @@ pub const ConnectionOptions = opaque {
@ptrCast(self), @ptrCast(self),
retry, retry,
makeConnectionCallbackThunk(T, callback), makeConnectionCallbackThunk(T, callback),
@constCast(@ptrCast(userdata)), thunkhelper.opaqueFromUserdata(userdata),
)).raise(); )).raise();
} }
@ -836,9 +835,9 @@ pub const ConnectionOptions = opaque {
return Status.fromInt(nats_c.natsOptions_SetUserCredentialsCallbacks( return Status.fromInt(nats_c.natsOptions_SetUserCredentialsCallbacks(
@ptrCast(self), @ptrCast(self),
makeJwtHandlerCallbackThunk(T, jwt_callback), makeJwtHandlerCallbackThunk(T, jwt_callback),
@constCast(@ptrCast(jwt_userdata)), thunkhelper.opaqueFromUserdata(jwt_userdata),
makeSignatureHandlerCallbackThunk(U, sig_callback), makeSignatureHandlerCallbackThunk(U, sig_callback),
@constCast(@ptrCast(sig_userdata)), thunkhelper.opaqueFromUserdata(sig_userdata),
)).raise(); )).raise();
} }
@ -876,7 +875,7 @@ pub const ConnectionOptions = opaque {
@ptrCast(self), @ptrCast(self),
pub_key.ptr, pub_key.ptr,
makeSignatureHandlerCallbackThunk(T, sig_callback), makeSignatureHandlerCallbackThunk(T, sig_callback),
@constCast(@ptrCast(sig_userdata)), thunkhelper.opaqueFromUserdata(sig_userdata),
)).raise(); )).raise();
} }
@ -917,18 +916,17 @@ pub const ConnectionOptions = opaque {
const TokenCallback = fn (?*anyopaque) callconv(.C) [*c]const u8; const TokenCallback = fn (?*anyopaque) callconv(.C) [*c]const u8;
pub fn TokenCallbackSignature(comptime T: type) type { pub fn TokenCallbackSignature(comptime UDT: type) type {
return fn (T) [:0]const u8; return fn (UDT) [:0]const u8;
} }
fn makeTokenCallbackThunk( fn makeTokenCallbackThunk(
comptime T: type, comptime UDT: type,
comptime callback: *const TokenCallbackSignature(T), comptime callback: *const TokenCallbackSignature(UDT),
) *const TokenCallback { ) *const TokenCallback {
comptime checkUserDataType(T);
return struct { return struct {
fn thunk(userdata: ?*anyopaque) callconv(.C) [*c]const u8 { fn thunk(userdata: ?*anyopaque) callconv(.C) [*c]const u8 {
const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; const data = thunkhelper.userdataFromOpaque(UDT, userdata);
return callback(data).ptr; return callback(data).ptr;
} }
}.thunk; }.thunk;
@ -936,19 +934,18 @@ fn makeTokenCallbackThunk(
const ConnectionCallback = fn (?*nats_c.natsConnection, ?*anyopaque) callconv(.C) void; const ConnectionCallback = fn (?*nats_c.natsConnection, ?*anyopaque) callconv(.C) void;
pub fn ConnectionCallbackSignature(comptime T: type) type { pub fn ConnectionCallbackSignature(comptime UDT: type) type {
return fn (T, *Connection) void; return fn (UDT, *Connection) void;
} }
fn makeConnectionCallbackThunk( fn makeConnectionCallbackThunk(
comptime T: type, comptime UDT: type,
comptime callback: *const ConnectionCallbackSignature(T), comptime callback: *const ConnectionCallbackSignature(UDT),
) *const ConnectionCallback { ) *const ConnectionCallback {
comptime checkUserDataType(T);
return struct { return struct {
fn thunk(conn: ?*nats_c.natsConnection, userdata: ?*anyopaque) callconv(.C) void { fn thunk(conn: ?*nats_c.natsConnection, userdata: ?*anyopaque) callconv(.C) void {
const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable;
const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; const data = thunkhelper.userdataFromOpaque(UDT, userdata);
callback(data, connection); callback(data, connection);
} }
}.thunk; }.thunk;
@ -956,15 +953,14 @@ fn makeConnectionCallbackThunk(
const ReconnectDelayCallback = fn (?*nats_c.natsConnection, c_int, ?*anyopaque) callconv(.C) i64; const ReconnectDelayCallback = fn (?*nats_c.natsConnection, c_int, ?*anyopaque) callconv(.C) i64;
pub fn ReconnectDelayCallbackSignature(comptime T: type) type { pub fn ReconnectDelayCallbackSignature(comptime UDT: type) type {
return fn (T, *Connection, c_int) i64; return fn (UDT, *Connection, c_int) i64;
} }
fn makeReconnectDelayCallbackThunk( fn makeReconnectDelayCallbackThunk(
comptime T: type, comptime UDT: type,
comptime callback: *const ReconnectDelayCallbackSignature(T), comptime callback: *const ReconnectDelayCallbackSignature(UDT),
) *const ReconnectDelayCallback { ) *const ReconnectDelayCallback {
comptime checkUserDataType(T);
return struct { return struct {
fn thunk( fn thunk(
conn: ?*nats_c.natsConnection, conn: ?*nats_c.natsConnection,
@ -972,7 +968,7 @@ fn makeReconnectDelayCallbackThunk(
userdata: ?*anyopaque, userdata: ?*anyopaque,
) callconv(.C) i64 { ) callconv(.C) i64 {
const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable;
const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; const data = thunkhelper.userdataFromOpaque(UDT, userdata);
return callback(data, connection, attempts); return callback(data, connection, attempts);
} }
}.thunk; }.thunk;
@ -985,15 +981,14 @@ const ErrorHandlerCallback = fn (
?*anyopaque, ?*anyopaque,
) callconv(.C) void; ) callconv(.C) void;
pub fn ErrorHandlerCallbackSignature(comptime T: type) type { pub fn ErrorHandlerCallbackSignature(comptime UDT: type) type {
return fn (T, *Connection, *Subscription, Status) void; return fn (UDT, *Connection, *Subscription, Status) void;
} }
fn makeErrorHandlerCallbackThunk( fn makeErrorHandlerCallbackThunk(
comptime T: type, comptime UDT: type,
comptime callback: *const ErrorHandlerCallbackSignature(T), comptime callback: *const ErrorHandlerCallbackSignature(UDT),
) *const ErrorHandlerCallback { ) *const ErrorHandlerCallback {
comptime checkUserDataType(T);
return struct { return struct {
fn thunk( fn thunk(
conn: ?*nats_c.natsConnection, conn: ?*nats_c.natsConnection,
@ -1003,8 +998,8 @@ fn makeErrorHandlerCallbackThunk(
) callconv(.C) void { ) callconv(.C) void {
const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable;
const subscription: *Subscription = if (sub) |s| @ptrCast(s) else unreachable; const subscription: *Subscription = if (sub) |s| @ptrCast(s) else unreachable;
const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable;
const data = thunkhelper.userdataFromOpaque(UDT, userdata);
callback(data, connection, subscription, Status.fromInt(status)); callback(data, connection, subscription, Status.fromInt(status));
} }
}.thunk; }.thunk;
@ -1018,17 +1013,16 @@ const AttachEventLoopCallback = fn (
nats_c.natsSock, nats_c.natsSock,
) callconv(.C) nats_c.natsStatus; ) callconv(.C) nats_c.natsStatus;
pub fn AttachEventLoopCallbackSignature(comptime T: type, comptime L: type) type { pub fn AttachEventLoopCallbackSignature(comptime UDT: type, comptime L: type) type {
return fn (L, *Connection, c_int) anyerror!T; return fn (L, *Connection, c_int) anyerror!UDT;
} }
fn makeAttachEventLoopCallbackThunk( fn makeAttachEventLoopCallbackThunk(
comptime T: type, comptime UDT: type,
comptime L: type, comptime L: type,
comptime callback: *const AttachEventLoopCallbackSignature(T, L), comptime callback: *const AttachEventLoopCallbackSignature(UDT, L),
) *const ReconnectDelayCallback { ) *const ReconnectDelayCallback {
comptime checkUserDataType(T); comptime thunkhelper.checkUserDataType(L);
comptime checkUserDataType(L);
return struct { return struct {
fn thunk( fn thunk(
userdata: *?*anyopaque, userdata: *?*anyopaque,
@ -1037,10 +1031,12 @@ fn makeAttachEventLoopCallbackThunk(
sock: ?*nats_c.natsSock, sock: ?*nats_c.natsSock,
) callconv(.C) nats_c.natsStatus { ) callconv(.C) nats_c.natsStatus {
const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable;
const ev_loop: L = if (loop) |l| @alignCast(@ptrCast(l)) else unreachable;
userdata.* = callback(ev_loop, connection, sock) catch |err| const ev_loop = thunkhelper.userdataFromOpaque(L, loop);
const result = callback(ev_loop, connection, sock) catch |err|
return Status.fromError(err).toInt(); return Status.fromError(err).toInt();
userdata.* = thunkhelper.opaqueFromUserdata(result);
return nats_c.NATS_OK; return nats_c.NATS_OK;
} }
@ -1049,15 +1045,14 @@ fn makeAttachEventLoopCallbackThunk(
const EventLoopAddRemoveCallback = fn (?*nats_c.natsConnection, c_int, ?*anyopaque) callconv(.C) nats_c.natsStatus; const EventLoopAddRemoveCallback = fn (?*nats_c.natsConnection, c_int, ?*anyopaque) callconv(.C) nats_c.natsStatus;
pub fn EventLoopAddRemoveCallbackSignature(comptime T: type) type { pub fn EventLoopAddRemoveCallbackSignature(comptime UDT: type) type {
return fn (T, *Connection, c_int) anyerror!void; return fn (UDT, *Connection, c_int) anyerror!void;
} }
fn makeEventLoopAddRemoveCallbackThunk( fn makeEventLoopAddRemoveCallbackThunk(
comptime T: type, comptime UDT: type,
comptime callback: *const EventLoopAddRemoveCallbackSignature(T), comptime callback: *const EventLoopAddRemoveCallbackSignature(UDT),
) *const ReconnectDelayCallback { ) *const ReconnectDelayCallback {
comptime checkUserDataType(T);
return struct { return struct {
fn thunk( fn thunk(
conn: ?*nats_c.natsConnection, conn: ?*nats_c.natsConnection,
@ -1065,7 +1060,7 @@ fn makeEventLoopAddRemoveCallbackThunk(
userdata: ?*anyopaque, userdata: ?*anyopaque,
) callconv(.C) nats_c.natsStatus { ) callconv(.C) nats_c.natsStatus {
const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable;
const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; const data = thunkhelper.userdataFromOpaque(UDT, userdata);
callback(data, connection, attempts) catch |err| callback(data, connection, attempts) catch |err|
return Status.fromError(err).toInt(); return Status.fromError(err).toInt();
@ -1076,21 +1071,21 @@ fn makeEventLoopAddRemoveCallbackThunk(
const EventLoopDetachCallback = fn (?*anyopaque) callconv(.C) nats_c.natsStatus; const EventLoopDetachCallback = fn (?*anyopaque) callconv(.C) nats_c.natsStatus;
pub fn EventLoopDetachCallbackSignature(comptime T: type) type { pub fn EventLoopDetachCallbackSignature(comptime UDT: type) type {
return fn (T) anyerror!void; return fn (UDT) anyerror!void;
} }
fn makeEventLoopDetachCallbackThunk( fn makeEventLoopDetachCallbackThunk(
comptime T: type, comptime UDT: type,
comptime callback: *const EventLoopDetachCallbackSignature(T), comptime callback: *const EventLoopDetachCallbackSignature(UDT),
) *const ReconnectDelayCallback { ) *const ReconnectDelayCallback {
comptime checkUserDataType(T);
return struct { return struct {
fn thunk( fn thunk(
userdata: ?*anyopaque, userdata: ?*anyopaque,
) callconv(.C) nats_c.natsStatus { ) callconv(.C) nats_c.natsStatus {
const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; const data = thunkhelper.userdataFromOpaque(UDT, userdata);
callback(data) catch |err| return Status.fromError(err).toInt(); callback(data) catch |err| return Status.fromError(err).toInt();
return nats_c.NATS_OK; return nats_c.NATS_OK;
} }
}.thunk; }.thunk;
@ -1105,26 +1100,24 @@ pub const JwtResponseOrError = union(enum) {
error_message: [:0]u8, error_message: [:0]u8,
}; };
pub fn JwtHandlerCallbackSignature(comptime T: type) type { pub fn JwtHandlerCallbackSignature(comptime UDT: type) type {
return fn (T) JwtResponseOrError; return fn (UDT) JwtResponseOrError;
} }
fn makeJwtHandlerCallbackThunk( fn makeJwtHandlerCallbackThunk(
comptime T: type, comptime UDT: type,
comptime callback: *const JwtHandlerCallbackSignature(T), comptime callback: *const JwtHandlerCallbackSignature(UDT),
) *const JwtHandlerCallback { ) *const JwtHandlerCallback {
comptime checkUserDataType(T);
return struct { return struct {
fn thunk( fn thunk(
jwt_out_raw: ?*?[*:0]u8, jwt_out_raw: ?*?[*:0]u8,
err_out_raw: ?*?[*:0]u8, err_out_raw: ?*?[*:0]u8,
userdata: ?*anyopaque, userdata: ?*anyopaque,
) callconv(.C) nats_c.natsStatus { ) callconv(.C) nats_c.natsStatus {
const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable;
const err_out = err_out_raw orelse unreachable; const err_out = err_out_raw orelse unreachable;
const jwt_out = jwt_out_raw orelse unreachable; const jwt_out = jwt_out_raw orelse unreachable;
switch (callback(data)) { switch (callback(thunkhelper.userdataFromOpaque(UDT, userdata))) {
.jwt => |jwt| { .jwt => |jwt| {
jwt_out.* = jwt.ptr; jwt_out.* = jwt.ptr;
return nats_c.NATS_OK; return nats_c.NATS_OK;
@ -1147,15 +1140,14 @@ pub const SignatureResponseOrError = union(enum) {
error_message: [:0]u8, error_message: [:0]u8,
}; };
pub fn SignatureHandlerCallbackSignature(comptime T: type) type { pub fn SignatureHandlerCallbackSignature(comptime UDT: type) type {
return fn (T, [:0]const u8) SignatureResponseOrError; return fn (UDT, [:0]const u8) SignatureResponseOrError;
} }
fn makeSignatureHandlerCallbackThunk( fn makeSignatureHandlerCallbackThunk(
comptime T: type, comptime UDT: type,
comptime callback: *const SignatureHandlerCallbackSignature(T), comptime callback: *const SignatureHandlerCallbackSignature(UDT),
) *const SignatureHandlerCallback { ) *const SignatureHandlerCallback {
comptime checkUserDataType(T);
return struct { return struct {
fn thunk( fn thunk(
err_out_raw: ?*?[*:0]u8, err_out_raw: ?*?[*:0]u8,
@ -1164,12 +1156,12 @@ fn makeSignatureHandlerCallbackThunk(
nonsense: ?[*:0]const u8, nonsense: ?[*:0]const u8,
userdata: ?*anyopaque, userdata: ?*anyopaque,
) callconv(.C) nats_c.natsStatus { ) callconv(.C) nats_c.natsStatus {
const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable;
const nonce = nonsense orelse unreachable; const nonce = nonsense orelse unreachable;
const err_out = err_out_raw orelse unreachable; const err_out = err_out_raw orelse unreachable;
const sig_out = sig_out_raw orelse unreachable; const sig_out = sig_out_raw orelse unreachable;
const sig_len_out = sig_len_out_raw orelse unreachable; const sig_len_out = sig_len_out_raw orelse unreachable;
const data = thunkhelper.userdataFromOpaque(UDT, userdata);
switch (callback(data, std.mem.sliceTo(nonce, 0))) { switch (callback(data, std.mem.sliceTo(nonce, 0))) {
.signature => |sig| { .signature => |sig| {
sig_out.* = sig.ptr; sig_out.* = sig.ptr;

View File

@ -24,8 +24,7 @@ const err_ = @import("./error.zig");
const Error = err_.Error; const Error = err_.Error;
const Status = err_.Status; const Status = err_.Status;
const thunk = @import("./thunk.zig"); const thunkhelper = @import("./thunk.zig");
const checkUserDataType = @import("./thunk.zig").checkUserDataType;
pub const Subscription = opaque { pub const Subscription = opaque {
pub const MessageCount = struct { pub const MessageCount = struct {
@ -171,13 +170,13 @@ pub const Subscription = opaque {
pub fn setCompletionCallback( pub fn setCompletionCallback(
self: *Subscription, self: *Subscription,
comptime T: type, comptime T: type,
comptime callback: *const thunk.SimpleCallbackThunkSignature(T), comptime callback: *const thunkhelper.SimpleCallbackThunkSignature(T),
userdata: T, userdata: T,
) Error!void { ) Error!void {
return Status.fromInt(nats_c.natsSubscription_SetOnCompleteCB( return Status.fromInt(nats_c.natsSubscription_SetOnCompleteCB(
@ptrCast(self), @ptrCast(self),
thunk.makeSimpleCallbackThunk(T, callback), thunkhelper.makeSimpleCallbackThunk(T, callback),
@constCast(@ptrCast(userdata)), thunkhelper.opaqueFromUserdata(userdata),
)).raise(); )).raise();
} }
}; };
@ -189,15 +188,14 @@ const SubscriptionCallback = fn (
?*anyopaque, ?*anyopaque,
) callconv(.C) void; ) callconv(.C) void;
pub fn SubscriptionCallbackSignature(comptime T: type) type { pub fn SubscriptionCallbackSignature(comptime UDT: type) type {
return fn (T, *Connection, *Subscription, *Message) void; return fn (UDT, *Connection, *Subscription, *Message) void;
} }
pub fn makeSubscriptionCallbackThunk( pub fn makeSubscriptionCallbackThunk(
comptime T: type, comptime UDT: type,
comptime callback: *const SubscriptionCallbackSignature(T), comptime callback: *const SubscriptionCallbackSignature(UDT),
) *const SubscriptionCallback { ) *const SubscriptionCallback {
comptime checkUserDataType(T);
return struct { return struct {
fn thunk( fn thunk(
conn: ?*nats_c.natsConnection, conn: ?*nats_c.natsConnection,
@ -211,8 +209,7 @@ pub fn makeSubscriptionCallbackThunk(
const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable; const connection: *Connection = if (conn) |c| @ptrCast(c) else unreachable;
const subscription: *Subscription = if (sub) |s| @ptrCast(s) else unreachable; const subscription: *Subscription = if (sub) |s| @ptrCast(s) else unreachable;
const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; const data = thunkhelper.userdataFromOpaque(UDT, userdata);
callback(data, connection, subscription, message); callback(data, connection, subscription, message);
} }
}.thunk; }.thunk;

View File

@ -16,68 +16,51 @@ const std = @import("std");
const nats_c = @import("./nats_c.zig").nats_c; const nats_c = @import("./nats_c.zig").nats_c;
pub fn CallbackType(comptime T: type) type { const optional = if (@hasField(std.builtin.Type, "optional")) .optional else .Optional;
return switch (@typeInfo(T)) { const pointer = if (@hasField(std.builtin.Type, "pointer")) .pointer else .Pointer;
.optional => |info| ?CallbackType(info.child), const void_type = if (@hasField(std.builtin.Type, "void")) .void else .Void;
.pointer => |info| switch (info.size) { const null_type = if (@hasField(std.builtin.Type, "null")) .null else .Null;
.Slice => *const T,
else => T, pub fn opaqueFromUserdata(userdata: anytype) ?*anyopaque {
}, checkUserDataType(@TypeOf(userdata));
else => *T, return switch (@typeInfo(@TypeOf(userdata))) {
optional, pointer => @constCast(@ptrCast(userdata)),
void_type => null,
else => @compileError("Unsupported userdata type " ++ @typeName(@TypeOf(userdata))),
}; };
} }
pub const checkUserDataType = if (@hasField(std.builtin.Type, "optional")) pub fn userdataFromOpaque(comptime UDT: type, userdata: ?*anyopaque) UDT {
checkUserDataType_14 comptime checkUserDataType(UDT);
else return if (UDT == void)
checkUserDataType_13; void{}
else if (@typeInfo(UDT) == optional)
pub fn checkUserDataType_14(comptime T: type) void { @alignCast(@ptrCast(userdata))
switch (@typeInfo(T)) { else
.optional => |info| switch (@typeInfo(info.child)) { @alignCast(@ptrCast(userdata.?));
.optional => @compileError(
"nats callbacks can only accept an (optional) single, many," ++
" or c pointer as userdata. \"" ++
@typeName(T) ++ "\" has more than one optional specifier.",
),
else => checkUserDataType(info.child),
},
.pointer => |info| switch (info.size) {
.Slice => @compileError(
"nats callbacks can only accept an (optional) single, many," ++
" or c pointer as userdata, not slices. \"" ++
@typeName(T) ++ "\" appears to be a slice.",
),
else => {},
},
else => @compileError(
"nats callbacks can only accept an (optional) single, many," ++
" or c pointer as userdata. \"" ++
@typeName(T) ++ "\" is not a pointer type.",
),
}
} }
pub fn checkUserDataType_13(comptime T: type) void { pub fn checkUserDataType(comptime T: type) void {
switch (@typeInfo(T)) { switch (@typeInfo(T)) {
.Optional => |info| switch (@typeInfo(info.child)) { optional => |info| switch (@typeInfo(info.child)) {
.Optional => @compileError( optional => @compileError(
"nats callbacks can only accept an (optional) single, many," ++ "nats callbacks can only accept void or an (optional) single, many," ++
" or c pointer as userdata. \"" ++ " or c pointer as userdata. \"" ++
@typeName(T) ++ "\" has more than one optional specifier.", @typeName(T) ++ "\" has more than one optional specifier.",
), ),
else => checkUserDataType(info.child), else => checkUserDataType(info.child),
}, },
.Pointer => |info| switch (info.size) { pointer => |info| switch (info.size) {
.Slice => @compileError( .Slice => @compileError(
"nats callbacks can only accept an (optional) single, many," ++ "nats callbacks can only accept void or an (optional) single, many," ++
" or c pointer as userdata, not slices. \"" ++ " or c pointer as userdata, not slices. \"" ++
@typeName(T) ++ "\" appears to be a slice.", @typeName(T) ++ "\" appears to be a slice.",
), ),
else => {}, else => {},
}, },
void_type => {},
else => @compileError( else => @compileError(
"nats callbacks can only accept an (optional) single, many," ++ "nats callbacks can only accept void or an (optional) single, many," ++
" or c pointer as userdata. \"" ++ " or c pointer as userdata. \"" ++
@typeName(T) ++ "\" is not a pointer type.", @typeName(T) ++ "\" is not a pointer type.",
), ),
@ -86,19 +69,18 @@ pub fn checkUserDataType_13(comptime T: type) void {
const SimpleCallback = fn (?*anyopaque) callconv(.C) void; const SimpleCallback = fn (?*anyopaque) callconv(.C) void;
pub fn SimpleCallbackThunkSignature(comptime T: type) type { pub fn SimpleCallbackThunkSignature(comptime UDT: type) type {
return fn (T) void; return fn (UDT) void;
} }
pub fn makeSimpleCallbackThunk( pub fn makeSimpleCallbackThunk(
comptime T: type, comptime UDT: type,
comptime callback: *const SimpleCallbackThunkSignature(T), comptime callback: *const SimpleCallbackThunkSignature(UDT),
) *const SimpleCallback { ) *const SimpleCallback {
comptime checkUserDataType(T); comptime checkUserDataType(UDT);
return struct { return struct {
fn thunk(userdata: ?*anyopaque) callconv(.C) void { fn thunk(userdata: ?*anyopaque) callconv(.C) void {
const data: T = if (userdata) |u| @alignCast(@ptrCast(u)) else unreachable; callback(userdataFromOpaque(UDT, userdata));
callback(data);
} }
}.thunk; }.thunk;
} }

View File

@ -99,42 +99,46 @@ test "nats.Connection" {
connection.drainTimeout(1000) catch {}; connection.drainTimeout(1000) catch {};
} }
fn reconnectDelayHandler(userdata: *const u32, connection: *nats.Connection, attempts: c_int) i64 { fn callbacks(comptime UDT: type) type {
_ = userdata; return struct {
_ = connection; fn reconnectDelayHandler(userdata: UDT, connection: *nats.Connection, attempts: c_int) i64 {
_ = attempts; _ = userdata;
_ = connection;
_ = attempts;
return 0; return 0;
} }
fn errorHandler( fn errorHandler(
userdata: *const u32, userdata: UDT,
connection: *nats.Connection, connection: *nats.Connection,
subscription: *nats.Subscription, subscription: *nats.Subscription,
status: nats.Status, status: nats.Status,
) void { ) void {
_ = userdata; _ = userdata;
_ = connection; _ = connection;
_ = subscription; _ = subscription;
_ = status; _ = status;
} }
fn connectionHandler(userdata: *const u32, connection: *nats.Connection) void { fn connectionHandler(userdata: UDT, connection: *nats.Connection) void {
_ = userdata; _ = userdata;
_ = connection; _ = connection;
} }
fn jwtHandler(userdata: *const u32) nats.JwtResponseOrError { fn jwtHandler(userdata: UDT) nats.JwtResponseOrError {
_ = userdata; _ = userdata;
// return .{ .jwt = std.heap.raw_c_allocator.dupeZ(u8, "abcdef") catch @panic("no!") }; // return .{ .jwt = std.heap.raw_c_allocator.dupeZ(u8, "abcdef") catch @panic("no!") };
return .{ .error_message = std.heap.raw_c_allocator.dupeZ(u8, "dang") catch @panic("no!") }; return .{ .error_message = std.heap.raw_c_allocator.dupeZ(u8, "dang") catch @panic("no!") };
} }
fn signatureHandler(userdata: *const u32, nonce: [:0]const u8) nats.SignatureResponseOrError { fn signatureHandler(userdata: UDT, nonce: [:0]const u8) nats.SignatureResponseOrError {
_ = userdata; _ = userdata;
_ = nonce; _ = nonce;
// return .{ .signature = std.heap.raw_c_allocator.dupe(u8, "01230123") catch @panic("no!") }; // return .{ .signature = std.heap.raw_c_allocator.dupe(u8, "01230123") catch @panic("no!") };
return .{ .error_message = std.heap.raw_c_allocator.dupeZ(u8, "whoops") catch @panic("no!") }; return .{ .error_message = std.heap.raw_c_allocator.dupeZ(u8, "whoops") catch @panic("no!") };
}
};
} }
test "nats.ConnectionOptions" { test "nats.ConnectionOptions" {
@ -164,14 +168,26 @@ test "nats.ConnectionOptions" {
try options.setMaxReconnect(10); try options.setMaxReconnect(10);
try options.setReconnectWait(500); try options.setReconnectWait(500);
try options.setReconnectJitter(100, 200); try options.setReconnectJitter(100, 200);
try options.setCustomReconnectDelay(*const u32, reconnectDelayHandler, &userdata); try options.setCustomReconnectDelay(*const u32, callbacks(*const u32).reconnectDelayHandler, &userdata);
try options.setCustomReconnectDelay(void, callbacks(void).reconnectDelayHandler, {});
try options.setCustomReconnectDelay(?*const u32, callbacks(?*const u32).reconnectDelayHandler, null);
try options.setReconnectBufSize(1024); try options.setReconnectBufSize(1024);
try options.setMaxPendingMessages(50); try options.setMaxPendingMessages(50);
try options.setErrorHandler(*const u32, errorHandler, &userdata); try options.setErrorHandler(*const u32, callbacks(*const u32).errorHandler, &userdata);
try options.setClosedCallback(*const u32, connectionHandler, &userdata); try options.setErrorHandler(void, callbacks(void).errorHandler, {});
try options.setDisconnectedCallback(*const u32, connectionHandler, &userdata); try options.setErrorHandler(?*const u32, callbacks(?*const u32).errorHandler, null);
try options.setDiscoveredServersCallback(*const u32, connectionHandler, &userdata); try options.setClosedCallback(*const u32, callbacks(*const u32).connectionHandler, &userdata);
try options.setLameDuckModeCallback(*const u32, connectionHandler, &userdata); try options.setClosedCallback(void, callbacks(void).connectionHandler, {});
try options.setClosedCallback(?*const u32, callbacks(?*const u32).connectionHandler, null);
try options.setDisconnectedCallback(*const u32, callbacks(*const u32).connectionHandler, &userdata);
try options.setDisconnectedCallback(void, callbacks(void).connectionHandler, {});
try options.setDisconnectedCallback(?*const u32, callbacks(?*const u32).connectionHandler, null);
try options.setDiscoveredServersCallback(*const u32, callbacks(*const u32).connectionHandler, &userdata);
try options.setDiscoveredServersCallback(void, callbacks(void).connectionHandler, {});
try options.setDiscoveredServersCallback(?*const u32, callbacks(?*const u32).connectionHandler, null);
try options.setLameDuckModeCallback(*const u32, callbacks(*const u32).connectionHandler, &userdata);
try options.setLameDuckModeCallback(void, callbacks(void).connectionHandler, {});
try options.setLameDuckModeCallback(?*const u32, callbacks(?*const u32).connectionHandler, null);
try options.ignoreDiscoveredServers(true); try options.ignoreDiscoveredServers(true);
try options.useGlobalMessageDelivery(false); try options.useGlobalMessageDelivery(false);
try options.ipResolutionOrder(.ipv4_first); try options.ipResolutionOrder(.ipv4_first);
@ -179,8 +195,11 @@ test "nats.ConnectionOptions" {
try options.useOldRequestStyle(false); try options.useOldRequestStyle(false);
try options.setFailRequestsOnDisconnect(true); try options.setFailRequestsOnDisconnect(true);
try options.setNoEcho(true); try options.setNoEcho(true);
try options.setRetryOnFailedConnect(*const u32, connectionHandler, true, &userdata); try options.setRetryOnFailedConnect(*const u32, callbacks(*const u32).connectionHandler, true, &userdata);
try options.setUserCredentialsCallbacks(*const u32, *const u32, jwtHandler, signatureHandler, &userdata, &userdata); try options.setRetryOnFailedConnect(void, callbacks(void).connectionHandler, true, {});
try options.setRetryOnFailedConnect(?*const u32, callbacks(?*const u32).connectionHandler, true, null);
try options.setUserCredentialsCallbacks(*const u32, *const u32, callbacks(*const u32).jwtHandler, callbacks(*const u32).signatureHandler, &userdata, &userdata);
try options.setUserCredentialsCallbacks(void, void, callbacks(void).jwtHandler, callbacks(void).signatureHandler, {}, {});
try options.setWriteDeadline(5); try options.setWriteDeadline(5);
try options.disableNoResponders(true); try options.disableNoResponders(true);
try options.setCustomInboxPrefix("_FOOBOX"); try options.setCustomInboxPrefix("_FOOBOX");

View File

@ -1,9 +1,11 @@
// This file is licensed under the CC0 1.0 license. // This file is licensed under the CC0 1.0 license.
// See: https://creativecommons.org/publicdomain/zero/1.0/legalcode // See: https://creativecommons.org/publicdomain/zero/1.0/legalcode
test { comptime {
_ = @import("./nats.zig"); if (@import("builtin").is_test) {
_ = @import("./connection.zig"); _ = @import("./nats.zig");
_ = @import("./message.zig"); _ = @import("./connection.zig");
_ = @import("./subscription.zig"); _ = @import("./message.zig");
_ = @import("./subscription.zig");
}
} }

View File

@ -190,3 +190,109 @@ test "nats.Subscription (async)" {
); );
} }
} }
fn onVoidMessage(
userdata: void,
connection: *nats.Connection,
subscription: *nats.Subscription,
message: *nats.Message,
) void {
_ = subscription;
_ = userdata;
if (message.getReply()) |reply| {
connection.publish(reply, "greetings") catch @panic("OH NO");
} else @panic("HOW");
}
fn onVoidClose(userdata: void) void {
_ = userdata;
}
test "nats.Subscription (async, void)" {
var server = try util.TestServer.launch(.{});
defer server.stop();
try nats.init(nats.default_spin_count);
defer nats.deinit();
const connection = try nats.Connection.connectTo(server.url);
defer connection.destroy();
const message_subject: [:0]const u8 = "hello";
const message_reply: [:0]const u8 = "reply";
const message_data: [:0]const u8 = "world";
const message = try nats.Message.create(message_subject, message_reply, message_data);
defer message.destroy();
{
{
const subscription = try connection.subscribe(void, message_subject, onVoidMessage, {});
defer subscription.destroy();
try subscription.setCompletionCallback(void, onVoidClose, {});
const response = try connection.requestMessage(message, 1000);
try std.testing.expectEqualStrings(
"greetings",
response.getData() orelse return error.TestUnexpectedResult,
);
}
// we have to sleep to allow the close callback to run.
nats.sleep(1);
}
}
fn onNullMessage(
userdata: ?*void,
connection: *nats.Connection,
subscription: *nats.Subscription,
message: *nats.Message,
) void {
_ = subscription;
_ = userdata;
if (message.getReply()) |reply| {
connection.publish(reply, "greetings") catch @panic("OH NO");
} else @panic("HOW");
}
fn onNullClose(userdata: ?*void) void {
_ = userdata;
}
test "nats.Subscription (async, null)" {
var server = try util.TestServer.launch(.{});
defer server.stop();
try nats.init(nats.default_spin_count);
defer nats.deinit();
const connection = try nats.Connection.connectTo(server.url);
defer connection.destroy();
const message_subject: [:0]const u8 = "hello";
const message_reply: [:0]const u8 = "reply";
const message_data: [:0]const u8 = "world";
const message = try nats.Message.create(message_subject, message_reply, message_data);
defer message.destroy();
{
{
const subscription = try connection.subscribe(?*void, message_subject, onNullMessage, null);
defer subscription.destroy();
try subscription.setCompletionCallback(?*void, onNullClose, null);
const response = try connection.requestMessage(message, 1000);
try std.testing.expectEqualStrings(
"greetings",
response.getData() orelse return error.TestUnexpectedResult,
);
}
// we have to sleep to allow the close callback to run.
nats.sleep(1);
}
}