libressl-portable/tests/handshake_table.c
2023-09-03 18:24:16 -07:00

551 lines
12 KiB
C

/* $OpenBSD: handshake_table.c,v 1.18 2022/12/01 13:49:12 tb Exp $ */
/*
* Copyright (c) 2019 Theo Buehler <tb@openbsd.org>
*
* Permission to use, copy, modify, and distribute this software for any
* purpose with or without fee is hereby granted, provided that the above
* copyright notice and this permission notice appear in all copies.
*
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
#include <err.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include "tls13_handshake.h"
#define MAX_FLAGS (UINT8_MAX + 1)
/*
* From RFC 8446:
*
* Appendix A. State Machine
*
* This appendix provides a summary of the legal state transitions for
* the client and server handshakes. State names (in all capitals,
* e.g., START) have no formal meaning but are provided for ease of
* comprehension. Actions which are taken only in certain circumstances
* are indicated in []. The notation "K_{send,recv} = foo" means "set
* the send/recv key to the given key".
*
* A.1. Client
*
* START <----+
* Send ClientHello | | Recv HelloRetryRequest
* [K_send = early data] | |
* v |
* / WAIT_SH ----+
* | | Recv ServerHello
* | | K_recv = handshake
* Can | V
* send | WAIT_EE
* early | | Recv EncryptedExtensions
* data | +--------+--------+
* | Using | | Using certificate
* | PSK | v
* | | WAIT_CERT_CR
* | | Recv | | Recv CertificateRequest
* | | Certificate | v
* | | | WAIT_CERT
* | | | | Recv Certificate
* | | v v
* | | WAIT_CV
* | | | Recv CertificateVerify
* | +> WAIT_FINISHED <+
* | | Recv Finished
* \ | [Send EndOfEarlyData]
* | K_send = handshake
* | [Send Certificate [+ CertificateVerify]]
* Can send | Send Finished
* app data --> | K_send = K_recv = application
* after here v
* CONNECTED
*
* Note that with the transitions as shown above, clients may send
* alerts that derive from post-ServerHello messages in the clear or
* with the early data keys. If clients need to send such alerts, they
* SHOULD first rekey to the handshake keys if possible.
*
*/
struct child {
enum tls13_message_type mt;
uint8_t flag;
uint8_t forced;
uint8_t illegal;
};
static struct child stateinfo[][TLS13_NUM_MESSAGE_TYPES] = {
[CLIENT_HELLO] = {
{
.mt = SERVER_HELLO_RETRY_REQUEST,
},
{
.mt = SERVER_HELLO,
.flag = WITHOUT_HRR,
},
},
[SERVER_HELLO_RETRY_REQUEST] = {
{
.mt = CLIENT_HELLO_RETRY,
},
},
[CLIENT_HELLO_RETRY] = {
{
.mt = SERVER_HELLO,
},
},
[SERVER_HELLO] = {
{
.mt = SERVER_ENCRYPTED_EXTENSIONS,
},
},
[SERVER_ENCRYPTED_EXTENSIONS] = {
{
.mt = SERVER_CERTIFICATE_REQUEST,
},
{ .mt = SERVER_CERTIFICATE,
.flag = WITHOUT_CR,
},
{
.mt = SERVER_FINISHED,
.flag = WITH_PSK,
},
},
[SERVER_CERTIFICATE_REQUEST] = {
{
.mt = SERVER_CERTIFICATE,
},
},
[SERVER_CERTIFICATE] = {
{
.mt = SERVER_CERTIFICATE_VERIFY,
},
},
[SERVER_CERTIFICATE_VERIFY] = {
{
.mt = SERVER_FINISHED,
},
},
[SERVER_FINISHED] = {
{
.mt = CLIENT_FINISHED,
.forced = WITHOUT_CR | WITH_PSK,
},
{
.mt = CLIENT_CERTIFICATE,
.illegal = WITHOUT_CR | WITH_PSK,
},
},
[CLIENT_CERTIFICATE] = {
{
.mt = CLIENT_FINISHED,
},
{
.mt = CLIENT_CERTIFICATE_VERIFY,
.flag = WITH_CCV,
},
},
[CLIENT_CERTIFICATE_VERIFY] = {
{
.mt = CLIENT_FINISHED,
},
},
[CLIENT_FINISHED] = {
{
.mt = APPLICATION_DATA,
},
},
[APPLICATION_DATA] = {
{
.mt = 0,
},
},
};
const size_t stateinfo_count = sizeof(stateinfo) / sizeof(stateinfo[0]);
void build_table(enum tls13_message_type
table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
struct child current, struct child end,
struct child path[], uint8_t flags, unsigned int depth);
size_t count_handshakes(void);
void edge(enum tls13_message_type start,
enum tls13_message_type end, uint8_t flag);
const char *flag2str(uint8_t flag);
void flag_label(uint8_t flag);
void forced_edges(enum tls13_message_type start,
enum tls13_message_type end, uint8_t forced);
int generate_graphics(void);
void fprint_entry(FILE *stream,
enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES],
uint8_t flags);
void fprint_flags(FILE *stream, uint8_t flags);
const char *mt2str(enum tls13_message_type mt);
void usage(void);
int verify_table(enum tls13_message_type
table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES], int print);
const char *
flag2str(uint8_t flag)
{
const char *ret;
if (flag & (flag - 1))
errx(1, "more than one bit is set");
switch (flag) {
case INITIAL:
ret = "INITIAL";
break;
case NEGOTIATED:
ret = "NEGOTIATED";
break;
case WITHOUT_CR:
ret = "WITHOUT_CR";
break;
case WITHOUT_HRR:
ret = "WITHOUT_HRR";
break;
case WITH_PSK:
ret = "WITH_PSK";
break;
case WITH_CCV:
ret = "WITH_CCV";
break;
case WITH_0RTT:
ret = "WITH_0RTT";
break;
default:
ret = "UNKNOWN";
}
return ret;
}
const char *
mt2str(enum tls13_message_type mt)
{
const char *ret;
switch (mt) {
case INVALID:
ret = "INVALID";
break;
case CLIENT_HELLO:
ret = "CLIENT_HELLO";
break;
case CLIENT_HELLO_RETRY:
ret = "CLIENT_HELLO_RETRY";
break;
case CLIENT_END_OF_EARLY_DATA:
ret = "CLIENT_END_OF_EARLY_DATA";
break;
case CLIENT_CERTIFICATE:
ret = "CLIENT_CERTIFICATE";
break;
case CLIENT_CERTIFICATE_VERIFY:
ret = "CLIENT_CERTIFICATE_VERIFY";
break;
case CLIENT_FINISHED:
ret = "CLIENT_FINISHED";
break;
case SERVER_HELLO:
ret = "SERVER_HELLO";
break;
case SERVER_HELLO_RETRY_REQUEST:
ret = "SERVER_HELLO_RETRY_REQUEST";
break;
case SERVER_ENCRYPTED_EXTENSIONS:
ret = "SERVER_ENCRYPTED_EXTENSIONS";
break;
case SERVER_CERTIFICATE:
ret = "SERVER_CERTIFICATE";
break;
case SERVER_CERTIFICATE_VERIFY:
ret = "SERVER_CERTIFICATE_VERIFY";
break;
case SERVER_CERTIFICATE_REQUEST:
ret = "SERVER_CERTIFICATE_REQUEST";
break;
case SERVER_FINISHED:
ret = "SERVER_FINISHED";
break;
case APPLICATION_DATA:
ret = "APPLICATION_DATA";
break;
case TLS13_NUM_MESSAGE_TYPES:
ret = "TLS13_NUM_MESSAGE_TYPES";
break;
default:
ret = "UNKNOWN";
break;
}
return ret;
}
void
fprint_flags(FILE *stream, uint8_t flags)
{
int first = 1, i;
if (flags == 0) {
fprintf(stream, "%s", flag2str(flags));
return;
}
for (i = 0; i < 8; i++) {
uint8_t set = flags & (1U << i);
if (set) {
fprintf(stream, "%s%s", first ? "" : " | ",
flag2str(set));
first = 0;
}
}
}
void
fprint_entry(FILE *stream,
enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES], uint8_t flags)
{
int i;
fprintf(stream, "\t[");
fprint_flags(stream, flags);
fprintf(stream, "] = {\n");
for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
if (path[i] == 0)
break;
fprintf(stream, "\t\t%s,\n", mt2str(path[i]));
}
fprintf(stream, "\t},\n");
}
void
edge(enum tls13_message_type start, enum tls13_message_type end,
uint8_t flag)
{
printf("\t%s -> %s", mt2str(start), mt2str(end));
flag_label(flag);
printf(";\n");
}
void
flag_label(uint8_t flag)
{
if (flag)
printf(" [label=\"%s\"]", flag2str(flag));
}
void
forced_edges(enum tls13_message_type start, enum tls13_message_type end,
uint8_t forced)
{
uint8_t forced_flag, i;
if (forced == 0)
return;
for (i = 0; i < 8; i++) {
forced_flag = forced & (1U << i);
if (forced_flag)
edge(start, end, forced_flag);
}
}
int
generate_graphics(void)
{
enum tls13_message_type start, end;
unsigned int child;
uint8_t flag;
uint8_t forced;
printf("digraph G {\n");
printf("\t%s [shape=box];\n", mt2str(CLIENT_HELLO));
printf("\t%s [shape=box];\n", mt2str(APPLICATION_DATA));
for (start = CLIENT_HELLO; start < APPLICATION_DATA; start++) {
for (child = 0; stateinfo[start][child].mt != 0; child++) {
end = stateinfo[start][child].mt;
flag = stateinfo[start][child].flag;
forced = stateinfo[start][child].forced;
if (forced == 0)
edge(start, end, flag);
else
forced_edges(start, end, forced);
}
}
printf("}\n");
return 0;
}
extern enum tls13_message_type handshakes[][TLS13_NUM_MESSAGE_TYPES];
extern size_t handshake_count;
size_t
count_handshakes(void)
{
size_t ret = 0, i;
for (i = 0; i < handshake_count; i++) {
if (handshakes[i][0] != INVALID)
ret++;
}
return ret;
}
void
build_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
struct child current, struct child end, struct child path[], uint8_t flags,
unsigned int depth)
{
unsigned int i;
if (depth >= TLS13_NUM_MESSAGE_TYPES - 1)
errx(1, "recursed too deeply");
/* Record current node. */
path[depth++] = current;
flags |= current.flag;
/* If we haven't reached the end, recurse over the children. */
if (current.mt != end.mt) {
for (i = 0; stateinfo[current.mt][i].mt != 0; i++) {
struct child child = stateinfo[current.mt][i];
int forced = stateinfo[current.mt][i].forced;
int illegal = stateinfo[current.mt][i].illegal;
if ((forced == 0 || (forced & flags)) &&
(illegal == 0 || !(illegal & flags)))
build_table(table, child, end, path, flags,
depth);
}
return;
}
if (flags == 0)
errx(1, "path does not set flags");
if (table[flags][0] != 0)
errx(1, "path traversed twice");
for (i = 0; i < depth; i++)
table[flags][i] = path[i].mt;
}
int
verify_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
int print)
{
int success = 1, i;
size_t num_valid, num_found = 0;
uint8_t flags = 0;
do {
if (table[flags][0] == 0)
continue;
num_found++;
for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
if (table[flags][i] != handshakes[flags][i]) {
fprintf(stderr,
"incorrect entry %d of handshake ", i);
fprint_flags(stderr, flags);
fprintf(stderr, "\n");
success = 0;
}
}
if (print)
fprint_entry(stdout, table[flags], flags);
} while(++flags != 0);
num_valid = count_handshakes();
if (num_valid != num_found) {
fprintf(stderr,
"incorrect number of handshakes: want %zu, got %zu.\n",
num_valid, num_found);
success = 0;
}
return success;
}
void
usage(void)
{
fprintf(stderr, "usage: handshake_table [-C | -g]\n");
exit(1);
}
int
main(int argc, char *argv[])
{
static enum tls13_message_type
hs_table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES] = {
[INITIAL] = {
CLIENT_HELLO,
SERVER_HELLO_RETRY_REQUEST,
CLIENT_HELLO_RETRY,
SERVER_HELLO,
},
};
struct child start = {
.mt = CLIENT_HELLO,
};
struct child end = {
.mt = APPLICATION_DATA,
};
struct child path[TLS13_NUM_MESSAGE_TYPES] = {{0}};
uint8_t flags = NEGOTIATED;
unsigned int depth = 0;
int ch, graphviz = 0, print = 0;
while ((ch = getopt(argc, argv, "Cg")) != -1) {
switch (ch) {
case 'C':
print = 1;
break;
case 'g':
graphviz = 1;
break;
default:
usage();
}
}
argc -= optind;
argv += optind;
if (argc != 0)
usage();
if (graphviz && print)
usage();
if (graphviz)
return generate_graphics();
build_table(hs_table, start, end, path, flags, depth);
if (!verify_table(hs_table, print))
return 1;
return 0;
}