diff --git a/libs/libc/netdb/lib_dns.h b/libs/libc/netdb/lib_dns.h index 8385ca563e3..a42d7909cb6 100644 --- a/libs/libc/netdb/lib_dns.h +++ b/libs/libc/netdb/lib_dns.h @@ -170,7 +170,7 @@ void dns_restorelock(unsigned int count); * ****************************************************************************/ -int dns_bind(sa_family_t family); +int dns_bind(sa_family_t family, bool stream); /**************************************************************************** * Name: dns_query diff --git a/libs/libc/netdb/lib_dnsbind.c b/libs/libc/netdb/lib_dnsbind.c index a61893e56ba..3b8017b6d5a 100644 --- a/libs/libc/netdb/lib_dnsbind.c +++ b/libs/libc/netdb/lib_dnsbind.c @@ -63,15 +63,16 @@ * ****************************************************************************/ -int dns_bind(sa_family_t family) +int dns_bind(sa_family_t family, bool stream) { + int stype = stream ? SOCK_STREAM : SOCK_DGRAM; struct timeval tv; int sd; int ret; /* Create a new socket */ - sd = socket(family, SOCK_DGRAM | SOCK_CLOEXEC, 0); + sd = socket(family, stype | SOCK_CLOEXEC, 0); if (sd < 0) { ret = -get_errno(); diff --git a/libs/libc/netdb/lib_dnsquery.c b/libs/libc/netdb/lib_dnsquery.c index 0564278075a..b636ef3bfc9 100644 --- a/libs/libc/netdb/lib_dnsquery.c +++ b/libs/libc/netdb/lib_dnsquery.c @@ -110,6 +110,189 @@ struct dns_query_data_s * Private Functions ****************************************************************************/ +/**************************************************************************** + * Name: stream_send + * + * Description: + * A wrapper of send() to deal with short results for SOCK_STREAM socket. + * + * Input Parameters: + * Same as send(). + * + * Returned Value: + * Same as send(). + * + ****************************************************************************/ + +static ssize_t stream_send(int fd, FAR const void *buf, size_t len) +{ + ssize_t total = 0; + + while (len > 0) + { + ssize_t ret = send(fd, buf, len, 0); + if (ret == -1) + { + if (total == 0) + { + total = ret; + } + break; + } + + buf = (FAR const uint8_t *)buf + len; + len -= ret; + total += ret; + } + + return total; +} + +/**************************************************************************** + * Name: stream_recv + * + * Description: + * A wrapper of recv() to deal with short results for SOCK_STREAM socket. + * + * Input Parameters: + * Same as recv(). + * + * Returned Value: + * Same as recv(). + * + ****************************************************************************/ + +static ssize_t stream_recv(int fd, FAR void *buf, size_t len) +{ + ssize_t total = 0; + + while (len > 0) + { + ssize_t ret = recv(fd, buf, len, 0); + if (ret == 0) + { + /* the peer closed the connection */ + + set_errno(EMSGSIZE); + ret = -1; + } + + if (ret == -1) + { + if (total == 0) + { + total = ret; + } + break; + } + + buf = (FAR uint8_t *)buf + len; + len -= ret; + total += ret; + } + + return total; +} + +/**************************************************************************** + * Name: stream_send_record + * + * Description: + * Send a DNS message over SOCK_STREAM socket. + * + * Input Parameters: + * Same as send(). + * + * Returned Value: + * Same as send(). + * + ****************************************************************************/ + +static ssize_t stream_send_record(int fd, FAR const void *buf, size_t len) +{ + ssize_t ret; + uint8_t reclen[2]; + + /* RFC 1035 + * 4.2.2. TCP usage + * + * > The message is prefixed with a two byte length field which + * > gives the message length, excluding the two byte length field. + */ + + reclen[0] = (uint8_t)(len >> 8); + reclen[1] = (uint8_t)len; + ret = stream_send(fd, reclen, sizeof(reclen)); + if (ret < sizeof(reclen)) + { + return -1; + } + + return stream_send(fd, buf, len); +} + +/**************************************************************************** + * Name: stream_recv_record + * + * Description: + * Receive a DNS message over SOCK_STREAM socket. + * + * Input Parameters: + * Same as recv(). + * + * Returned Value: + * Same as recv(). + * + ****************************************************************************/ + +static ssize_t stream_recv_record(int fd, FAR void *buf, size_t len) +{ + size_t rlen; + ssize_t ret; + uint8_t reclen[2]; + + /* RFC 1035 + * 4.2.2. TCP usage + * + * > The message is prefixed with a two byte length field which + * > gives the message length, excluding the two byte length field. + */ + + ret = stream_recv(fd, reclen, sizeof(reclen)); + if (ret < sizeof(reclen)) + { + if (ret >= 0) + { + set_errno(EMSGSIZE); + } + + return -1; + } + + rlen = ((uint16_t)reclen[0] << 8) + reclen[1]; + if (rlen > len) + { + nerr("ERROR: DNS response (%zu bytes) didn't fit " + "the buffer. (%zu bytes) You may need to bump " + "CONFIG_NETDB_DNSCLIENT_MAXRESPONSE\n", rlen, len); + set_errno(EMSGSIZE); + return -1; + } + + ret = stream_recv(fd, buf, rlen); + if (ret != rlen) + { + if (ret >= 0) + { + set_errno(EMSGSIZE); + } + + return -1; + } + + return ret; +} + /**************************************************************************** * Name: dns_parse_name * @@ -200,7 +383,8 @@ static inline uint16_t dns_alloc_id(void) static int dns_send_query(int sd, FAR const char *name, FAR union dns_addr_u *uaddr, uint16_t rectype, FAR struct dns_query_info_s *qinfo, - FAR uint8_t *buffer) + FAR uint8_t *buffer, + bool stream) { FAR struct dns_header_s *hdr; FAR uint8_t *dest; @@ -304,7 +488,15 @@ static int dns_send_query(int sd, FAR const char *name, return ret; } - ret = send(sd, buffer, dest - buffer, 0); + if (stream) + { + ret = stream_send_record(sd, buffer, dest - buffer); + } + else + { + ret = send(sd, buffer, dest - buffer, 0); + } + if (ret < 0) { ret = -get_errno(); @@ -329,7 +521,8 @@ static int dns_send_query(int sd, FAR const char *name, static int dns_recv_response(int sd, FAR union dns_addr_u *addr, int naddr, FAR struct dns_query_info_s *qinfo, - FAR uint32_t *ttl, FAR uint8_t *buffer) + FAR uint32_t *ttl, FAR uint8_t *buffer, + bool stream, bool *should_try_stream) { FAR uint8_t *nameptr; FAR uint8_t *namestart; @@ -350,7 +543,15 @@ static int dns_recv_response(int sd, FAR union dns_addr_u *addr, int naddr, /* Receive the response */ - ret = recv(sd, buffer, RECV_BUFFER_SIZE, 0); + if (stream) + { + ret = stream_recv_record(sd, buffer, RECV_BUFFER_SIZE); + } + else + { + ret = recv(sd, buffer, RECV_BUFFER_SIZE, 0); + } + if (ret < 0) { ret = -get_errno(); @@ -378,6 +579,29 @@ static int dns_recv_response(int sd, FAR union dns_addr_u *addr, int naddr, /* Check for error */ + if ((hdr->flags1 & DNS_FLAG1_TRUNC) != 0) + { + /* RFC 2181 + * 9. The TC (truncated) header bit + * + * > When a DNS client receives a reply with TC set, + * > it should ignore that response, and query again, + * > using a mechanism, such as a TCP connection, + * > that will permit larger replies. + */ + + if (stream) + { + nerr("ERROR: DNS response truncated on stream socket.\n"); + return -EPROTO; + } + + ninfo("ERROR: DNS response truncated. " + "Falling back to stream socket.\n"); + *should_try_stream = true; + return -EAGAIN; + } + if ((hdr->flags2 & DNS_FLAG2_ERR_MASK) != 0) { nerr("ERROR: DNS reported error: flags2=%02x\n", hdr->flags2); @@ -627,6 +851,7 @@ static int dns_query_callback(FAR void *arg, FAR struct sockaddr *addr, int retries; int ret; int sd; + bool stream = false; /* Loop while receive timeout errors occur and there are remaining * retries. @@ -634,12 +859,15 @@ static int dns_query_callback(FAR void *arg, FAR struct sockaddr *addr, for (retries = 0; retries < CONFIG_NETDB_DNSCLIENT_RETRIES; retries++) { + bool should_try_stream; + +try_stream: #ifdef CONFIG_NET_IPv6 if (dns_is_queryfamily(AF_INET6)) { /* Send the IPv6 query */ - sd = dns_bind(addr->sa_family); + sd = dns_bind(addr->sa_family, stream); if (sd < 0) { query->result = sd; @@ -649,7 +877,7 @@ static int dns_query_callback(FAR void *arg, FAR struct sockaddr *addr, ret = dns_send_query(sd, query->hostname, (FAR union dns_addr_u *)addr, DNS_RECTYPE_AAAA, &qdata->qinfo, - qdata->buffer); + qdata->buffer, stream); if (ret < 0) { dns_query_error("ERROR: IPv6 dns_send_query failed", @@ -660,16 +888,24 @@ static int dns_query_callback(FAR void *arg, FAR struct sockaddr *addr, { /* Obtain the IPv6 response */ + should_try_stream = false; ret = dns_recv_response(sd, &query->addr[next], CONFIG_NETDB_MAX_IPv6ADDR, &qdata->qinfo, - &query->ttl, qdata->buffer); + &query->ttl, qdata->buffer, + stream, &should_try_stream); if (ret >= 0) { next += ret; } else { + if (!stream && should_try_stream) + { + stream = true; + goto try_stream; /* Don't consume retry count */ + } + dns_query_error("ERROR: IPv6 dns_recv_response failed", ret, (FAR union dns_addr_u *)addr); query->result = ret; @@ -685,7 +921,7 @@ static int dns_query_callback(FAR void *arg, FAR struct sockaddr *addr, { /* Send the IPv4 query */ - sd = dns_bind(addr->sa_family); + sd = dns_bind(addr->sa_family, stream); if (sd < 0) { query->result = sd; @@ -694,7 +930,8 @@ static int dns_query_callback(FAR void *arg, FAR struct sockaddr *addr, ret = dns_send_query(sd, query->hostname, (FAR union dns_addr_u *)addr, - DNS_RECTYPE_A, &qdata->qinfo, qdata->buffer); + DNS_RECTYPE_A, &qdata->qinfo, qdata->buffer, + stream); if (ret < 0) { dns_query_error("ERROR: IPv4 dns_send_query failed", @@ -710,16 +947,24 @@ static int dns_query_callback(FAR void *arg, FAR struct sockaddr *addr, next = *query->naddr / 2; } + should_try_stream = false; ret = dns_recv_response(sd, &query->addr[next], CONFIG_NETDB_MAX_IPv4ADDR, &qdata->qinfo, - &query->ttl, qdata->buffer); + &query->ttl, qdata->buffer, + stream, &should_try_stream); if (ret >= 0) { next += ret; } else { + if (!stream && should_try_stream) + { + stream = true; + goto try_stream; /* Don't consume retry count */ + } + dns_query_error("ERROR: IPv4 dns_recv_response failed", ret, (FAR union dns_addr_u *)addr); query->result = ret;