From e0006c078782a3463fb354358a18a99b04b616da Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Sun, 14 Apr 2024 11:22:03 +0200 Subject: [PATCH 01/28] WIP: Working Box> based Service but with stream & middleware disabled. --- examples/serve-zone.rs | 93 ++- src/net/server/connection.rs | 196 +++--- src/net/server/dgram.rs | 380 ++++++----- src/net/server/message.rs | 629 +++++++++--------- src/net/server/middleware/chain.rs | 22 +- .../middleware/processors/mandatory_svc.rs | 527 +++++++++++++++ src/net/server/middleware/processors/mod.rs | 1 + src/net/server/mod.rs | 8 +- src/net/server/service.rs | 34 +- src/net/server/stream.rs | 6 +- src/net/server/tests.rs | 198 +++--- src/net/server/util.rs | 16 +- 12 files changed, 1357 insertions(+), 753 deletions(-) create mode 100644 src/net/server/middleware/processors/mandatory_svc.rs diff --git a/examples/serve-zone.rs b/examples/serve-zone.rs index dbe082d02..0390028e2 100644 --- a/examples/serve-zone.rs +++ b/examples/serve-zone.rs @@ -21,17 +21,17 @@ use domain::base::{Dname, Message, Rtype, ToDname}; use domain::net::server::buf::VecBufSource; use domain::net::server::dgram::DgramServer; use domain::net::server::message::Request; -use domain::net::server::service::{ - CallResult, ServiceError, Transaction, TransactionStream, -}; -use domain::net::server::stream::StreamServer; +use domain::net::server::service::{CallResult, ServiceError}; +// use domain::net::server::stream::StreamServer; use domain::net::server::util::{mk_builder_for_target, service_fn}; use domain::zonefile::inplace; use domain::zonetree::{Answer, Rrset}; use domain::zonetree::{Zone, ZoneTree}; +use futures::stream::{once, FuturesOrdered}; use octseq::OctetsBuilder; use std::future::{pending, ready, Future}; use std::io::BufReader; +use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::time::Duration; use tokio::net::{TcpListener, UdpSocket}; @@ -76,11 +76,11 @@ async fn main() { tokio::spawn(async move { udp_srv.run().await }); } - let sock = TcpListener::bind(addr).await.unwrap(); - let tcp_srv = StreamServer::new(sock, VecBufSource, svc); - let tcp_metrics = tcp_srv.metrics(); + // let sock = TcpListener::bind(addr).await.unwrap(); + // let tcp_srv = StreamServer::new(sock, VecBufSource, svc); + // let tcp_metrics = tcp_srv.metrics(); - tokio::spawn(async move { tcp_srv.run().await }); + // tokio::spawn(async move { tcp_srv.run().await }); tokio::spawn(async move { loop { @@ -95,14 +95,14 @@ async fn main() { metrics.num_sent_responses(), ); } - eprintln!( - "Server status: TCP: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}", - tcp_metrics.num_connections(), - tcp_metrics.num_inflight_requests(), - tcp_metrics.num_pending_writes(), - tcp_metrics.num_received_requests(), - tcp_metrics.num_sent_responses(), - ); + // eprintln!( + // "Server status: TCP: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}", + // tcp_metrics.num_connections(), + // tcp_metrics.num_inflight_requests(), + // tcp_metrics.num_pending_writes(), + // tcp_metrics.num_received_requests(), + // tcp_metrics.num_sent_responses(), + // ); } }); @@ -113,22 +113,22 @@ async fn main() { fn my_service( request: Request>, zones: Arc, -) -> Result< - Transaction< - Vec, - impl Future>, ServiceError>> + Send, - >, - ServiceError, -> { +) -> + Box< + dyn futures::stream::Stream< + Item = Result>, ServiceError>, + > + Send + + Unpin, + > + { let qtype = request.message().sole_question().unwrap().qtype(); match qtype { Rtype::AXFR if request.transport_ctx().is_non_udp() => { - let fut = handle_axfr_request(request, zones); - Ok(Transaction::stream(Box::pin(fut))) + Box::new(handle_axfr_request(request, zones)) } _ => { - let fut = handle_non_axfr_request(request, zones); - Ok(Transaction::single(fut)) + let fut = Box::pin(handle_non_axfr_request(request, zones)); + Box::new(once(fut)) } } } @@ -155,11 +155,22 @@ async fn handle_non_axfr_request( Ok(CallResult::new(additional)) } -async fn handle_axfr_request( +fn handle_axfr_request( request: Request>, zones: Arc, -) -> TransactionStream>, ServiceError>> { - let mut stream = TransactionStream::default(); +) -> FuturesOrdered< + Pin>, ServiceError>> + Send>>, +> { + // let mut stream = TransactionStream::default(); + let mut stream = FuturesOrdered::< + Pin< + Box< + dyn Future< + Output = Result>, ServiceError>, + > + Send, + >, + >, + >::new(); // Look up the zone for the queried name. let question = request.message().sole_question().unwrap(); @@ -261,7 +272,15 @@ async fn handle_axfr_request( fn add_to_stream( answer: Answer, msg: &Message>, - stream: &mut TransactionStream>, ServiceError>>, + stream: &mut FuturesOrdered< + Pin< + Box< + dyn Future< + Output = Result>, ServiceError>, + > + Send, + >, + >, + >, ) { let builder = mk_builder_for_target(); let additional = answer.to_message(msg, builder); @@ -272,10 +291,18 @@ fn add_to_stream( fn add_additional_to_stream( mut additional: AdditionalBuilder>>, msg: &Message>, - stream: &mut TransactionStream>, ServiceError>>, + stream: &mut FuturesOrdered< + Pin< + Box< + dyn Future< + Output = Result>, ServiceError>, + > + Send, + >, + >, + >, ) { set_axfr_header(msg, &mut additional); - stream.push(ready(Ok(CallResult::new(additional)))); + stream.push_back(Box::pin(ready(Ok(CallResult::new(additional))))); } fn set_axfr_header( diff --git a/src/net/server/connection.rs b/src/net/server/connection.rs index 173079fc1..1c3177f8e 100644 --- a/src/net/server/connection.rs +++ b/src/net/server/connection.rs @@ -21,10 +21,10 @@ use tracing::{debug, enabled, error, trace, warn}; use crate::base::wire::Composer; use crate::base::{Message, StreamTarget}; use crate::net::server::buf::BufSource; -use crate::net::server::message::CommonMessageFlow; +// use crate::net::server::message::CommonMessageFlow; use crate::net::server::message::Request; use crate::net::server::metrics::ServerMetrics; -use crate::net::server::middleware::chain::MiddlewareChain; +// use crate::net::server::middleware::chain::MiddlewareChain; use crate::net::server::service::{ CallResult, Service, ServiceError, ServiceFeedback, }; @@ -32,7 +32,7 @@ use crate::net::server::util::to_pcap_text; use crate::utils::config::DefMinMax; use super::message::{NonUdpTransportContext, TransportSpecificContext}; -use super::middleware::builder::MiddlewareBuilder; +// use super::middleware::builder::MiddlewareBuilder; use super::stream::Config as ServerConfig; use super::ServerCommand; use std::fmt::Display; @@ -118,9 +118,9 @@ pub struct Config { /// Limit on the number of DNS responses queued for wriing to the client. max_queued_responses: usize, - /// The middleware chain used to pre-process requests and post-process - /// responses. - middleware_chain: MiddlewareChain, + // /// The middleware chain used to pre-process requests and post-process + // /// responses. + // middleware_chain: MiddlewareChain, } impl Config @@ -209,24 +209,24 @@ where self.max_queued_responses = value; } - /// Set the middleware chain used to pre-process requests and post-process - /// responses. - /// - /// # Reconfigure - /// - /// On [`StreamServer::reconfigure`] only new connections created after - /// this setting is changed will use the new value, existing connections - /// and in-flight requests (and their responses) will continue to use - /// their current middleware chain. - /// - /// [`StreamServer::reconfigure`]: - /// super::stream::StreamServer::reconfigure() - pub fn set_middleware_chain( - &mut self, - value: MiddlewareChain, - ) { - self.middleware_chain = value; - } + // /// Set the middleware chain used to pre-process requests and post-process + // /// responses. + // /// + // /// # Reconfigure + // /// + // /// On [`StreamServer::reconfigure`] only new connections created after + // /// this setting is changed will use the new value, existing connections + // /// and in-flight requests (and their responses) will continue to use + // /// their current middleware chain. + // /// + // /// [`StreamServer::reconfigure`]: + // /// super::stream::StreamServer::reconfigure() + // pub fn set_middleware_chain( + // &mut self, + // value: MiddlewareChain, + // ) { + // self.middleware_chain = value; + // } } //--- Default @@ -241,7 +241,7 @@ where idle_timeout: IDLE_TIMEOUT.default(), response_write_timeout: RESPONSE_WRITE_TIMEOUT.default(), max_queued_responses: MAX_QUEUED_RESPONSES.default(), - middleware_chain: MiddlewareBuilder::default().build(), + // middleware_chain: MiddlewareBuilder::default().build(), } } } @@ -254,7 +254,7 @@ impl Clone for Config { idle_timeout: self.idle_timeout, response_write_timeout: self.response_write_timeout, max_queued_responses: self.max_queued_responses, - middleware_chain: self.middleware_chain.clone(), + // middleware_chain: self.middleware_chain.clone(), } } } @@ -407,7 +407,7 @@ where ServerCommand>, >, ) where - Svc::Future: Send, + Svc::Stream: Send, { self.metrics.inc_num_connections(); @@ -426,7 +426,7 @@ where Buf: BufSource + Send + Sync + Clone + 'static, Buf::Output: Octets + Send + Sync, Svc: Service + Send + Sync + 'static, - Svc::Future: Send, + Svc::Stream: Send, Svc::Target: Send + Composer + Default, { /// Connection handler main loop. @@ -552,7 +552,7 @@ where idle_timeout, response_write_timeout, max_queued_responses: _, - middleware_chain: _, + // middleware_chain: _, }, .. // Ignore the Server specific configuration settings }) => { @@ -715,7 +715,7 @@ where res: Result, ) -> Result<(), ConnectionEvent> where - Svc::Future: Send, + Svc::Stream: Send, { res.and_then(|msg| { let received_at = Instant::now(); @@ -731,16 +731,18 @@ where self.idle_timer.full_msg_received(); // Process the received message - self.process_request( - msg, - received_at, - self.addr, - self.config.middleware_chain.clone(), - &self.service, - self.metrics.clone(), - self.result_q_tx.clone(), - ) - .map_err(ConnectionEvent::ServiceError) + // self.process_request( + // msg, + // received_at, + // self.addr, + // self.config.middleware_chain.clone(), + // &self.service, + // self.metrics.clone(), + // self.result_q_tx.clone(), + // ) + // .map_err(ConnectionEvent::ServiceError) + + todo!() }) } } @@ -760,63 +762,63 @@ where } } -//--- CommonMessageFlow - -impl CommonMessageFlow - for Connection -where - Buf: BufSource, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static, - Svc::Target: Send, -{ - type Meta = Sender>; - - /// Add information to the request that relates to the type of server we - /// are and our state where relevant. - fn add_context_to_request( - &self, - request: Message, - received_at: Instant, - addr: SocketAddr, - ) -> Request { - let ctx = NonUdpTransportContext::new(Some(self.config.idle_timeout)); - let ctx = TransportSpecificContext::NonUdp(ctx); - Request::new(addr, received_at, request, ctx) - } - - /// Process the result from the middleware -> service -> middleware call - /// tree. - fn process_call_result( - _request: &Request, - call_result: CallResult, - tx: Self::Meta, - metrics: Arc, - ) { - // We can't send in a spawned async task as then we would just - // accumlate tasks even if the target queue is full. We can't call - // `tx.blocking_send()` as that would block the Tokio runtime. So - // instead we try and send and if that fails because the queue is full - // then we abort. - match tx.try_send(call_result) { - Ok(()) => { - metrics.set_num_pending_writes( - tx.max_capacity() - tx.capacity(), - ); - } - - Err(TrySendError::Closed(_msg)) => { - // TODO: How should we properly communicate this to the operator? - error!("Unable to queue message for sending: server is shutting down."); - } - - Err(TrySendError::Full(_msg)) => { - // TODO: How should we properly communicate this to the operator? - error!("Unable to queue message for sending: queue is full."); - } - } - } -} +// //--- CommonMessageFlow + +// impl CommonMessageFlow +// for Connection +// where +// Buf: BufSource, +// Buf::Output: Octets + Send + Sync + 'static, +// Svc: Service + Send + Sync + 'static, +// Svc::Target: Send, +// { +// type Meta = Sender>; + +// /// Add information to the request that relates to the type of server we +// /// are and our state where relevant. +// fn add_context_to_request( +// &self, +// request: Message, +// received_at: Instant, +// addr: SocketAddr, +// ) -> Request { +// let ctx = NonUdpTransportContext::new(Some(self.config.idle_timeout)); +// let ctx = TransportSpecificContext::NonUdp(ctx); +// Request::new(addr, received_at, request, ctx) +// } + +// /// Process the result from the middleware -> service -> middleware call +// /// tree. +// fn process_call_result( +// _request: &Request, +// call_result: CallResult, +// tx: Self::Meta, +// metrics: Arc, +// ) { +// // We can't send in a spawned async task as then we would just +// // accumlate tasks even if the target queue is full. We can't call +// // `tx.blocking_send()` as that would block the Tokio runtime. So +// // instead we try and send and if that fails because the queue is full +// // then we abort. +// match tx.try_send(call_result) { +// Ok(()) => { +// metrics.set_num_pending_writes( +// tx.max_capacity() - tx.capacity(), +// ); +// } + +// Err(TrySendError::Closed(_msg)) => { +// // TODO: How should we properly communicate this to the operator? +// error!("Unable to queue message for sending: server is shutting down."); +// } + +// Err(TrySendError::Full(_msg)) => { +// // TODO: How should we properly communicate this to the operator? +// error!("Unable to queue message for sending: queue is full."); +// } +// } +// } +// } //----------- DnsMessageReceiver --------------------------------------------- diff --git a/src/net/server/dgram.rs b/src/net/server/dgram.rs index 59356c36d..f9188da99 100644 --- a/src/net/server/dgram.rs +++ b/src/net/server/dgram.rs @@ -20,6 +20,7 @@ use std::string::String; use std::string::ToString; use std::sync::{Arc, Mutex}; +use futures::StreamExt; use octseq::Octets; use tokio::io::ReadBuf; use tokio::net::UdpSocket; @@ -34,10 +35,10 @@ use tracing::{enabled, error, trace}; use crate::base::Message; use crate::net::server::buf::BufSource; use crate::net::server::error::Error; -use crate::net::server::message::CommonMessageFlow; +// use crate::net::server::message::CommonMessageFlow; use crate::net::server::message::Request; use crate::net::server::metrics::ServerMetrics; -use crate::net::server::middleware::chain::MiddlewareChain; +// use crate::net::server::middleware::chain::MiddlewareChain; use crate::net::server::service::{CallResult, Service, ServiceFeedback}; use crate::net::server::sock::AsyncDgramSock; use crate::net::server::util::to_pcap_text; @@ -45,7 +46,7 @@ use crate::utils::config::DefMinMax; use super::buf::VecBufSource; use super::message::{TransportSpecificContext, UdpTransportContext}; -use super::middleware::builder::MiddlewareBuilder; +// use super::middleware::builder::MiddlewareBuilder; use super::ServerCommand; use crate::base::wire::Composer; use arc_swap::ArcSwap; @@ -84,23 +85,22 @@ const MAX_RESPONSE_SIZE: DefMinMax = DefMinMax::new(1232, 512, 4096); /// Configuration for a datagram server. #[derive(Debug)] -pub struct Config { +pub struct Config /**/ { /// Limit suggested to [`Service`] on maximum response size to create. max_response_size: Option, /// Limit the time to wait for a complete message to be written to the client. write_timeout: Duration, - - /// The middleware chain used to pre-process requests and post-process - /// responses. - middleware_chain: MiddlewareChain, + // /// The middleware chain used to pre-process requests and post-process + // /// responses. + // middleware_chain: MiddlewareChain, } -impl Config -where - RequestOctets: Octets, - Target: Composer + Default, -{ +// impl Config +// where +// RequestOctets: Octets, +// Target: Composer + Default, +impl Config { /// Creates a new, default config. pub fn new() -> Self { Default::default() @@ -145,50 +145,50 @@ where self.write_timeout = value; } - /// Set the middleware chain used to pre-process requests and post-process - /// responses. - /// - /// # Reconfigure - /// - /// On [`DgramServer::reconfigure`]` any change to this setting will only - /// affect requests (and their responses) received after the setting is - /// changed, in progress requests will be unaffected. - pub fn set_middleware_chain( - &mut self, - value: MiddlewareChain, - ) { - self.middleware_chain = value; - } + // /// Set the middleware chain used to pre-process requests and post-process + // /// responses. + // /// + // /// # Reconfigure + // /// + // /// On [`DgramServer::reconfigure`]` any change to this setting will only + // /// affect requests (and their responses) received after the setting is + // /// changed, in progress requests will be unaffected. + // pub fn set_middleware_chain( + // &mut self, + // value: MiddlewareChain, + // ) { + // self.middleware_chain = value; + // } } //--- Default -impl Default for Config -where - RequestOctets: Octets, - Target: Composer + Default, -{ +// impl Default for Config +// where +// RequestOctets: Octets, +// Target: Composer + Default, +impl Default for Config { fn default() -> Self { Self { max_response_size: Some(MAX_RESPONSE_SIZE.default()), write_timeout: WRITE_TIMEOUT.default(), - middleware_chain: MiddlewareBuilder::default().build(), + // middleware_chain: MiddlewareBuilder::default().build(), } } } //--- Clone -impl Clone for Config -where - RequestOctets: Octets, - Target: Composer + Default, -{ +// impl Clone for Config +// where +// RequestOctets: Octets, +// Target: Composer + Default, +impl Clone for Config { fn clone(&self) -> Self { Self { max_response_size: self.max_response_size, write_timeout: self.write_timeout, - middleware_chain: self.middleware_chain.clone(), + // middleware_chain: self.middleware_chain.clone(), } } } @@ -196,14 +196,17 @@ where //------------ DgramServer --------------------------------------------------- /// A [`ServerCommand`] capable of propagating a DgramServer [`Config`] value. -type ServerCommandType = ServerCommand>; +// type ServerCommandType = ServerCommand>; +type ServerCommandType = ServerCommand; /// A thread safe sender of [`ServerCommand`]s. -type CommandSender = - Arc>>>; +// type CommandSender = +// Arc>>>; +type CommandSender = Arc>>; /// A thread safe receiver of [`ServerCommand`]s. -type CommandReceiver = watch::Receiver>; +// type CommandReceiver = watch::Receiver>; +type CommandReceiver = watch::Receiver; /// A server for connecting clients via a datagram based network transport to /// a [`Service`]. @@ -297,22 +300,22 @@ where Sock: AsyncDgramSock + Send + Sync + 'static, Buf: BufSource + Send + Sync + 'static, Buf::Output: Octets + Send + Sync, - Svc: Service + Send + Sync + 'static, + Svc: Service + Send + Sync + 'static + Clone, Svc::Target: Send + Composer + Default, { /// The configuration of the server. - config: Arc>>, + config: Arc*/>>, /// A receiver for receiving [`ServerCommand`]s. /// /// Used by both the server and spawned connections to react to sent /// commands. - command_rx: CommandReceiver, + command_rx: CommandReceiver, //, /// A sender for sending [`ServerCommand`]s. /// /// Used to signal the server to stop, reconfigure, etc. - command_tx: CommandSender, + command_tx: CommandSender, //, /// The network socket over which client requests will be received /// and responses sent. @@ -335,7 +338,7 @@ where Sock: AsyncDgramSock + Send + Sync, Buf: BufSource + Send + Sync, Buf::Output: Octets + Send + Sync, - Svc: Service + Send + Sync, + Svc: Service + Send + Sync + Clone, Svc::Target: Send + Composer + Default, { /// Constructs a new [`DgramServer`] with default configuration. @@ -364,7 +367,7 @@ where sock: Sock, buf: Buf, service: Svc, - config: Config, + config: Config, //, ) -> Self { let (command_tx, command_rx) = watch::channel(ServerCommand::Init); let command_tx = Arc::new(Mutex::new(command_tx)); @@ -390,7 +393,7 @@ where Sock: AsyncDgramSock + Send + Sync + 'static, Buf: BufSource + Send + Sync + 'static, Buf::Output: Octets + Send + Sync + 'static + Debug, - Svc: Service + Send + Sync + 'static, + Svc: Service + Send + Sync + 'static + Clone, Svc::Target: Send + Composer + Debug + Default, { /// Get a reference to the network source being used to receive messages. @@ -413,7 +416,7 @@ where Sock: AsyncDgramSock + Send + Sync + 'static, Buf: BufSource + Send + Sync, Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static, + Svc: Service + Send + Sync + 'static + Clone, Svc::Target: Send + Composer + Default, { /// Start the server. @@ -424,8 +427,8 @@ where /// /// [`shutdown`]: Self::shutdown pub async fn run(&self) - where - Svc::Future: Send, + // where + // Svc::Stream: Send, { if let Err(err) = self.run_until_error().await { error!("Server stopped due to error: {err}"); @@ -437,7 +440,7 @@ where /// pub fn reconfigure( &self, - config: Config, + config: Config, //,, ) -> Result<(), Error> { self.command_tx .lock() @@ -501,13 +504,13 @@ where Sock: AsyncDgramSock + Send + Sync + 'static, Buf: BufSource + Send + Sync, Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static, + Svc: Service + Send + Sync + 'static + Clone, Svc::Target: Send + Composer + Default, { /// Receive incoming messages until shutdown or fatal error. async fn run_until_error(&self) -> Result<(), String> - where - Svc::Future: Send, +// where + // Svc::Stream: Send, { let mut command_rx = self.command_rx.clone(); @@ -523,7 +526,7 @@ where } _ = self.sock.readable() => { - let (msg, addr, bytes_read) = match self.recv_from() { + let (buf, addr, bytes_read) = match self.recv_from() { Ok(res) => res, Err(err) if err.kind() == io::ErrorKind::WouldBlock => continue, Err(err) => return Err(format!("Error while receiving message: {err}")), @@ -533,22 +536,80 @@ where self.metrics.inc_num_received_requests(); if enabled!(Level::TRACE) { - let pcap_text = to_pcap_text(&msg, bytes_read); + let pcap_text = to_pcap_text(&buf, bytes_read); trace!(%addr, pcap_text, "Received message"); } let state = self.mk_state_for_request(); - self.process_request( - msg, received_at, addr, - self.config.load().middleware_chain.clone(), - &self.service, - self.metrics.clone(), - state, - ) - .map_err(|err| - format!("Error while processing message: {err}") - )?; + // self.process_request( + // msg, received_at, addr, + // self.config.load().middleware_chain.clone(), + // &self.service, + // self.metrics.clone(), + // state, + // ) + // .map_err(|err| + // format!("Error while processing message: {err}") + // )?; + + let svc = self.service.clone(); + let cfg = self.config.clone(); + let metrics = self.metrics.clone(); + + tokio::spawn(async move { + match Message::from_octets(buf) { + Err(err) => { + tracing::warn!("Failed while parsing request message: {err}"); + } + + Ok(msg) => { + let ctx = UdpTransportContext::new(cfg.load().max_response_size); + let ctx = TransportSpecificContext::Udp(ctx); + let request = Request::new(addr, received_at, msg, ctx); + let mut stream = svc.call(request); + while let Some(Ok(call_result)) = stream.next().await { + let (response, feedback) = call_result.into_inner(); + + if let Some(feedback) = feedback { + match feedback { + ServiceFeedback::Reconfigure { + idle_timeout: _, // N/A - only applies to connection-oriented transports + } => { + // Nothing to do. + } + } + } + + // Process the DNS response message, if any. + if let Some(response) = response { + // Convert the DNS response message into bytes. + let target = response.finish(); + let bytes = target.as_dgram_slice(); + + // Logging + if enabled!(Level::TRACE) { + let pcap_text = to_pcap_text(bytes, bytes.len()); + trace!(%addr, pcap_text, "Sending response"); + } + + // Actually write the DNS response message bytes to the UDP + // socket. + let _ = Self::send_to( + &state.sock, + bytes, + &addr, + state.write_timeout, + ) + .await; + + metrics.dec_num_pending_writes(); + metrics.inc_num_sent_responses(); + } + } + } + } + }); } } } @@ -558,7 +619,7 @@ where fn process_server_command( &self, res: Result<(), watch::error::RecvError>, - command_rx: &mut CommandReceiver, + command_rx: &mut CommandReceiver, //, ) -> Result<(), String> { // If the parent server no longer exists but was not cleanly shutdown // then the command channel will be closed and attempting to check for @@ -644,9 +705,8 @@ where /// into a [`RequestState`] ready for passing through the /// [`CommonMessageFlow`] call chain and ultimately back to ourselves at /// [`process_call_reusult`]. - fn mk_state_for_request( - &self, - ) -> RequestState { + fn mk_state_for_request(&self) -> RequestState { + //}, Buf::Output, Svc::Target> { RequestState::new( self.sock.clone(), self.command_tx.clone(), @@ -655,85 +715,85 @@ where } } -//--- CommonMessageFlow - -impl CommonMessageFlow - for DgramServer -where - Sock: AsyncDgramSock + Send + Sync + 'static, - Buf: BufSource + Send + Sync + 'static, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static, - Svc::Target: Send + Composer + Default, -{ - type Meta = RequestState; - - /// Add information to the request that relates to the type of server we - /// are and our state where relevant. - fn add_context_to_request( - &self, - request: Message, - received_at: Instant, - addr: SocketAddr, - ) -> Request { - let ctx = - UdpTransportContext::new(self.config.load().max_response_size); - let ctx = TransportSpecificContext::Udp(ctx); - Request::new(addr, received_at, request, ctx) - } - - /// Process the result from the middleware -> service -> middleware call - /// tree. - fn process_call_result( - request: &Request, - call_result: CallResult, - state: RequestState, - metrics: Arc, - ) { - metrics.inc_num_pending_writes(); - let client_addr = request.client_addr(); - - tokio::spawn(async move { - let (response, feedback) = call_result.into_inner(); - - if let Some(feedback) = feedback { - match feedback { - ServiceFeedback::Reconfigure { - idle_timeout: _, // N/A - only applies to connection-oriented transports - } => { - // Nothing to do. - } - } - } - - // Process the DNS response message, if any. - if let Some(response) = response { - // Convert the DNS response message into bytes. - let target = response.finish(); - let bytes = target.as_dgram_slice(); - - // Logging - if enabled!(Level::TRACE) { - let pcap_text = to_pcap_text(bytes, bytes.len()); - trace!(%client_addr, pcap_text, "Sending response"); - } - - // Actually write the DNS response message bytes to the UDP - // socket. - let _ = Self::send_to( - &state.sock, - bytes, - &client_addr, - state.write_timeout, - ) - .await; - - metrics.dec_num_pending_writes(); - metrics.inc_num_sent_responses(); - } - }); - } -} +// //--- CommonMessageFlow + +// impl CommonMessageFlow +// for DgramServer +// where +// Sock: AsyncDgramSock + Send + Sync + 'static, +// Buf: BufSource + Send + Sync + 'static, +// Buf::Output: Octets + Send + Sync + 'static, +// Svc: Service + Send + Sync + 'static, +// Svc::Target: Send + Composer + Default, +// { +// type Meta = RequestState; + +// /// Add information to the request that relates to the type of server we +// /// are and our state where relevant. +// fn add_context_to_request( +// &self, +// request: Message, +// received_at: Instant, +// addr: SocketAddr, +// ) -> Request { +// let ctx = +// UdpTransportContext::new(self.config.load().max_response_size); +// let ctx = TransportSpecificContext::Udp(ctx); +// Request::new(addr, received_at, request, ctx) +// } + +// /// Process the result from the middleware -> service -> middleware call +// /// tree. +// fn process_call_result( +// request: &Request, +// call_result: CallResult, +// state: RequestState, +// metrics: Arc, +// ) { +// metrics.inc_num_pending_writes(); +// let client_addr = request.client_addr(); + +// tokio::spawn(async move { +// let (response, feedback) = call_result.into_inner(); + +// if let Some(feedback) = feedback { +// match feedback { +// ServiceFeedback::Reconfigure { +// idle_timeout: _, // N/A - only applies to connection-oriented transports +// } => { +// // Nothing to do. +// } +// } +// } + +// // Process the DNS response message, if any. +// if let Some(response) = response { +// // Convert the DNS response message into bytes. +// let target = response.finish(); +// let bytes = target.as_dgram_slice(); + +// // Logging +// if enabled!(Level::TRACE) { +// let pcap_text = to_pcap_text(bytes, bytes.len()); +// trace!(%client_addr, pcap_text, "Sending response"); +// } + +// // Actually write the DNS response message bytes to the UDP +// // socket. +// let _ = Self::send_to( +// &state.sock, +// bytes, +// &client_addr, +// state.write_timeout, +// ) +// .await; + +// metrics.dec_num_pending_writes(); +// metrics.inc_num_sent_responses(); +// } +// }); +// } +// } //--- Drop @@ -742,7 +802,7 @@ where Sock: AsyncDgramSock + Send + Sync + 'static, Buf: BufSource + Send + Sync + 'static, Buf::Output: Octets + Send + Sync, - Svc: Service + Send + Sync + 'static, + Svc: Service + Send + Sync + 'static + Clone, Svc::Target: Send + Composer + Default, { fn drop(&mut self) { @@ -757,7 +817,8 @@ where /// Data needed by [`DgramServer::process_call_result`] which needs to be /// passed through the [`CommonMessageFlow`] call chain. -pub struct RequestState { +pub struct RequestState { + //, RequestOctets, Target> { /// The network socket over which this request was received and over which /// the response should be sent. sock: Arc, @@ -765,18 +826,19 @@ pub struct RequestState { /// A sender for sending [`ServerCommand`]s. /// /// Used to signal the server to stop, reconfigure, etc. - command_tx: CommandSender, + command_tx: CommandSender, //, /// The maximum amount of time to wait for a response datagram to be /// accepted by the operating system for writing back to the client. write_timeout: Duration, } -impl RequestState { +impl RequestState { + //, RequestOctets, Target> { /// Creates a new instance of [`RequestState`]. fn new( sock: Arc, - command_tx: CommandSender, + command_tx: CommandSender, //, write_timeout: Duration, ) -> Self { Self { @@ -789,8 +851,8 @@ impl RequestState { //--- Clone -impl Clone - for RequestState +impl Clone + for RequestState { fn clone(&self) -> Self { Self { diff --git a/src/net/server/message.rs b/src/net/server/message.rs index bf1d71dbc..2c2cec0ff 100644 --- a/src/net/server/message.rs +++ b/src/net/server/message.rs @@ -13,7 +13,7 @@ use tracing::{enabled, error, info_span, warn}; use crate::base::Message; use crate::net::server::buf::BufSource; use crate::net::server::metrics::ServerMetrics; -use crate::net::server::middleware::chain::MiddlewareChain; +// use crate::net::server::middleware::chain::MiddlewareChain; use super::service::{CallResult, Service, ServiceError, Transaction}; use super::util::start_reply; @@ -214,320 +214,313 @@ impl> Clone for Request { } } -//----------- CommonMessageFlow ---------------------------------------------- - -/// Perform processing common to all messages being handled by a DNS server. -/// -/// All messages received by a DNS server need to pass through the following -/// processing stages: -/// -/// - Pre-processing. -/// - Service processing. -/// - Post-processing. -/// -/// The strategy is common but some server specific aspects are delegated to -/// the server that implements this trait: -/// -/// - Adding context to a request. -/// - Finalizing the handling of a response. -/// -/// Servers implement this trait to benefit from the common processing -/// required while still handling aspects specific to the server themselves. -/// -/// Processing starts at [`process_request`]. -/// -///
-/// -/// This trait exists as a convenient mechanism for sharing common code -/// between server implementations. The default function implementations -/// provided by this trait are not intended to be overridden by consumers of -/// this library. -/// -///
-/// -/// [`process_request`]: Self::process_request() -pub trait CommonMessageFlow -where - Buf: BufSource, - Buf::Output: Octets + Send + Sync, - Svc: Service + Send + Sync, -{ - /// Server-specific data that it chooses to pass along with the request in - /// order that it may receive it when `process_call_result()` is - /// invoked on the implementing server. - type Meta: Clone + Send + Sync + 'static; - - /// Process a DNS request message. - /// - /// This function consumes the given message buffer and processes the - /// contained message, if any, to completion, possibly resulting in a - /// response being passed to [`Self::process_call_result`]. - /// - /// The request message is a given as a seqeuence of bytes in `buf` - /// originating from client address `addr`. - /// - /// The [`MiddlewareChain`] and [`Service`] to be used to process the - /// message are supplied in the `middleware_chain` and `svc` arguments - /// respectively. - /// - /// Any server specific state to be used and/or updated as part of the - /// processing should be supplied via the `state` argument whose type is - /// defined by the implementing type. - /// - /// On error the result will be a [`ServiceError`]. - #[allow(clippy::too_many_arguments)] - fn process_request( - &self, - buf: Buf::Output, - received_at: Instant, - addr: SocketAddr, - middleware_chain: MiddlewareChain, - svc: &Svc, - metrics: Arc, - meta: Self::Meta, - ) -> Result<(), ServiceError> - where - Svc: 'static, - Svc::Target: Send + Composer + Default, - Svc::Future: Send, - Buf::Output: 'static, - { - boomerang( - self, - buf, - received_at, - addr, - middleware_chain, - metrics, - svc, - meta, - ) - } - - /// Add context to a request. - /// - /// The server supplies this function to annotate the received message - /// with additional information about its origins. - fn add_context_to_request( - &self, - request: Message, - received_at: Instant, - addr: SocketAddr, - ) -> Request; - - /// Finalize a response. - /// - /// The server supplies this function to handle the response as - /// appropriate for the server, e.g. to write the response back to the - /// originating client. - /// - /// The response is the form of a [`CallResult`]. - fn process_call_result( - request: &Request, - call_result: CallResult, - state: Self::Meta, - metrics: Arc, - ); -} - -/// Propogate a message through the [`MiddlewareChain`] to the [`Service`] and -/// flow the response in reverse back down the same path, a bit like throwing -/// a boomerang. -#[allow(clippy::too_many_arguments)] -fn boomerang( - server: &Server, - buf: ::Output, - received_at: Instant, - addr: SocketAddr, - middleware_chain: MiddlewareChain< - ::Output, - ::Output>>::Target, - >, - metrics: Arc, - svc: &Svc, - meta: Server::Meta, -) -> Result<(), ServiceError> -where - Buf: BufSource, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static, - Svc::Future: Send, - Svc::Target: Send + Composer + Default, - Server: CommonMessageFlow + ?Sized, -{ - let message = Message::from_octets(buf).map_err(|err| { - warn!("Failed while parsing request message: {err}"); - ServiceError::InternalError - })?; - - let request = server.add_context_to_request(message, received_at, addr); - - let preprocessing_result = do_middleware_preprocessing::( - &request, - &middleware_chain, - &metrics, - )?; - - let (txn, aborted_preprocessor_idx) = - do_service_call::(preprocessing_result, &request, svc); - - do_middleware_postprocessing::( - request, - meta, - middleware_chain, - txn, - aborted_preprocessor_idx, - metrics, - ); - - Ok(()) -} - -/// Pass a pre-processed request to the [`Service`] to handle. -/// -/// If [`Service::call`] returns an error this function will produce a DNS -/// ServFail error response. If the returned error is -/// [`ServiceError::InternalError`] it will also be logged. -#[allow(clippy::type_complexity)] -fn do_service_call( - preprocessing_result: ControlFlow<( - Transaction, - usize, - )>, - request: &Request<::Output>, - svc: &Svc, -) -> (Transaction, Option) -where - Buf: BufSource, - Buf::Output: Octets, - Svc: Service, - Svc::Target: Composer + Default, -{ - match preprocessing_result { - ControlFlow::Continue(()) => { - let res = if enabled!(Level::INFO) { - let span = info_span!("svc-call", - msg_id = request.message().header().id(), - client = %request.client_addr(), - ); - let _guard = span.enter(); - svc.call(request.clone()) - } else { - svc.call(request.clone()) - }; - - // Handle any error returned by the service. - let txn = res.unwrap_or_else(|err| { - if matches!(err, ServiceError::InternalError) { - error!("Service error while processing request: {err}"); - } - - let mut response = start_reply(request); - response.header_mut().set_rcode(err.rcode()); - let call_result = CallResult::new(response.additional()); - Transaction::immediate(Ok(call_result)) - }); - - // Pass the transaction out for post-processing. - (txn, None) - } - - ControlFlow::Break((txn, aborted_preprocessor_idx)) => { - (txn, Some(aborted_preprocessor_idx)) - } - } -} - -/// Pre-process a request. -/// -/// Pre-processing involves parsing a [`Message`] from the byte buffer and -/// pre-processing it via any supplied [`MiddlewareChain`]. -/// -/// On success the result is an immutable request message and a -/// [`ControlFlow`] decision about whether to continue with further processing -/// or to break early with a possible response. If processing failed the -/// result will be a [`ServiceError`]. -/// -/// On break the result will be one ([`Transaction::single`]) or more -/// ([`Transaction::stream`]) to post-process. -#[allow(clippy::type_complexity)] -fn do_middleware_preprocessing( - request: &Request, - middleware_chain: &MiddlewareChain, - metrics: &Arc, -) -> Result< - ControlFlow<(Transaction, usize)>, - ServiceError, -> -where - Buf: BufSource, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync, - Svc::Future: Send, - Svc::Target: Send + Composer + Default + 'static, -{ - let span = info_span!("pre-process", - msg_id = request.message().header().id(), - client = %request.client_addr(), - ); - let _guard = span.enter(); - - metrics.inc_num_inflight_requests(); - - let pp_res = middleware_chain.preprocess(request); - - Ok(pp_res) -} - -/// Post-process a response in the context of its originating request. -/// -/// Each response is post-processed in its own Tokio task. Note that there is -/// no guarantee about the order in which responses will be post-processed. If -/// the order of a seqence of responses is important it should be provided as -/// a [`Transaction::stream`] rather than [`Transaction::single`]. -/// -/// Responses are first post-processed by the [`MiddlewareChain`] provided, if -/// any, then passed to [`Self::process_call_result`] for final processing. -#[allow(clippy::type_complexity)] -fn do_middleware_postprocessing( - request: Request, - meta: Server::Meta, - middleware_chain: MiddlewareChain, - mut response_txn: Transaction, - last_processor_id: Option, - metrics: Arc, -) where - Buf: BufSource, - Buf::Output: Octets + Send + Sync + 'static, - Svc: Service + Send + Sync + 'static, - Svc::Future: Send, - Svc::Target: Send + Composer + Default, - Server: CommonMessageFlow + ?Sized, -{ - tokio::spawn(async move { - let span = info_span!("post-process", - msg_id = request.message().header().id(), - client = %request.client_addr(), - ); - let _guard = span.enter(); - - while let Some(Ok(mut call_result)) = response_txn.next().await { - if let Some(response) = call_result.get_response_mut() { - middleware_chain.postprocess( - &request, - response, - last_processor_id, - ); - } - - Server::process_call_result( - &request, - call_result, - meta.clone(), - metrics.clone(), - ); - } - - metrics.dec_num_inflight_requests(); - }); -} +// //----------- CommonMessageFlow ---------------------------------------------- + +// /// Perform processing common to all messages being handled by a DNS server. +// /// +// /// All messages received by a DNS server need to pass through the following +// /// processing stages: +// /// +// /// - Pre-processing. +// /// - Service processing. +// /// - Post-processing. +// /// +// /// The strategy is common but some server specific aspects are delegated to +// /// the server that implements this trait: +// /// +// /// - Adding context to a request. +// /// - Finalizing the handling of a response. +// /// +// /// Servers implement this trait to benefit from the common processing +// /// required while still handling aspects specific to the server themselves. +// /// +// /// Processing starts at [`process_request`]. +// /// +// ///
+// /// +// /// This trait exists as a convenient mechanism for sharing common code +// /// between server implementations. The default function implementations +// /// provided by this trait are not intended to be overridden by consumers of +// /// this library. +// /// +// ///
+// /// +// /// [`process_request`]: Self::process_request() +// pub trait CommonMessageFlow +// where +// Buf: BufSource, +// Buf::Output: Octets + Send + Sync, +// Svc: Service + Send + Sync, +// { +// /// Server-specific data that it chooses to pass along with the request in +// /// order that it may receive it when `process_call_result()` is +// /// invoked on the implementing server. +// type Meta: Clone + Send + Sync + 'static; + +// /// Process a DNS request message. +// /// +// /// This function consumes the given message buffer and processes the +// /// contained message, if any, to completion, possibly resulting in a +// /// response being passed to [`Self::process_call_result`]. +// /// +// /// The request message is a given as a seqeuence of bytes in `buf` +// /// originating from client address `addr`. +// /// +// /// The [`MiddlewareChain`] and [`Service`] to be used to process the +// /// message are supplied in the `middleware_chain` and `svc` arguments +// /// respectively. +// /// +// /// Any server specific state to be used and/or updated as part of the +// /// processing should be supplied via the `state` argument whose type is +// /// defined by the implementing type. +// /// +// /// On error the result will be a [`ServiceError`]. +// #[allow(clippy::too_many_arguments)] +// fn process_request( +// &self, +// buf: Buf::Output, +// received_at: Instant, +// addr: SocketAddr, +// middleware_chain: MiddlewareChain, +// svc: &Svc, +// metrics: Arc, +// meta: Self::Meta, +// ) -> Result<(), ServiceError> +// where +// Svc: 'static, +// Svc::Target: Send + Composer + Default, +// Svc::Stream: Send, +// Buf::Output: 'static, +// { +// boomerang( +// self, +// buf, +// received_at, +// addr, +// middleware_chain, +// metrics, +// svc, +// meta, +// ) +// } + +// /// Add context to a request. +// /// +// /// The server supplies this function to annotate the received message +// /// with additional information about its origins. +// fn add_context_to_request( +// &self, +// request: Message, +// received_at: Instant, +// addr: SocketAddr, +// ) -> Request; + +// /// Finalize a response. +// /// +// /// The server supplies this function to handle the response as +// /// appropriate for the server, e.g. to write the response back to the +// /// originating client. +// /// +// /// The response is the form of a [`CallResult`]. +// fn process_call_result( +// request: &Request, +// call_result: CallResult, +// state: Self::Meta, +// metrics: Arc, +// ); +// } + +// /// Propogate a message through the [`MiddlewareChain`] to the [`Service`] and +// /// flow the response in reverse back down the same path, a bit like throwing +// /// a boomerang. +// #[allow(clippy::too_many_arguments)] +// fn boomerang( +// server: &Server, +// buf: ::Output, +// received_at: Instant, +// addr: SocketAddr, +// middleware_chain: MiddlewareChain< +// ::Output, +// ::Output>>::Target, +// >, +// metrics: Arc, +// svc: &Svc, +// meta: Server::Meta, +// ) -> Result<(), ServiceError> +// where +// Buf: BufSource, +// Buf::Output: Octets + Send + Sync + 'static, +// Svc: Service + Send + Sync + 'static, +// Svc::Stream: Send, +// Svc::Target: Send + Composer + Default, +// Server: CommonMessageFlow + ?Sized, +// { +// let message = Message::from_octets(buf).map_err(|err| { +// warn!("Failed while parsing request message: {err}"); +// ServiceError::InternalError +// })?; + +// let request = server.add_context_to_request(message, received_at, addr); + +// let preprocessing_result = do_middleware_preprocessing::( +// &request, +// &middleware_chain, +// &metrics, +// )?; + +// let (txn, aborted_preprocessor_idx) = +// do_service_call::(preprocessing_result, &request, svc); + +// do_middleware_postprocessing::( +// request, +// meta, +// middleware_chain, +// txn, +// aborted_preprocessor_idx, +// metrics, +// ); + +// Ok(()) +// } + +// /// Pass a pre-processed request to the [`Service`] to handle. +// /// +// /// If [`Service::call`] returns an error this function will produce a DNS +// /// ServFail error response. If the returned error is +// /// [`ServiceError::InternalError`] it will also be logged. +// #[allow(clippy::type_complexity)] +// fn do_service_call( +// preprocessing_result: ControlFlow<(Svc::Stream, usize)>, +// request: &Request<::Output>, +// svc: &Svc, +// ) -> (Svc::Stream, Option) +// where +// Buf: BufSource, +// Buf::Output: Octets, +// Svc: Service, +// Svc::Target: Composer + Default, +// { +// match preprocessing_result { +// ControlFlow::Continue(()) => { +// let res = if enabled!(Level::INFO) { +// let span = info_span!("svc-call", +// msg_id = request.message().header().id(), +// client = %request.client_addr(), +// ); +// let _guard = span.enter(); +// svc.call(request.clone()) +// } else { +// svc.call(request.clone()) +// }; + +// // Handle any error returned by the service. +// // let txn = res.unwrap_or_else(|err| { +// // if matches!(err, ServiceError::InternalError) { +// // error!("Service error while processing request: {err}"); +// // } + +// // let mut response = start_reply(request); +// // response.header_mut().set_rcode(err.rcode()); +// // let call_result = CallResult::new(response.additional()); +// // Transaction::immediate(Ok(call_result)) +// // }); + +// // Pass the transaction out for post-processing. +// (res, None) +// } + +// ControlFlow::Break((txn, aborted_preprocessor_idx)) => { +// (txn, Some(aborted_preprocessor_idx)) +// } +// } +// } + +// /// Pre-process a request. +// /// +// /// Pre-processing involves parsing a [`Message`] from the byte buffer and +// /// pre-processing it via any supplied [`MiddlewareChain`]. +// /// +// /// On success the result is an immutable request message and a +// /// [`ControlFlow`] decision about whether to continue with further processing +// /// or to break early with a possible response. If processing failed the +// /// result will be a [`ServiceError`]. +// /// +// /// On break the result will be one ([`Transaction::single`]) or more +// /// ([`Transaction::stream`]) to post-process. +// #[allow(clippy::type_complexity)] +// fn do_middleware_preprocessing( +// request: &Request, +// middleware_chain: &MiddlewareChain, +// metrics: &Arc, +// ) -> Result, ServiceError> +// where +// Buf: BufSource, +// Buf::Output: Octets + Send + Sync + 'static, +// Svc: Service + Send + Sync, +// Svc::Target: Send + Composer + Default + 'static, +// { +// let span = info_span!("pre-process", +// msg_id = request.message().header().id(), +// client = %request.client_addr(), +// ); +// let _guard = span.enter(); + +// metrics.inc_num_inflight_requests(); + +// let pp_res = middleware_chain.preprocess(request); + +// Ok(pp_res) +// } + +// /// Post-process a response in the context of its originating request. +// /// +// /// Each response is post-processed in its own Tokio task. Note that there is +// /// no guarantee about the order in which responses will be post-processed. If +// /// the order of a seqence of responses is important it should be provided as +// /// a [`Transaction::stream`] rather than [`Transaction::single`]. +// /// +// /// Responses are first post-processed by the [`MiddlewareChain`] provided, if +// /// any, then passed to [`Self::process_call_result`] for final processing. +// #[allow(clippy::type_complexity)] +// fn do_middleware_postprocessing( +// request: Request, +// meta: Server::Meta, +// middleware_chain: MiddlewareChain, +// mut response_txn: Svc::Stream, +// last_processor_id: Option, +// metrics: Arc, +// ) where +// Buf: BufSource, +// Buf::Output: Octets + Send + Sync + 'static, +// Svc: Service + Send + Sync + 'static, +// Svc::Stream: Send, +// Svc::Target: Send + Composer + Default, +// Server: CommonMessageFlow + ?Sized, +// { +// tokio::spawn(async move { +// let span = info_span!("post-process", +// msg_id = request.message().header().id(), +// client = %request.client_addr(), +// ); +// let _guard = span.enter(); + +// while let Some(Ok(mut call_result)) = response_txn.next().await { +// if let Some(response) = call_result.get_response_mut() { +// middleware_chain.postprocess( +// &request, +// response, +// last_processor_id, +// ); +// } + +// Server::process_call_result( +// &request, +// call_result, +// meta.clone(), +// metrics.clone(), +// ); +// } + +// metrics.dec_num_inflight_requests(); +// }); +// } diff --git a/src/net/server/middleware/chain.rs b/src/net/server/middleware/chain.rs index 57d4c9956..4546f8d1d 100644 --- a/src/net/server/middleware/chain.rs +++ b/src/net/server/middleware/chain.rs @@ -1,10 +1,13 @@ //! Chaining [`MiddlewareProcessor`]s together. +use core::future::ready; use core::ops::{ControlFlow, RangeTo}; use std::fmt::Debug; use std::sync::Arc; use std::vec::Vec; +use futures::stream::once; + use crate::base::message_builder::AdditionalBuilder; use crate::base::wire::Composer; use crate::base::StreamTarget; @@ -97,15 +100,15 @@ where /// put pre-processors which protect the server against doing too much /// work as early in the chain as possible. #[allow(clippy::type_complexity)] - pub fn preprocess( + pub fn preprocess( &self, request: &Request, - ) -> ControlFlow<(Transaction, usize)> - where - Future: std::future::Future< - Output = Result, ServiceError>, + ) -> ControlFlow<( + impl futures::stream::Stream< + Item = Result, ServiceError>, > + Send, - { + usize, + )> { for (i, p) in self.processors.iter().enumerate() { match p.preprocess(request) { ControlFlow::Continue(()) => { @@ -115,11 +118,8 @@ where ControlFlow::Break(response) => { // Stop pre-processing, return the produced response // (after first applying post-processors to it). - let item = Ok(CallResult::new(response)); - return ControlFlow::Break(( - Transaction::immediate(item), - i, - )); + let item = ready(Ok(CallResult::new(response))); + return ControlFlow::Break((once(item), i)); } } } diff --git a/src/net/server/middleware/processors/mandatory_svc.rs b/src/net/server/middleware/processors/mandatory_svc.rs new file mode 100644 index 000000000..fd25bc616 --- /dev/null +++ b/src/net/server/middleware/processors/mandatory_svc.rs @@ -0,0 +1,527 @@ +//! Core DNS RFC standards based message processing for MUST requirements. +use core::ops::ControlFlow; + +use futures::StreamExt; +use octseq::Octets; +use tracing::{debug, error, trace, warn}; + +use crate::base::iana::{Opcode, Rcode}; +use crate::base::message_builder::{AdditionalBuilder, PushError}; +use crate::base::wire::{Composer, ParseError}; +use crate::base::StreamTarget; +use crate::net::server::message::{Request, TransportSpecificContext}; +use crate::net::server::service::{ + CallResult, Service, ServiceError, Transaction, +}; +use crate::net::server::util::{mk_builder_for_target, start_reply}; +use core::marker::PhantomData; +use std::fmt::Display; + +/// The minimum legal UDP response size in bytes. +/// +/// As defined by [RFC 1035 section 4.2.1]. +/// +/// [RFC 1035 section 4.2.1]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1 +pub const MINIMUM_RESPONSE_BYTE_LEN: u16 = 512; + +/// A [`MiddlewareProcessor`] for enforcing core RFC MUST requirements on +/// processed messages. +/// +/// Standards covered by ths implementation: +/// +/// | RFC | Status | +/// |--------|---------| +/// | [1035] | TBD | +/// | [2181] | TBD | +/// +/// [`MiddlewareProcessor`]: +/// crate::net::server::middleware::processor::MiddlewareProcessor +/// [1035]: https://datatracker.ietf.org/doc/html/rfc1035 +/// [2181]: https://datatracker.ietf.org/doc/html/rfc2181 +#[derive(Debug)] +pub struct MandatoryMiddlewareSvc +where + RequestOctets: AsRef<[u8]>, + S: Service, +{ + /// In strict mode the processor does more checks on requests and + /// responses. + strict: bool, + + inner: S, + + _phantom: PhantomData, +} + +impl MandatoryMiddlewareSvc +where + RequestOctets: Octets, + S: Service, +{ + /// Creates a new processor instance. + /// + /// The processor will operate in strict mode. + #[must_use] + pub fn new(inner: S) -> Self { + Self { + strict: true, + inner, + _phantom: PhantomData, + } + } + + /// Creates a new processor instance. + /// + /// The processor will operate in relaxed mode. + #[must_use] + pub fn relaxed(inner: S) -> Self { + Self { + strict: false, + inner, + _phantom: PhantomData, + } + } + + /// Create a DNS error response to the given request with the given RCODE. + fn error_response( + &self, + request: &Request, + rcode: Rcode, + ) -> AdditionalBuilder> + where + S::Target: Composer + Default, + { + let mut response = start_reply(request); + response.header_mut().set_rcode(rcode); + let mut additional = response.additional(); + self.postprocess(request, &mut additional); + additional + } +} + +impl MandatoryMiddlewareSvc +where + RequestOctets: Octets, + S: Service, +{ + /// Truncate the given response message if it is too large. + /// + /// Honours either a transport supplied hint, if present in the given + /// [`UdpSpecificTransportContext`], as to how large the response is + /// allowed to be, or if missing will instead honour the clients indicated + /// UDP response payload size (if an EDNS OPT is present in the request). + /// + /// Truncation discards the authority and additional sections, except for + /// any OPT record present which will be preserved, then truncates to the + /// specified byte length. + fn truncate( + request: &Request, + response: &mut AdditionalBuilder>, + ) -> Result<(), TruncateError> + where + S::Target: Composer + Default, + { + if let TransportSpecificContext::Udp(ctx) = request.transport_ctx() { + // https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1 + // "Messages carried by UDP are restricted to 512 bytes (not + // counting the IP or UDP headers). Longer messages are + // truncated and the TC bit is set in the header." + let max_response_size = ctx + .max_response_size_hint() + .unwrap_or(MINIMUM_RESPONSE_BYTE_LEN); + let max_response_size = max_response_size as usize; + let response_len = response.as_slice().len(); + + if response_len > max_response_size { + // Truncate per RFC 1035 section 6.2 and RFC 2181 sections 5.1 + // and 9: + // + // https://datatracker.ietf.org/doc/html/rfc1035#section-6.2 + // "When a response is so long that truncation is required, + // the truncation should start at the end of the response + // and work forward in the datagram. Thus if there is any + // data for the authority section, the answer section is + // guaranteed to be unique." + // + // https://datatracker.ietf.org/doc/html/rfc2181#section-5.1 + // "A query for a specific (or non-specific) label, class, + // and type, will always return all records in the + // associated RRSet - whether that be one or more RRs. The + // response must be marked as "truncated" if the entire + // RRSet will not fit in the response." + // + // https://datatracker.ietf.org/doc/html/rfc2181#section-9 + // "Where TC is set, the partial RRSet that would not + // completely fit may be left in the response. When a DNS + // client receives a reply with TC set, it should ignore + // that response, and query again, using a mechanism, such + // as a TCP connection, that will permit larger replies." + // + // https://datatracker.ietf.org/doc/html/rfc6891#section-7 + // "The minimal response MUST be the DNS header, question + // section, and an OPT record. This MUST also occur when + // a truncated response (using the DNS header's TC bit) is + // returned." + + // Tell the client that we are truncating the response. + response.header_mut().set_tc(true); + + // Remember the original length. + let old_len = response.as_slice().len(); + + // Copy the header, question and opt record from the + // additional section, but leave the answer and authority + // sections empty. + let source = response.as_message(); + let mut target = mk_builder_for_target(); + + *target.header_mut() = source.header(); + + let mut target = target.question(); + for rr in source.question() { + target.push(rr?)?; + } + + let mut target = target.additional(); + if let Some(opt) = source.opt() { + if let Err(err) = target.push(opt.as_record()) { + warn!("Error while truncating response: unable to push OPT record: {err}"); + // As the client had an OPT record and RFC 6891 says + // when truncating that there MUST be an OPT record, + // attempt to push just the empty OPT record (as the + // OPT record header still has value, e.g. the + // requestors payload size field and extended rcode). + if let Err(err) = target.opt(|builder| { + builder.set_version(opt.version()); + builder.set_rcode(opt.rcode(response.header())); + builder + .set_udp_payload_size(opt.udp_payload_size()); + Ok(()) + }) { + error!("Error while truncating response: unable to add minimal OPT record: {err}"); + } + } + } + + let new_len = target.as_slice().len(); + trace!("Truncating response from {old_len} bytes to {new_len} bytes"); + + *response = target; + } + } + + Ok(()) + } + + fn preprocess( + &self, + request: &Request, + ) -> ControlFlow>> + where + S::Target: Composer + Default, + { + // https://www.rfc-editor.org/rfc/rfc3425.html + // 3 - Effect on RFC 1035 + // .. + // "Therefore IQUERY is now obsolete, and name servers SHOULD return + // a "Not Implemented" error when an IQUERY request is received." + if self.strict + && request.message().header().opcode() == Opcode::IQUERY + { + debug!( + "RFC 3425 3 violation: request opcode IQUERY is obsolete." + ); + return ControlFlow::Break( + self.error_response(request, Rcode::NOTIMP), + ); + } + + ControlFlow::Continue(()) + } + + fn postprocess( + &self, + request: &Request, + response: &mut AdditionalBuilder>, + ) where + S::Target: Composer + Default, + { + if let Err(err) = Self::truncate(request, response) { + error!("Error while truncating response: {err}"); + *response = self.error_response(request, Rcode::SERVFAIL); + return; + } + + // https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1 + // 4.1.1: Header section format + // + // ID A 16 bit identifier assigned by the program that + // generates any kind of query. This identifier is copied + // the corresponding reply and can be used by the requester + // to match up replies to outstanding queries. + response + .header_mut() + .set_id(request.message().header().id()); + + // QR A one bit field that specifies whether this message is a + // query (0), or a response (1). + response.header_mut().set_qr(true); + + // RD Recursion Desired - this bit may be set in a query and + // is copied into the response. If RD is set, it directs + // the name server to pursue the query recursively. + // Recursive query support is optional. + response + .header_mut() + .set_rd(request.message().header().rd()); + + // https://www.rfc-editor.org/rfc/rfc1035.html + // https://www.rfc-editor.org/rfc/rfc3425.html + // + // All responses shown in RFC 1035 (except those for inverse queries, + // opcode 1, which was obsoleted by RFC 4325) contain the question + // from the request. So we would expect the number of questions in the + // response to match the number of questions in the request. + if self.strict + && !request.message().header_counts().qdcount() + == response.counts().qdcount() + { + warn!("RFC 1035 violation: response question count != request question count"); + } + } +} + +//--- Service + +impl Service + for MandatoryMiddlewareSvc +where + RequestOctets: Octets, + S: Service, + S::Target: Composer + Default, +{ + type Target = S::Target; + + type Stream = S::Stream; + + fn call( + &self, + request: Request, + ) -> Self::Stream { + match self.preprocess(&request) { + ControlFlow::Continue(()) => { + self.inner.call(request).map(|res| { + res.and_then(|cr| { + cr.get_response_mut().and_then(|response| { + self.postprocess(&request, response); + Some(response) + }); + Ok(cr) + }) + }) + } + ControlFlow::Break(mut response) => { + self.postprocess(&request, &mut response); + Ok(Transaction::immediate(Ok(CallResult::new(response)))) + } + } + } +} + +// impl MiddlewareProcessor +// for MandatoryMiddlewareSvc +// where +// RequestOctets: Octets, +// S: Service, +// S::Target: Composer + Default, +// { +// fn preprocess( +// &self, +// request: &Request, +// ) -> ControlFlow>> { +// self.p +// } + +// fn postprocess( +// &self, +// request: &Request, +// response: &mut AdditionalBuilder>, +// ) { +// todo!() +// } +// } + +//------------ TruncateError ------------------------------------------------- + +/// An error occured during oversize response truncation. +enum TruncateError { + /// There was a problem parsing the request, specifically the question + /// section. + InvalidQuestion(ParseError), + + /// There was a problem pushing to the response. + PushFailure(PushError), +} + +impl Display for TruncateError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + TruncateError::InvalidQuestion(err) => { + write!(f, "Unable to parse question: {err}") + } + TruncateError::PushFailure(err) => { + write!(f, "Unable to push into response: {err}") + } + } + } +} + +impl From for TruncateError { + fn from(err: ParseError) -> Self { + Self::InvalidQuestion(err) + } +} + +impl From for TruncateError { + fn from(err: PushError) -> Self { + Self::PushFailure(err) + } +} + +#[cfg(test)] +mod tests { + use core::ops::ControlFlow; + + use std::vec::Vec; + + use bytes::Bytes; + use tokio::time::Instant; + + use crate::base::{Dname, MessageBuilder, Rtype}; + use crate::net::server::message::{ + Request, TransportSpecificContext, UdpTransportContext, + }; + + use super::MandatoryMiddlewareSvc; + use crate::base::iana::{OptionCode, Rcode}; + use crate::net::server::middleware::processor::MiddlewareProcessor; + use crate::net::server::middleware::processors::mandatory::MINIMUM_RESPONSE_BYTE_LEN; + use crate::net::server::service::{ + CallResult, ServiceError, Transaction, + }; + use crate::net::server::util::{mk_builder_for_target, service_fn}; + use core::pin::Pin; + use octseq::OctetsBuilder; + use std::boxed::Box; + use std::future::Future; + + //------------ Constants ------------------------------------------------- + + const MIN_ALLOWED: u16 = MINIMUM_RESPONSE_BYTE_LEN; + const TOO_SMALL: u16 = 511; + const JUST_RIGHT: u16 = MIN_ALLOWED; + const HUGE: u16 = u16::MAX; + + //------------ Tests ----------------------------------------------------- + + #[test] + fn clamp_max_response_size_correctly() { + assert!(process(None) <= Some(MIN_ALLOWED as usize)); + assert!(process(Some(TOO_SMALL)) <= Some(MIN_ALLOWED as usize)); + assert!(process(Some(TOO_SMALL)) <= Some(MIN_ALLOWED as usize)); + assert!(process(Some(TOO_SMALL)) <= Some(MIN_ALLOWED as usize)); + assert!(process(Some(JUST_RIGHT)) <= Some(JUST_RIGHT as usize)); + assert!(process(Some(JUST_RIGHT)) <= Some(JUST_RIGHT as usize)); + assert!(process(Some(JUST_RIGHT)) <= Some(JUST_RIGHT as usize)); + assert!(process(Some(HUGE)) <= Some(HUGE as usize)); + assert!(process(Some(HUGE)) <= Some(HUGE as usize)); + assert!(process(Some(HUGE)) <= Some(HUGE as usize)); + } + + //------------ Helper functions ------------------------------------------ + + // Returns Some(n) if truncation occurred where n is the size after + // truncation. + fn process(max_response_size_hint: Option) -> Option { + // Build a dummy DNS query. + let query = MessageBuilder::new_vec(); + let mut query = query.question(); + query.push((Dname::::root(), Rtype::A)).unwrap(); + let extra_bytes = vec![0; (MIN_ALLOWED as usize) * 2]; + let mut additional = query.additional(); + additional + .opt(|builder| { + builder.push_raw_option( + OptionCode::PADDING, + extra_bytes.len() as u16, + |target| { + target.append_slice(&extra_bytes).unwrap(); + Ok(()) + }, + ) + }) + .unwrap(); + let old_size = additional.as_slice().len(); + let message = additional.into_message(); + + // TODO: Artificially expand the message to be as big as possible + // so that it will get truncated. + + // Package the query into a context aware request to make it look + // as if it came from a UDP server. + let ctx = UdpTransportContext::new(max_response_size_hint); + let request = Request::new( + "127.0.0.1:12345".parse().unwrap(), + Instant::now(), + message, + TransportSpecificContext::Udp(ctx), + ); + + fn my_service( + req: Request>, + _meta: (), + ) -> Result< + Transaction< + Vec, + Pin< + Box< + dyn Future< + Output = Result< + CallResult>, + ServiceError, + >, + >, + >, + >, + >, + ServiceError, + > { + // For each request create a single response: + Ok(Transaction::single(Box::pin(async move { + let builder = mk_builder_for_target(); + let answer = + builder.start_answer(req.message(), Rcode::NXDOMAIN)?; + Ok(CallResult::new(answer.additional())) + }))) + } + + // And pass the query through the middleware processor + let processor = + MandatoryMiddlewareSvc::new(service_fn(my_service, ())); + let processor: &dyn MiddlewareProcessor, Vec> = + &processor; + let mut response = MessageBuilder::new_stream_vec().additional(); + if let ControlFlow::Continue(()) = processor.preprocess(&request) { + processor.postprocess(&request, &mut response); + } + + // Get the response length + let new_size = response.as_slice().len(); + + if new_size < old_size { + Some(new_size) + } else { + None + } + } +} diff --git a/src/net/server/middleware/processors/mod.rs b/src/net/server/middleware/processors/mod.rs index 18635c239..b0add717c 100644 --- a/src/net/server/middleware/processors/mod.rs +++ b/src/net/server/middleware/processors/mod.rs @@ -5,3 +5,4 @@ pub mod cookies; pub mod edns; pub mod mandatory; +pub mod mandatory_svc; diff --git a/src/net/server/mod.rs b/src/net/server/mod.rs index 13c8c34a7..6381d895f 100644 --- a/src/net/server/mod.rs +++ b/src/net/server/mod.rs @@ -209,18 +209,18 @@ #![cfg(feature = "unstable-server-transport")] #![cfg_attr(docsrs, doc(cfg(feature = "unstable-server-transport")))] -mod connection; -pub use connection::Config as ConnectionConfig; +// mod connection; +// pub use connection::Config as ConnectionConfig; pub mod buf; pub mod dgram; pub mod error; pub mod message; pub mod metrics; -pub mod middleware; +// pub mod middleware; pub mod service; pub mod sock; -pub mod stream; +// pub mod stream; pub mod util; #[cfg(test)] diff --git a/src/net/server/service.rs b/src/net/server/service.rs index 57625025f..81a5f7f6f 100644 --- a/src/net/server/service.rs +++ b/src/net/server/service.rs @@ -203,18 +203,18 @@ pub trait Service = Vec> { /// The type of buffer in which response messages are stored. type Target; - /// The type of future returned by [`Service::call`] via - /// [`Transaction::single`]. - type Future: std::future::Future< - Output = Result, ServiceError>, - >; + /// The type of future returned by [`Service::call()`] via + /// [`Transaction::single()`]. + // type Item: ; /// Generate a response to a fully pre-processed request. #[allow(clippy::type_complexity)] fn call( &self, request: Request, - ) -> Result, ServiceError>; + ) -> impl futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + Unpin; } /// Helper trait impl to treat an [`Arc`] as a [`Service`]. @@ -222,34 +222,36 @@ impl, T: Service> Service for Arc { type Target = T::Target; - type Future = T::Future; fn call( &self, request: Request, - ) -> Result, ServiceError> { + ) -> impl futures::stream::Stream< + Item = Result, ServiceError>, + > { Arc::deref(self).call(request) } } /// Helper trait impl to treat a function as a [`Service`]. -impl Service for F +impl Service for F where + RequestOctets: AsRef<[u8]>, F: Fn( Request, - ) -> Result, ServiceError>, - RequestOctets: AsRef<[u8]>, - Future: std::future::Future< - Output = Result, ServiceError>, - >, + ) -> Stream, + Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + Unpin { type Target = Target; - type Future = Future; fn call( &self, request: Request, - ) -> Result, ServiceError> { + ) -> impl futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + Unpin { (*self)(request) } } diff --git a/src/net/server/stream.rs b/src/net/server/stream.rs index 09d813e00..d642b84e5 100644 --- a/src/net/server/stream.rs +++ b/src/net/server/stream.rs @@ -462,7 +462,7 @@ where Listener::StreamType: AsyncRead + AsyncWrite + Send + Sync + 'static, Svc: 'static, Svc::Target: Send + Sync, - Svc::Future: Send, + Svc::Stream: Send, { if let Err(err) = self.run_until_error().await { error!("Server stopped due to error: {err}"); @@ -552,7 +552,7 @@ where Listener::Future: Send + 'static, Listener::StreamType: AsyncRead + AsyncWrite + Send + Sync + 'static, Svc: 'static, - Svc::Future: Send, + Svc::Stream: Send, Svc::Target: Send + Sync + 'static, { let mut command_rx = self.command_rx.clone(); @@ -675,7 +675,7 @@ where Listener::Future: Send + 'static, Listener::StreamType: AsyncRead + AsyncWrite + Send + Sync + 'static, Svc: 'static, - Svc::Future: Send, + Svc::Stream: Send, Svc::Target: Send + Sync + 'static, { // Work around the compiler wanting to move self to the async block by diff --git a/src/net/server/tests.rs b/src/net/server/tests.rs index 335740bec..2e2021cb7 100644 --- a/src/net/server/tests.rs +++ b/src/net/server/tests.rs @@ -1,4 +1,3 @@ -use core::future::Future; use core::pin::Pin; use core::str::FromStr; use core::sync::atomic::{AtomicBool, Ordering}; @@ -23,11 +22,9 @@ use crate::base::StreamTarget; use super::buf::BufSource; use super::message::Request; -use super::service::{ - CallResult, Service, ServiceError, ServiceFeedback, Transaction, -}; +use super::service::{CallResult, Service, ServiceError, ServiceFeedback}; use super::sock::AsyncAccept; -use super::stream::StreamServer; +// use super::stream::StreamServer; /// Mock I/O which supplies a sequence of mock messages to the server at a /// defined rate. @@ -267,13 +264,13 @@ impl BufSource for MockBufSource { /// it is possible to define your own. struct MySingle; -impl Future for MySingle { - type Output = Result>, ServiceError>; +impl futures::stream::Stream for MySingle { + type Item = Result>, ServiceError>; - fn poll( + fn poll_next( self: Pin<&mut Self>, _cx: &mut Context<'_>, - ) -> Poll { + ) -> Poll> { let builder = MessageBuilder::new_stream_vec(); let response = builder.additional(); @@ -283,7 +280,7 @@ impl Future for MySingle { let call_result = CallResult::new(response).with_feedback(command); - Poll::Ready(Ok(call_result)) + Poll::Ready(Some(Ok(call_result))) } } @@ -299,13 +296,10 @@ impl MyService { impl Service> for MyService { type Target = Vec; - type Future = MySingle; + // type Stream = MySingle; - fn call( - &self, - _msg: Request>, - ) -> Result, ServiceError> { - Ok(Transaction::single(MySingle)) + fn call(&self, _msg: Request>) -> MySingle { + MySingle } } @@ -340,89 +334,89 @@ fn mk_query() -> StreamTarget> { // signal that time has passed when in fact it actually hasn't, allowing a // time dependent test to run much faster without actual periods of // waiting to allow time to elapse. -#[tokio::test(flavor = "current_thread", start_paused = true)] -async fn service_test() { - let (srv_handle, server_status_printer_handle) = { - let fast_client = MockClientConfig { - new_message_every: Duration::from_millis(100), - messages: VecDeque::from([ - mk_query().as_stream_slice().to_vec(), - mk_query().as_stream_slice().to_vec(), - mk_query().as_stream_slice().to_vec(), - mk_query().as_stream_slice().to_vec(), - mk_query().as_stream_slice().to_vec(), - ]), - client_port: 1, - }; - let slow_client = MockClientConfig { - new_message_every: Duration::from_millis(3000), - messages: VecDeque::from([ - mk_query().as_stream_slice().to_vec(), - mk_query().as_stream_slice().to_vec(), - ]), - client_port: 2, - }; - let num_messages = - fast_client.messages.len() + slow_client.messages.len(); - let streams_to_read = VecDeque::from([fast_client, slow_client]); - let new_client_every = Duration::from_millis(2000); - let listener = MockListener::new(streams_to_read, new_client_every); - let ready_flag = listener.get_ready_flag(); - - let buf = MockBufSource; - let my_service = Arc::new(MyService::new()); - let srv = - Arc::new(StreamServer::new(listener, buf, my_service.clone())); - - let metrics = srv.metrics(); - let server_status_printer_handle = tokio::spawn(async move { - loop { - sleep(Duration::from_millis(250)).await; - eprintln!( - "Server status: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}", - metrics.num_connections(), - metrics.num_inflight_requests(), - metrics.num_pending_writes(), - metrics.num_received_requests(), - metrics.num_sent_responses(), - ); - } - }); - - let spawned_srv = srv.clone(); - let srv_handle = tokio::spawn(async move { spawned_srv.run().await }); - - eprintln!("Clients sleeping"); - sleep(Duration::from_secs(1)).await; - - eprintln!("Clients connecting"); - ready_flag.store(true, Ordering::Relaxed); - - // Simulate a wait long enough that all simulated clients had time - // to connect, communicate and disconnect. - sleep(Duration::from_secs(20)).await; - - // Verify that all simulated clients connected. - assert_eq!(0, srv.source().streams_remaining()); - - // Verify that no requests or responses are in progress still in - // the server. - assert_eq!(srv.metrics().num_connections(), 0); - assert_eq!(srv.metrics().num_inflight_requests(), 0); - assert_eq!(srv.metrics().num_pending_writes(), 0); - assert_eq!(srv.metrics().num_received_requests(), num_messages); - assert_eq!(srv.metrics().num_sent_responses(), num_messages); - - eprintln!("Shutting down"); - srv.shutdown().unwrap(); - eprintln!("Shutdown command sent"); - - (srv_handle, server_status_printer_handle) - }; - - eprintln!("Waiting for service to shutdown"); - let _ = srv_handle.await; - - // Terminate the task that periodically prints the server status - server_status_printer_handle.abort(); -} +// #[tokio::test(flavor = "current_thread", start_paused = true)] +// async fn service_test() { +// let (srv_handle, server_status_printer_handle) = { +// let fast_client = MockClientConfig { +// new_message_every: Duration::from_millis(100), +// messages: VecDeque::from([ +// mk_query().as_stream_slice().to_vec(), +// mk_query().as_stream_slice().to_vec(), +// mk_query().as_stream_slice().to_vec(), +// mk_query().as_stream_slice().to_vec(), +// mk_query().as_stream_slice().to_vec(), +// ]), +// client_port: 1, +// }; +// let slow_client = MockClientConfig { +// new_message_every: Duration::from_millis(3000), +// messages: VecDeque::from([ +// mk_query().as_stream_slice().to_vec(), +// mk_query().as_stream_slice().to_vec(), +// ]), +// client_port: 2, +// }; +// let num_messages = +// fast_client.messages.len() + slow_client.messages.len(); +// let streams_to_read = VecDeque::from([fast_client, slow_client]); +// let new_client_every = Duration::from_millis(2000); +// let listener = MockListener::new(streams_to_read, new_client_every); +// let ready_flag = listener.get_ready_flag(); + +// let buf = MockBufSource; +// let my_service = Arc::new(MyService::new()); +// let srv = +// Arc::new(StreamServer::new(listener, buf, my_service.clone())); + +// let metrics = srv.metrics(); +// let server_status_printer_handle = tokio::spawn(async move { +// loop { +// sleep(Duration::from_millis(250)).await; +// eprintln!( +// "Server status: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}", +// metrics.num_connections(), +// metrics.num_inflight_requests(), +// metrics.num_pending_writes(), +// metrics.num_received_requests(), +// metrics.num_sent_responses(), +// ); +// } +// }); + +// let spawned_srv = srv.clone(); +// let srv_handle = tokio::spawn(async move { spawned_srv.run().await }); + +// eprintln!("Clients sleeping"); +// sleep(Duration::from_secs(1)).await; + +// eprintln!("Clients connecting"); +// ready_flag.store(true, Ordering::Relaxed); + +// // Simulate a wait long enough that all simulated clients had time +// // to connect, communicate and disconnect. +// sleep(Duration::from_secs(20)).await; + +// // Verify that all simulated clients connected. +// assert_eq!(0, srv.source().streams_remaining()); + +// // Verify that no requests or responses are in progress still in +// // the server. +// assert_eq!(srv.metrics().num_connections(), 0); +// assert_eq!(srv.metrics().num_inflight_requests(), 0); +// assert_eq!(srv.metrics().num_pending_writes(), 0); +// assert_eq!(srv.metrics().num_received_requests(), num_messages); +// assert_eq!(srv.metrics().num_sent_responses(), num_messages); + +// eprintln!("Shutting down"); +// srv.shutdown().unwrap(); +// eprintln!("Shutdown command sent"); + +// (srv_handle, server_status_printer_handle) +// }; + +// eprintln!("Waiting for service to shutdown"); +// let _ = srv_handle.await; + +// // Terminate the task that periodically prints the server status +// server_status_printer_handle.abort(); +// } diff --git a/src/net/server/util.rs b/src/net/server/util.rs index 5a38d1acc..d90e8250d 100644 --- a/src/net/server/util.rs +++ b/src/net/server/util.rs @@ -98,21 +98,17 @@ where /// [`Vec`]: std::vec::Vec /// [`CallResult`]: crate::net::server::service::CallResult /// [`Result::Ok`]: std::result::Result::Ok -pub fn service_fn( +pub fn service_fn( request_handler: T, metadata: Metadata, -) -> impl Service + Clone +) -> impl Service + Clone where RequestOctets: AsRef<[u8]>, - Future: std::future::Future< - Output = Result, ServiceError>, - >, + Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + Unpin, Metadata: Clone, - T: Fn( - Request, - Metadata, - ) -> Result, ServiceError> - + Clone, + T: Fn(Request, Metadata) -> Stream + Clone, { move |request| request_handler(request, metadata.clone()) } From 4aec8bd41acfa4725429c9d46c2a466850f9b413 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Mon, 15 Apr 2024 16:11:12 +0200 Subject: [PATCH 02/28] WIP: Working Box based Service andmandatory middleware, with stream server and other middleware disabled. --- examples/serve-zone.rs | 85 ++++-- src/net/server/dgram.rs | 27 ++ .../middleware/processors/mandatory_svc.rs | 254 ++++++++++-------- src/net/server/mod.rs | 2 +- src/net/server/service.rs | 34 +-- src/net/server/tests.rs | 2 +- src/net/server/util.rs | 8 +- 7 files changed, 253 insertions(+), 159 deletions(-) diff --git a/examples/serve-zone.rs b/examples/serve-zone.rs index 0390028e2..8d0e6f503 100644 --- a/examples/serve-zone.rs +++ b/examples/serve-zone.rs @@ -27,10 +27,12 @@ use domain::net::server::util::{mk_builder_for_target, service_fn}; use domain::zonefile::inplace; use domain::zonetree::{Answer, Rrset}; use domain::zonetree::{Zone, ZoneTree}; -use futures::stream::{once, FuturesOrdered}; +use futures::stream::{once, FuturesOrdered, Once}; +use futures::StreamExt; use octseq::OctetsBuilder; use std::future::{pending, ready, Future}; use std::io::BufReader; +use std::ops::DerefMut; use std::pin::Pin; use std::sync::{Arc, Mutex}; use std::time::Duration; @@ -109,27 +111,60 @@ async fn main() { pending::<()>().await; } +enum SingleOrStream { + Single( + Once< + Pin< + Box< + dyn std::future::Future< + Output = Result< + CallResult>, + ServiceError, + >, + > + Send, + >, + >, + >, + ), + + Stream( + Box< + dyn futures::stream::Stream< + Item = Result>, ServiceError>, + > + Unpin + Send, + >, + ), +} + +impl futures::stream::Stream for SingleOrStream { + type Item = Result>, ServiceError>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + match self.deref_mut() { + SingleOrStream::Single(s) => s.poll_next_unpin(cx), + SingleOrStream::Stream(s) => s.poll_next_unpin(cx), + } + } +} + #[allow(clippy::type_complexity)] fn my_service( request: Request>, zones: Arc, -) -> - Box< - dyn futures::stream::Stream< - Item = Result>, ServiceError>, - > + Send - + Unpin, - > - { +) -> SingleOrStream { let qtype = request.message().sole_question().unwrap().qtype(); match qtype { Rtype::AXFR if request.transport_ctx().is_non_udp() => { - Box::new(handle_axfr_request(request, zones)) - } - _ => { - let fut = Box::pin(handle_non_axfr_request(request, zones)); - Box::new(once(fut)) + SingleOrStream::Stream(Box::new(handle_axfr_request( + request, zones, + ))) } + _ => SingleOrStream::Single(once(Box::pin(handle_non_axfr_request( + request, zones, + )))), } } @@ -159,15 +194,19 @@ fn handle_axfr_request( request: Request>, zones: Arc, ) -> FuturesOrdered< - Pin>, ServiceError>> + Send>>, + Pin< + Box< + dyn Future>, ServiceError>> + + Send, + >, + >, > { // let mut stream = TransactionStream::default(); let mut stream = FuturesOrdered::< Pin< Box< - dyn Future< - Output = Result>, ServiceError>, - > + Send, + dyn Future>, ServiceError>> + + Send, >, >, >::new(); @@ -275,9 +314,8 @@ fn add_to_stream( stream: &mut FuturesOrdered< Pin< Box< - dyn Future< - Output = Result>, ServiceError>, - > + Send, + dyn Future>, ServiceError>> + + Send, >, >, >, @@ -294,9 +332,8 @@ fn add_additional_to_stream( stream: &mut FuturesOrdered< Pin< Box< - dyn Future< - Output = Result>, ServiceError>, - > + Send, + dyn Future>, ServiceError>> + + Send, >, >, >, diff --git a/src/net/server/dgram.rs b/src/net/server/dgram.rs index f9188da99..8ab6a7e9f 100644 --- a/src/net/server/dgram.rs +++ b/src/net/server/dgram.rs @@ -46,6 +46,7 @@ use crate::utils::config::DefMinMax; use super::buf::VecBufSource; use super::message::{TransportSpecificContext, UdpTransportContext}; +use super::service::ServiceError; // use super::middleware::builder::MiddlewareBuilder; use super::ServerCommand; use crate::base::wire::Composer; @@ -302,6 +303,10 @@ where Buf::Output: Octets + Send + Sync, Svc: Service + Send + Sync + 'static + Clone, Svc::Target: Send + Composer + Default, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, { /// The configuration of the server. config: Arc*/>>, @@ -340,6 +345,10 @@ where Buf::Output: Octets + Send + Sync, Svc: Service + Send + Sync + Clone, Svc::Target: Send + Composer + Default, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, { /// Constructs a new [`DgramServer`] with default configuration. /// @@ -395,6 +404,10 @@ where Buf::Output: Octets + Send + Sync + 'static + Debug, Svc: Service + Send + Sync + 'static + Clone, Svc::Target: Send + Composer + Debug + Default, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, { /// Get a reference to the network source being used to receive messages. #[must_use] @@ -418,6 +431,10 @@ where Buf::Output: Octets + Send + Sync + 'static, Svc: Service + Send + Sync + 'static + Clone, Svc::Target: Send + Composer + Default, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, { /// Start the server. /// @@ -506,6 +523,10 @@ where Buf::Output: Octets + Send + Sync + 'static, Svc: Service + Send + Sync + 'static + Clone, Svc::Target: Send + Composer + Default, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, { /// Receive incoming messages until shutdown or fatal error. async fn run_until_error(&self) -> Result<(), String> @@ -593,6 +614,8 @@ where trace!(%addr, pcap_text, "Sending response"); } + metrics.inc_num_pending_writes(); + // Actually write the DNS response message bytes to the UDP // socket. let _ = Self::send_to( @@ -804,6 +827,10 @@ where Buf::Output: Octets + Send + Sync, Svc: Service + Send + Sync + 'static + Clone, Svc::Target: Send + Composer + Default, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, { fn drop(&mut self) { // Shutdown the DgramServer. Don't handle the failure case here as diff --git a/src/net/server/middleware/processors/mandatory_svc.rs b/src/net/server/middleware/processors/mandatory_svc.rs index fd25bc616..440322a70 100644 --- a/src/net/server/middleware/processors/mandatory_svc.rs +++ b/src/net/server/middleware/processors/mandatory_svc.rs @@ -1,7 +1,13 @@ //! Core DNS RFC standards based message processing for MUST requirements. -use core::ops::ControlFlow; +use core::future::{ready, Ready}; +use core::marker::PhantomData; +use core::ops::{ControlFlow, DerefMut}; + +use std::boxed::Box; +use std::fmt::Display; -use futures::StreamExt; +use futures::stream::{once, Once}; +use futures::{Stream, StreamExt}; use octseq::Octets; use tracing::{debug, error, trace, warn}; @@ -10,12 +16,8 @@ use crate::base::message_builder::{AdditionalBuilder, PushError}; use crate::base::wire::{Composer, ParseError}; use crate::base::StreamTarget; use crate::net::server::message::{Request, TransportSpecificContext}; -use crate::net::server::service::{ - CallResult, Service, ServiceError, Transaction, -}; +use crate::net::server::service::{CallResult, Service, ServiceError}; use crate::net::server::util::{mk_builder_for_target, start_reply}; -use core::marker::PhantomData; -use std::fmt::Display; /// The minimum legal UDP response size in bytes. /// @@ -39,24 +41,18 @@ pub const MINIMUM_RESPONSE_BYTE_LEN: u16 = 512; /// [1035]: https://datatracker.ietf.org/doc/html/rfc1035 /// [2181]: https://datatracker.ietf.org/doc/html/rfc2181 #[derive(Debug)] -pub struct MandatoryMiddlewareSvc -where - RequestOctets: AsRef<[u8]>, - S: Service, -{ +pub struct MandatoryMiddlewareSvc { /// In strict mode the processor does more checks on requests and /// responses. strict: bool, inner: S, - _phantom: PhantomData, + _phantom: PhantomData<(RequestOctets, Target)>, } -impl MandatoryMiddlewareSvc -where - RequestOctets: Octets, - S: Service, +impl + MandatoryMiddlewareSvc { /// Creates a new processor instance. /// @@ -84,25 +80,24 @@ where /// Create a DNS error response to the given request with the given RCODE. fn error_response( - &self, request: &Request, rcode: Rcode, - ) -> AdditionalBuilder> + strict: bool, + ) -> AdditionalBuilder> where - S::Target: Composer + Default, + RequestOctets: Octets, + Target: Composer + Default, { let mut response = start_reply(request); response.header_mut().set_rcode(rcode); let mut additional = response.additional(); - self.postprocess(request, &mut additional); + Self::postprocess(request, &mut additional, strict); additional } } -impl MandatoryMiddlewareSvc -where - RequestOctets: Octets, - S: Service, +impl + MandatoryMiddlewareSvc { /// Truncate the given response message if it is too large. /// @@ -116,10 +111,11 @@ where /// specified byte length. fn truncate( request: &Request, - response: &mut AdditionalBuilder>, + response: &mut AdditionalBuilder>, ) -> Result<(), TruncateError> where - S::Target: Composer + Default, + RequestOctets: Octets, + Target: Composer + Default, { if let TransportSpecificContext::Udp(ctx) = request.transport_ctx() { // https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1 @@ -216,9 +212,10 @@ where fn preprocess( &self, request: &Request, - ) -> ControlFlow>> + ) -> ControlFlow>> where - S::Target: Composer + Default, + RequestOctets: Octets, + Target: Composer + Default, { // https://www.rfc-editor.org/rfc/rfc3425.html // 3 - Effect on RFC 1035 @@ -231,24 +228,28 @@ where debug!( "RFC 3425 3 violation: request opcode IQUERY is obsolete." ); - return ControlFlow::Break( - self.error_response(request, Rcode::NOTIMP), - ); + return ControlFlow::Break(Self::error_response( + request, + Rcode::NOTIMP, + self.strict, + )); } ControlFlow::Continue(()) } fn postprocess( - &self, request: &Request, - response: &mut AdditionalBuilder>, + response: &mut AdditionalBuilder>, + strict: bool, ) where - S::Target: Composer + Default, + RequestOctets: Octets, + Target: Composer + Default, { if let Err(err) = Self::truncate(request, response) { error!("Error while truncating response: {err}"); - *response = self.error_response(request, Rcode::SERVFAIL); + *response = + Self::error_response(request, Rcode::SERVFAIL, strict); return; } @@ -282,7 +283,7 @@ where // opcode 1, which was obsoleted by RFC 4325) contain the question // from the request. So we would expect the number of questions in the // response to match the number of questions in the request. - if self.strict + if strict && !request.message().header_counts().qdcount() == response.counts().qdcount() { @@ -293,36 +294,81 @@ where //--- Service -impl Service - for MandatoryMiddlewareSvc +pub enum MiddlewareStream where - RequestOctets: Octets, - S: Service, - S::Target: Composer + Default, + S: futures::stream::Stream< + Item = Result, ServiceError>, + > + Unpin, { - type Target = S::Target; + Continue(S), + Map( + Box< + dyn futures::stream::Stream< + Item = Result, ServiceError>, + > + Unpin, + >, + ), + BreakOne(Once, ServiceError>>>), +} + +impl Stream for MiddlewareStream +where + S: futures::stream::Stream< + Item = Result, ServiceError>, + > + Unpin, +{ + type Item = Result, ServiceError>; + + fn poll_next( + mut self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + match self.deref_mut() { + MiddlewareStream::Continue(s) => s.poll_next_unpin(cx), + MiddlewareStream::Map(s) => s.poll_next_unpin(cx), + MiddlewareStream::BreakOne(s) => s.poll_next_unpin(cx), + } + } +} - type Stream = S::Stream; +impl Service + for MandatoryMiddlewareSvc +where + RequestOctets: Octets + 'static, + S: Service, + S::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Unpin + + 'static, + Target: Composer + Default + 'static, +{ + type Target = Target; + type Stream = MiddlewareStream; - fn call( - &self, - request: Request, - ) -> Self::Stream { + fn call(&self, request: Request) -> Self::Stream { match self.preprocess(&request) { ControlFlow::Continue(()) => { - self.inner.call(request).map(|res| { - res.and_then(|cr| { - cr.get_response_mut().and_then(|response| { - self.postprocess(&request, response); - Some(response) - }); - Ok(cr) - }) - }) + let strict = self.strict; + let cloned_request = request.clone(); + let map = self.inner.call(request).map(move |mut res| { + if let Ok(cr) = &mut res { + if let Some(response) = cr.get_response_mut() { + Self::postprocess( + &cloned_request, + response, + strict, + ); + } + } + res + }); + MiddlewareStream::Map(Box::new(map)) } ControlFlow::Break(mut response) => { - self.postprocess(&request, &mut response); - Ok(Transaction::immediate(Ok(CallResult::new(response)))) + Self::postprocess(&request, &mut response, self.strict); + MiddlewareStream::BreakOne(once(ready(Ok(CallResult::new( + response, + ))))) } } } @@ -390,30 +436,27 @@ impl From for TruncateError { #[cfg(test)] mod tests { - use core::ops::ControlFlow; + use core::pin::Pin; + use std::boxed::Box; use std::vec::Vec; use bytes::Bytes; + use futures::stream::Once; + use futures::StreamExt; + use octseq::OctetsBuilder; use tokio::time::Instant; + use super::MandatoryMiddlewareSvc; + + use crate::base::iana::{OptionCode, Rcode}; use crate::base::{Dname, MessageBuilder, Rtype}; use crate::net::server::message::{ Request, TransportSpecificContext, UdpTransportContext, }; - - use super::MandatoryMiddlewareSvc; - use crate::base::iana::{OptionCode, Rcode}; - use crate::net::server::middleware::processor::MiddlewareProcessor; use crate::net::server::middleware::processors::mandatory::MINIMUM_RESPONSE_BYTE_LEN; - use crate::net::server::service::{ - CallResult, ServiceError, Transaction, - }; + use crate::net::server::service::{CallResult, Service, ServiceError}; use crate::net::server::util::{mk_builder_for_target, service_fn}; - use core::pin::Pin; - use octseq::OctetsBuilder; - use std::boxed::Box; - use std::future::Future; //------------ Constants ------------------------------------------------- @@ -424,25 +467,25 @@ mod tests { //------------ Tests ----------------------------------------------------- - #[test] - fn clamp_max_response_size_correctly() { - assert!(process(None) <= Some(MIN_ALLOWED as usize)); - assert!(process(Some(TOO_SMALL)) <= Some(MIN_ALLOWED as usize)); - assert!(process(Some(TOO_SMALL)) <= Some(MIN_ALLOWED as usize)); - assert!(process(Some(TOO_SMALL)) <= Some(MIN_ALLOWED as usize)); - assert!(process(Some(JUST_RIGHT)) <= Some(JUST_RIGHT as usize)); - assert!(process(Some(JUST_RIGHT)) <= Some(JUST_RIGHT as usize)); - assert!(process(Some(JUST_RIGHT)) <= Some(JUST_RIGHT as usize)); - assert!(process(Some(HUGE)) <= Some(HUGE as usize)); - assert!(process(Some(HUGE)) <= Some(HUGE as usize)); - assert!(process(Some(HUGE)) <= Some(HUGE as usize)); + #[tokio::test] + async fn clamp_max_response_size_correctly() { + assert!(process(None).await <= Some(MIN_ALLOWED as usize)); + assert!(process(Some(TOO_SMALL)).await <= Some(MIN_ALLOWED as usize)); + assert!(process(Some(TOO_SMALL)).await <= Some(MIN_ALLOWED as usize)); + assert!(process(Some(TOO_SMALL)).await <= Some(MIN_ALLOWED as usize)); + assert!(process(Some(JUST_RIGHT)).await <= Some(JUST_RIGHT as usize)); + assert!(process(Some(JUST_RIGHT)).await <= Some(JUST_RIGHT as usize)); + assert!(process(Some(JUST_RIGHT)).await <= Some(JUST_RIGHT as usize)); + assert!(process(Some(HUGE)).await <= Some(HUGE as usize)); + assert!(process(Some(HUGE)).await <= Some(HUGE as usize)); + assert!(process(Some(HUGE)).await <= Some(HUGE as usize)); } //------------ Helper functions ------------------------------------------ // Returns Some(n) if truncation occurred where n is the size after // truncation. - fn process(max_response_size_hint: Option) -> Option { + async fn process(max_response_size_hint: Option) -> Option { // Build a dummy DNS query. let query = MessageBuilder::new_vec(); let mut query = query.question(); @@ -480,40 +523,39 @@ mod tests { fn my_service( req: Request>, _meta: (), - ) -> Result< - Transaction< - Vec, - Pin< - Box< - dyn Future< + ) -> Once< + Pin< + Box< + dyn std::future::Future< Output = Result< CallResult>, ServiceError, >, - >, - >, + > + Send, >, >, - ServiceError, > { // For each request create a single response: - Ok(Transaction::single(Box::pin(async move { + let msg = req.message().clone(); + futures::stream::once(Box::pin(async move { let builder = mk_builder_for_target(); - let answer = - builder.start_answer(req.message(), Rcode::NXDOMAIN)?; + let answer = builder.start_answer(&msg, Rcode::NXDOMAIN)?; Ok(CallResult::new(answer.additional())) - }))) + })) } - // And pass the query through the middleware processor - let processor = - MandatoryMiddlewareSvc::new(service_fn(my_service, ())); - let processor: &dyn MiddlewareProcessor, Vec> = - &processor; - let mut response = MessageBuilder::new_stream_vec().additional(); - if let ControlFlow::Continue(()) = processor.preprocess(&request) { - processor.postprocess(&request, &mut response); - } + // Either call the service directly. + let my_svc = service_fn(my_service, ()); + let mut stream = my_svc.call(request.clone()); + let _call_result: CallResult> = + stream.next().await.unwrap().unwrap(); + + // Or pass the query through the middleware processor + let processor_svc = MandatoryMiddlewareSvc::new(my_svc); + let mut stream = processor_svc.call(request); + let call_result: CallResult> = + stream.next().await.unwrap().unwrap(); + let (response, _feedback) = call_result.into_inner(); // Get the response length let new_size = response.as_slice().len(); diff --git a/src/net/server/mod.rs b/src/net/server/mod.rs index 6381d895f..90246cc79 100644 --- a/src/net/server/mod.rs +++ b/src/net/server/mod.rs @@ -217,7 +217,7 @@ pub mod dgram; pub mod error; pub mod message; pub mod metrics; -// pub mod middleware; +pub mod middleware; pub mod service; pub mod sock; // pub mod stream; diff --git a/src/net/server/service.rs b/src/net/server/service.rs index 81a5f7f6f..60ae8e226 100644 --- a/src/net/server/service.rs +++ b/src/net/server/service.rs @@ -203,18 +203,15 @@ pub trait Service = Vec> { /// The type of buffer in which response messages are stored. type Target; + type Stream; + /// The type of future returned by [`Service::call()`] via /// [`Transaction::single()`]. // type Item: ; /// Generate a response to a fully pre-processed request. #[allow(clippy::type_complexity)] - fn call( - &self, - request: Request, - ) -> impl futures::stream::Stream< - Item = Result, ServiceError>, - > + Send + Unpin; + fn call(&self, request: Request) -> Self::Stream; } /// Helper trait impl to treat an [`Arc`] as a [`Service`]. @@ -222,13 +219,9 @@ impl, T: Service> Service for Arc { type Target = T::Target; + type Stream = T::Stream; - fn call( - &self, - request: Request, - ) -> impl futures::stream::Stream< - Item = Result, ServiceError>, - > { + fn call(&self, request: Request) -> Self::Stream { Arc::deref(self).call(request) } } @@ -237,21 +230,16 @@ impl, T: Service> impl Service for F where RequestOctets: AsRef<[u8]>, - F: Fn( - Request, - ) -> Stream, + F: Fn(Request) -> Stream, Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send + Unpin + Item = Result, ServiceError>, + > + Send + + 'static, { type Target = Target; + type Stream = Stream; - fn call( - &self, - request: Request, - ) -> impl futures::stream::Stream< - Item = Result, ServiceError>, - > + Send + Unpin { + fn call(&self, request: Request) -> Self::Stream { (*self)(request) } } diff --git a/src/net/server/tests.rs b/src/net/server/tests.rs index 2e2021cb7..6687b13f0 100644 --- a/src/net/server/tests.rs +++ b/src/net/server/tests.rs @@ -296,7 +296,7 @@ impl MyService { impl Service> for MyService { type Target = Vec; - // type Stream = MySingle; + type Stream = MySingle; fn call(&self, _msg: Request>) -> MySingle { MySingle diff --git a/src/net/server/util.rs b/src/net/server/util.rs index d90e8250d..5d05ab76c 100644 --- a/src/net/server/util.rs +++ b/src/net/server/util.rs @@ -14,7 +14,7 @@ use crate::base::{MessageBuilder, ParsedDname, Rtype, StreamTarget}; use crate::rdata::AllRecordData; use super::message::Request; -use super::service::{CallResult, Service, ServiceError, Transaction}; +use super::service::{CallResult, Service, ServiceError}; use crate::base::iana::Rcode; //----------- mk_builder_for_target() ---------------------------------------- @@ -101,12 +101,12 @@ where pub fn service_fn( request_handler: T, metadata: Metadata, -) -> impl Service + Clone +) -> impl Service + Clone where RequestOctets: AsRef<[u8]>, Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send + Unpin, + Item = Result, ServiceError>, + > + Send + 'static, Metadata: Clone, T: Fn(Request, Metadata) -> Stream + Clone, { From 724d670cb09774efebb68db8b2db87424f1b786c Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Mon, 15 Apr 2024 21:58:21 +0200 Subject: [PATCH 03/28] WIP: Working Box based Service and mandatory middleware, with stream server re-enabled but still no other middleware. --- examples/serve-zone.rs | 32 +++---- src/net/server/connection.rs | 174 ++++++++++++++--------------------- src/net/server/mod.rs | 6 +- src/net/server/stream.rs | 89 ++++++++++-------- src/net/server/util.rs | 3 +- 5 files changed, 138 insertions(+), 166 deletions(-) diff --git a/examples/serve-zone.rs b/examples/serve-zone.rs index 8d0e6f503..7f6e7d66f 100644 --- a/examples/serve-zone.rs +++ b/examples/serve-zone.rs @@ -22,7 +22,7 @@ use domain::net::server::buf::VecBufSource; use domain::net::server::dgram::DgramServer; use domain::net::server::message::Request; use domain::net::server::service::{CallResult, ServiceError}; -// use domain::net::server::stream::StreamServer; +use domain::net::server::stream::StreamServer; use domain::net::server::util::{mk_builder_for_target, service_fn}; use domain::zonefile::inplace; use domain::zonetree::{Answer, Rrset}; @@ -75,14 +75,14 @@ async fn main() { DgramServer::new(sock.clone(), VecBufSource, svc.clone()); let metrics = udp_srv.metrics(); udp_metrics.push(metrics); - tokio::spawn(async move { udp_srv.run().await }); + tokio::spawn(udp_srv.run()); } - // let sock = TcpListener::bind(addr).await.unwrap(); - // let tcp_srv = StreamServer::new(sock, VecBufSource, svc); - // let tcp_metrics = tcp_srv.metrics(); + let sock = TcpListener::bind(addr).await.unwrap(); + let tcp_srv = StreamServer::new(sock, VecBufSource, svc); + let tcp_metrics = tcp_srv.metrics(); - // tokio::spawn(async move { tcp_srv.run().await }); + tokio::spawn(tcp_srv.run()); tokio::spawn(async move { loop { @@ -97,14 +97,14 @@ async fn main() { metrics.num_sent_responses(), ); } - // eprintln!( - // "Server status: TCP: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}", - // tcp_metrics.num_connections(), - // tcp_metrics.num_inflight_requests(), - // tcp_metrics.num_pending_writes(), - // tcp_metrics.num_received_requests(), - // tcp_metrics.num_sent_responses(), - // ); + eprintln!( + "Server status: TCP: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}", + tcp_metrics.num_connections(), + tcp_metrics.num_inflight_requests(), + tcp_metrics.num_pending_writes(), + tcp_metrics.num_received_requests(), + tcp_metrics.num_sent_responses(), + ); } }); @@ -131,7 +131,8 @@ enum SingleOrStream { Box< dyn futures::stream::Stream< Item = Result>, ServiceError>, - > + Unpin + Send, + > + Unpin + + Send, >, ), } @@ -201,7 +202,6 @@ fn handle_axfr_request( >, >, > { - // let mut stream = TransactionStream::default(); let mut stream = FuturesOrdered::< Pin< Box< diff --git a/src/net/server/connection.rs b/src/net/server/connection.rs index 1c3177f8e..1e3f83058 100644 --- a/src/net/server/connection.rs +++ b/src/net/server/connection.rs @@ -6,6 +6,7 @@ use std::io; use std::net::SocketAddr; use std::sync::Arc; +use futures::StreamExt; use octseq::Octets; use tokio::io::{ AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, @@ -21,10 +22,8 @@ use tracing::{debug, enabled, error, trace, warn}; use crate::base::wire::Composer; use crate::base::{Message, StreamTarget}; use crate::net::server::buf::BufSource; -// use crate::net::server::message::CommonMessageFlow; use crate::net::server::message::Request; use crate::net::server::metrics::ServerMetrics; -// use crate::net::server::middleware::chain::MiddlewareChain; use crate::net::server::service::{ CallResult, Service, ServiceError, ServiceFeedback, }; @@ -32,7 +31,6 @@ use crate::net::server::util::to_pcap_text; use crate::utils::config::DefMinMax; use super::message::{NonUdpTransportContext, TransportSpecificContext}; -// use super::middleware::builder::MiddlewareBuilder; use super::stream::Config as ServerConfig; use super::ServerCommand; use std::fmt::Display; @@ -90,7 +88,7 @@ const MAX_QUEUED_RESPONSES: DefMinMax = DefMinMax::new(10, 0, 1024); //----------- Config --------------------------------------------------------- /// Configuration for a stream server connection. -pub struct Config { +pub struct Config { /// Limit on the amount of time to allow between client requests. /// /// This setting can be overridden on a per connection basis by a @@ -117,17 +115,12 @@ pub struct Config { /// Limit on the number of DNS responses queued for wriing to the client. max_queued_responses: usize, - // /// The middleware chain used to pre-process requests and post-process // /// responses. // middleware_chain: MiddlewareChain, } -impl Config -where - RequestOctets: Octets, - Target: Composer + Default, -{ +impl Config { /// Creates a new, default config. #[allow(dead_code)] pub fn new() -> Self { @@ -231,30 +224,24 @@ where //--- Default -impl Default for Config -where - RequestOctets: Octets, - Target: Composer + Default, -{ +impl Default for Config { fn default() -> Self { Self { idle_timeout: IDLE_TIMEOUT.default(), response_write_timeout: RESPONSE_WRITE_TIMEOUT.default(), max_queued_responses: MAX_QUEUED_RESPONSES.default(), - // middleware_chain: MiddlewareBuilder::default().build(), } } } //--- Clone -impl Clone for Config { +impl Clone for Config { fn clone(&self) -> Self { Self { idle_timeout: self.idle_timeout, response_write_timeout: self.response_write_timeout, max_queued_responses: self.max_queued_responses, - // middleware_chain: self.middleware_chain.clone(), } } } @@ -279,7 +266,7 @@ where /// /// Note: Some reconfiguration is possible at runtime via /// [`ServerCommand::Reconfigure`] and [`ServiceFeedback::Reconfigure`]. - config: Config, + config: Config, /// The address of the connected client. addr: SocketAddr, @@ -348,7 +335,7 @@ where metrics: Arc, stream: Stream, addr: SocketAddr, - config: Config, + config: Config, ) -> Self { let (stream_rx, stream_tx) = tokio::io::split(stream); let (result_q_tx, result_q_rx) = @@ -388,6 +375,10 @@ where Buf: BufSource + Send + Sync + Clone + 'static, Buf::Output: Octets + Send + Sync, Svc: Service + Send + Sync + 'static, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, Svc::Target: Send + Composer + Default, { /// Start reading requests and writing responses to the stream. @@ -403,9 +394,7 @@ where /// for writing. pub async fn run( mut self, - command_rx: watch::Receiver< - ServerCommand>, - >, + command_rx: watch::Receiver>, ) where Svc::Stream: Send, { @@ -426,15 +415,16 @@ where Buf: BufSource + Send + Sync + Clone + 'static, Buf::Output: Octets + Send + Sync, Svc: Service + Send + Sync + 'static, - Svc::Stream: Send, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, Svc::Target: Send + Composer + Default, { /// Connection handler main loop. async fn run_until_error( mut self, - mut command_rx: watch::Receiver< - ServerCommand>, - >, + mut command_rx: watch::Receiver>, ) { // SAFETY: This unwrap is safe because we always put a Some value into // self.stream_rx in [`Self::with_config`] above (and thus also in @@ -514,9 +504,7 @@ where fn process_server_command( &mut self, res: Result<(), watch::error::RecvError>, - command_rx: &mut watch::Receiver< - ServerCommand>, - >, + command_rx: &mut watch::Receiver>, ) -> Result<(), ConnectionEvent> { // If the parent server no longer exists but was not cleanly shutdown // then the command channel will be closed and attempting to check for @@ -552,7 +540,6 @@ where idle_timeout, response_write_timeout, max_queued_responses: _, - // middleware_chain: _, }, .. // Ignore the Server specific configuration settings }) => { @@ -715,13 +702,17 @@ where res: Result, ) -> Result<(), ConnectionEvent> where - Svc::Stream: Send, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin + + Send, { - res.and_then(|msg| { + if let Ok(buf) = res { let received_at = Instant::now(); if enabled!(Level::TRACE) { - let pcap_text = to_pcap_text(&msg, msg.as_ref().len()); + let pcap_text = to_pcap_text(&buf, buf.as_ref().len()); trace!(addr = %self.addr, pcap_text, "Received message"); } @@ -730,20 +721,49 @@ where // Message received, reset the DNS idle timer self.idle_timer.full_msg_received(); - // Process the received message - // self.process_request( - // msg, - // received_at, - // self.addr, - // self.config.middleware_chain.clone(), - // &self.service, - // self.metrics.clone(), - // self.result_q_tx.clone(), - // ) - // .map_err(ConnectionEvent::ServiceError) - - todo!() - }) + match Message::from_octets(buf) { + Err(err) => { + tracing::warn!( + "Failed while parsing request message: {err}" + ); + return Err(ConnectionEvent::ServiceError( + ServiceError::FormatError, + )); + } + + Ok(msg) => { + let ctx = NonUdpTransportContext::new(Some( + self.config.idle_timeout, + )); + let ctx = TransportSpecificContext::NonUdp(ctx); + let request = + Request::new(self.addr, received_at, msg, ctx); + let mut stream = self.service.call(request); + while let Some(Ok(call_result)) = stream.next().await { + match self.result_q_tx.try_send(call_result) { + Ok(()) => { + self.metrics.set_num_pending_writes( + self.result_q_tx.max_capacity() + - self.result_q_tx.capacity(), + ); + } + + Err(TrySendError::Closed(_msg)) => { + // TODO: How should we properly communicate this to the operator? + error!("Unable to queue message for sending: server is shutting down."); + } + + Err(TrySendError::Full(_msg)) => { + // TODO: How should we properly communicate this to the operator? + error!("Unable to queue message for sending: queue is full."); + } + } + } + } + } + } + + Ok(()) } } @@ -762,64 +782,6 @@ where } } -// //--- CommonMessageFlow - -// impl CommonMessageFlow -// for Connection -// where -// Buf: BufSource, -// Buf::Output: Octets + Send + Sync + 'static, -// Svc: Service + Send + Sync + 'static, -// Svc::Target: Send, -// { -// type Meta = Sender>; - -// /// Add information to the request that relates to the type of server we -// /// are and our state where relevant. -// fn add_context_to_request( -// &self, -// request: Message, -// received_at: Instant, -// addr: SocketAddr, -// ) -> Request { -// let ctx = NonUdpTransportContext::new(Some(self.config.idle_timeout)); -// let ctx = TransportSpecificContext::NonUdp(ctx); -// Request::new(addr, received_at, request, ctx) -// } - -// /// Process the result from the middleware -> service -> middleware call -// /// tree. -// fn process_call_result( -// _request: &Request, -// call_result: CallResult, -// tx: Self::Meta, -// metrics: Arc, -// ) { -// // We can't send in a spawned async task as then we would just -// // accumlate tasks even if the target queue is full. We can't call -// // `tx.blocking_send()` as that would block the Tokio runtime. So -// // instead we try and send and if that fails because the queue is full -// // then we abort. -// match tx.try_send(call_result) { -// Ok(()) => { -// metrics.set_num_pending_writes( -// tx.max_capacity() - tx.capacity(), -// ); -// } - -// Err(TrySendError::Closed(_msg)) => { -// // TODO: How should we properly communicate this to the operator? -// error!("Unable to queue message for sending: server is shutting down."); -// } - -// Err(TrySendError::Full(_msg)) => { -// // TODO: How should we properly communicate this to the operator? -// error!("Unable to queue message for sending: queue is full."); -// } -// } -// } -// } - //----------- DnsMessageReceiver --------------------------------------------- /// The [`DnsMessageReceiver`] state machine. diff --git a/src/net/server/mod.rs b/src/net/server/mod.rs index 90246cc79..13c8c34a7 100644 --- a/src/net/server/mod.rs +++ b/src/net/server/mod.rs @@ -209,8 +209,8 @@ #![cfg(feature = "unstable-server-transport")] #![cfg_attr(docsrs, doc(cfg(feature = "unstable-server-transport")))] -// mod connection; -// pub use connection::Config as ConnectionConfig; +mod connection; +pub use connection::Config as ConnectionConfig; pub mod buf; pub mod dgram; @@ -220,7 +220,7 @@ pub mod metrics; pub mod middleware; pub mod service; pub mod sock; -// pub mod stream; +pub mod stream; pub mod util; #[cfg(test)] diff --git a/src/net/server/stream.rs b/src/net/server/stream.rs index d642b84e5..76febc667 100644 --- a/src/net/server/stream.rs +++ b/src/net/server/stream.rs @@ -38,6 +38,7 @@ use crate::utils::config::DefMinMax; use super::buf::VecBufSource; use super::connection::{self, Connection}; +use super::service::{CallResult, ServiceError}; use super::ServerCommand; use crate::base::wire::Composer; use tokio::io::{AsyncRead, AsyncWrite}; @@ -73,7 +74,7 @@ const MAX_CONCURRENT_TCP_CONNECTIONS: DefMinMax = //----------- Config --------------------------------------------------------- /// Configuration for a stream server. -pub struct Config { +pub struct Config { /// Limit on the number of concurrent TCP connections that can be handled /// by the server. max_concurrent_connections: usize, @@ -82,14 +83,10 @@ pub struct Config { accept_connections_at_max: bool, /// Connection specific configuration. - pub(super) connection_config: connection::Config, + pub(super) connection_config: connection::Config, } -impl Config -where - RequestOctets: Octets, - Target: Composer + Default, -{ +impl Config { /// Creates a new, default config. pub fn new() -> Self { Default::default() @@ -145,26 +142,20 @@ where /// See [`connection::Config`] for more information. pub fn set_connection_config( &mut self, - connection_config: connection::Config, + connection_config: connection::Config, ) { self.connection_config = connection_config; } /// Gets the connection specific configuration. - pub fn connection_config( - &self, - ) -> &connection::Config { + pub fn connection_config(&self) -> &connection::Config { &self.connection_config } } //--- Default -impl Default for Config -where - RequestOctets: Octets, - Target: Composer + Default, -{ +impl Default for Config { fn default() -> Self { Self { accept_connections_at_max: true, @@ -177,11 +168,7 @@ where //--- Clone -impl Clone for Config -where - RequestOctets: Octets, - Target: Composer + Default, -{ +impl Clone for Config { fn clone(&self) -> Self { Self { accept_connections_at_max: self.accept_connections_at_max, @@ -194,16 +181,13 @@ where //------------ StreamServer -------------------------------------------------- /// A [`ServerCommand`] capable of propagating a StreamServer [`Config`] value. -type ServerCommandType = - ServerCommand>; +type ServerCommandType = ServerCommand; /// A thread safe sender of [`ServerCommand`]s. -type CommandSender = - Arc>>>; +type CommandSender = Arc>>; /// A thread safe receiver of [`ServerCommand`]s. -type CommandReceiver = - watch::Receiver>; +type CommandReceiver = watch::Receiver; /// A server for connecting clients via stream based network transport to a /// [`Service`]. @@ -296,21 +280,25 @@ where Buf: BufSource + Send + Sync + Clone, Buf::Output: Octets + Send + Sync, Svc: Service + Send + Sync + Clone, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, Svc::Target: Composer + Default + 'static, { /// The configuration of the server. - config: Arc>>, + config: Arc>, /// A receiver for receiving [`ServerCommand`]s. /// /// Used by both the server and spawned connections to react to sent /// commands. - command_rx: CommandReceiver, + command_rx: CommandReceiver, /// A sender for sending [`ServerCommand`]s. /// /// Used to signal the server to stop, reconfigure, etc. - command_tx: CommandSender, + command_tx: CommandSender, /// A listener for listening for and accepting incoming stream /// connections. @@ -341,6 +329,10 @@ where Buf: BufSource + Send + Sync + Clone, Buf::Output: Octets + Send + Sync, Svc: Service + Send + Sync + Clone, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, Svc::Target: Composer + Default, { /// Creates a new [`StreamServer`] instance. @@ -361,7 +353,7 @@ where listener: Listener, buf: Buf, service: Svc, - config: Config, + config: Config, ) -> Self { let (command_tx, command_rx) = watch::channel(ServerCommand::Init); let command_tx = Arc::new(Mutex::new(command_tx)); @@ -421,6 +413,10 @@ where Buf: BufSource + Send + Sync + Clone, Buf::Output: Octets + Debug + Send + Sync, Svc: Service + Send + Sync + Clone, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, Svc::Target: Composer + Default, { /// Get a reference to the source for this server. @@ -444,6 +440,10 @@ where Buf: BufSource + Send + Sync + Clone, Buf::Output: Octets + Send + Sync, Svc: Service + Send + Sync + Clone, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, Svc::Target: Composer + Default + 'static, { /// Start the server. @@ -473,10 +473,7 @@ where /// /// This command will be received both by the server and by any existing /// connections. - pub fn reconfigure( - &self, - config: Config, - ) -> Result<(), Error> { + pub fn reconfigure(&self, config: Config) -> Result<(), Error> { self.command_tx .lock() .map_err(|_| Error::CommandCouldNotBeSent)? @@ -541,6 +538,10 @@ where Buf: BufSource + Send + Sync + Clone, Buf::Output: Octets + Send + Sync, Svc: Service + Send + Sync + Clone, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, Svc::Target: Composer + Default, { /// Accept stream connections until shutdown or fatal error. @@ -552,7 +553,10 @@ where Listener::Future: Send + 'static, Listener::StreamType: AsyncRead + AsyncWrite + Send + Sync + 'static, Svc: 'static, - Svc::Stream: Send, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, Svc::Target: Send + Sync + 'static, { let mut command_rx = self.command_rx.clone(); @@ -615,9 +619,7 @@ where fn process_server_command( &self, res: Result<(), watch::error::RecvError>, - command_rx: &mut watch::Receiver< - ServerCommand>, - >, + command_rx: &mut watch::Receiver>, ) -> Result<(), String> { // If the parent server no longer exists but was not cleanly shutdown // then the command channel will be closed and attempting to check for @@ -675,7 +677,10 @@ where Listener::Future: Send + 'static, Listener::StreamType: AsyncRead + AsyncWrite + Send + Sync + 'static, Svc: 'static, - Svc::Stream: Send, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, Svc::Target: Send + Sync + 'static, { // Work around the compiler wanting to move self to the async block by @@ -741,6 +746,10 @@ where Buf: BufSource + Send + Sync + Clone, Buf::Output: Octets + Send + Sync, Svc: Service + Send + Sync + Clone, + Svc::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Send + + Unpin, Svc::Target: Composer + Default + 'static, { fn drop(&mut self) { diff --git a/src/net/server/util.rs b/src/net/server/util.rs index 5d05ab76c..aa55560fe 100644 --- a/src/net/server/util.rs +++ b/src/net/server/util.rs @@ -106,7 +106,8 @@ where RequestOctets: AsRef<[u8]>, Stream: futures::stream::Stream< Item = Result, ServiceError>, - > + Send + 'static, + > + Send + + 'static, Metadata: Clone, T: Fn(Request, Metadata) -> Stream + Clone, { From 4cc2bd5da2ae2af66e9842d405922c382d57842b Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Mon, 15 Apr 2024 22:01:44 +0200 Subject: [PATCH 04/28] Compilation fix. --- examples/serve-zone.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/serve-zone.rs b/examples/serve-zone.rs index 7f6e7d66f..8a110eb49 100644 --- a/examples/serve-zone.rs +++ b/examples/serve-zone.rs @@ -75,14 +75,14 @@ async fn main() { DgramServer::new(sock.clone(), VecBufSource, svc.clone()); let metrics = udp_srv.metrics(); udp_metrics.push(metrics); - tokio::spawn(udp_srv.run()); + tokio::spawn(async move { udp_srv.run().await }); } let sock = TcpListener::bind(addr).await.unwrap(); let tcp_srv = StreamServer::new(sock, VecBufSource, svc); let tcp_metrics = tcp_srv.metrics(); - tokio::spawn(tcp_srv.run()); + tokio::spawn(async move { tcp_srv.run().await }); tokio::spawn(async move { loop { From 32b46231451ed3801d316c82a76cf6ae89f31221 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Mon, 15 Apr 2024 22:01:50 +0200 Subject: [PATCH 05/28] Remove commented out code. --- src/net/server/connection.rs | 1 - src/net/server/message.rs | 322 ----------------------------- src/net/server/middleware/chain.rs | 2 +- 3 files changed, 1 insertion(+), 324 deletions(-) diff --git a/src/net/server/connection.rs b/src/net/server/connection.rs index 1e3f83058..4f10eeaa3 100644 --- a/src/net/server/connection.rs +++ b/src/net/server/connection.rs @@ -12,7 +12,6 @@ use tokio::io::{ AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, WriteHalf, }; use tokio::sync::mpsc::error::TrySendError; -use tokio::sync::mpsc::Sender; use tokio::sync::{mpsc, watch}; use tokio::time::Instant; use tokio::time::{sleep_until, timeout}; diff --git a/src/net/server/message.rs b/src/net/server/message.rs index 2c2cec0ff..c77664767 100644 --- a/src/net/server/message.rs +++ b/src/net/server/message.rs @@ -1,23 +1,12 @@ //! Support for working with DNS messages in servers. -use core::ops::ControlFlow; use core::time::Duration; -use std::net::SocketAddr; use std::sync::{Arc, Mutex}; -use octseq::Octets; use tokio::time::Instant; -use tracing::Level; -use tracing::{enabled, error, info_span, warn}; use crate::base::Message; -use crate::net::server::buf::BufSource; -use crate::net::server::metrics::ServerMetrics; -// use crate::net::server::middleware::chain::MiddlewareChain; -use super::service::{CallResult, Service, ServiceError, Transaction}; -use super::util::start_reply; -use crate::base::wire::Composer; //------------ UdpTransportContext ------------------------------------------- @@ -213,314 +202,3 @@ impl> Clone for Request { } } } - -// //----------- CommonMessageFlow ---------------------------------------------- - -// /// Perform processing common to all messages being handled by a DNS server. -// /// -// /// All messages received by a DNS server need to pass through the following -// /// processing stages: -// /// -// /// - Pre-processing. -// /// - Service processing. -// /// - Post-processing. -// /// -// /// The strategy is common but some server specific aspects are delegated to -// /// the server that implements this trait: -// /// -// /// - Adding context to a request. -// /// - Finalizing the handling of a response. -// /// -// /// Servers implement this trait to benefit from the common processing -// /// required while still handling aspects specific to the server themselves. -// /// -// /// Processing starts at [`process_request`]. -// /// -// ///
-// /// -// /// This trait exists as a convenient mechanism for sharing common code -// /// between server implementations. The default function implementations -// /// provided by this trait are not intended to be overridden by consumers of -// /// this library. -// /// -// ///
-// /// -// /// [`process_request`]: Self::process_request() -// pub trait CommonMessageFlow -// where -// Buf: BufSource, -// Buf::Output: Octets + Send + Sync, -// Svc: Service + Send + Sync, -// { -// /// Server-specific data that it chooses to pass along with the request in -// /// order that it may receive it when `process_call_result()` is -// /// invoked on the implementing server. -// type Meta: Clone + Send + Sync + 'static; - -// /// Process a DNS request message. -// /// -// /// This function consumes the given message buffer and processes the -// /// contained message, if any, to completion, possibly resulting in a -// /// response being passed to [`Self::process_call_result`]. -// /// -// /// The request message is a given as a seqeuence of bytes in `buf` -// /// originating from client address `addr`. -// /// -// /// The [`MiddlewareChain`] and [`Service`] to be used to process the -// /// message are supplied in the `middleware_chain` and `svc` arguments -// /// respectively. -// /// -// /// Any server specific state to be used and/or updated as part of the -// /// processing should be supplied via the `state` argument whose type is -// /// defined by the implementing type. -// /// -// /// On error the result will be a [`ServiceError`]. -// #[allow(clippy::too_many_arguments)] -// fn process_request( -// &self, -// buf: Buf::Output, -// received_at: Instant, -// addr: SocketAddr, -// middleware_chain: MiddlewareChain, -// svc: &Svc, -// metrics: Arc, -// meta: Self::Meta, -// ) -> Result<(), ServiceError> -// where -// Svc: 'static, -// Svc::Target: Send + Composer + Default, -// Svc::Stream: Send, -// Buf::Output: 'static, -// { -// boomerang( -// self, -// buf, -// received_at, -// addr, -// middleware_chain, -// metrics, -// svc, -// meta, -// ) -// } - -// /// Add context to a request. -// /// -// /// The server supplies this function to annotate the received message -// /// with additional information about its origins. -// fn add_context_to_request( -// &self, -// request: Message, -// received_at: Instant, -// addr: SocketAddr, -// ) -> Request; - -// /// Finalize a response. -// /// -// /// The server supplies this function to handle the response as -// /// appropriate for the server, e.g. to write the response back to the -// /// originating client. -// /// -// /// The response is the form of a [`CallResult`]. -// fn process_call_result( -// request: &Request, -// call_result: CallResult, -// state: Self::Meta, -// metrics: Arc, -// ); -// } - -// /// Propogate a message through the [`MiddlewareChain`] to the [`Service`] and -// /// flow the response in reverse back down the same path, a bit like throwing -// /// a boomerang. -// #[allow(clippy::too_many_arguments)] -// fn boomerang( -// server: &Server, -// buf: ::Output, -// received_at: Instant, -// addr: SocketAddr, -// middleware_chain: MiddlewareChain< -// ::Output, -// ::Output>>::Target, -// >, -// metrics: Arc, -// svc: &Svc, -// meta: Server::Meta, -// ) -> Result<(), ServiceError> -// where -// Buf: BufSource, -// Buf::Output: Octets + Send + Sync + 'static, -// Svc: Service + Send + Sync + 'static, -// Svc::Stream: Send, -// Svc::Target: Send + Composer + Default, -// Server: CommonMessageFlow + ?Sized, -// { -// let message = Message::from_octets(buf).map_err(|err| { -// warn!("Failed while parsing request message: {err}"); -// ServiceError::InternalError -// })?; - -// let request = server.add_context_to_request(message, received_at, addr); - -// let preprocessing_result = do_middleware_preprocessing::( -// &request, -// &middleware_chain, -// &metrics, -// )?; - -// let (txn, aborted_preprocessor_idx) = -// do_service_call::(preprocessing_result, &request, svc); - -// do_middleware_postprocessing::( -// request, -// meta, -// middleware_chain, -// txn, -// aborted_preprocessor_idx, -// metrics, -// ); - -// Ok(()) -// } - -// /// Pass a pre-processed request to the [`Service`] to handle. -// /// -// /// If [`Service::call`] returns an error this function will produce a DNS -// /// ServFail error response. If the returned error is -// /// [`ServiceError::InternalError`] it will also be logged. -// #[allow(clippy::type_complexity)] -// fn do_service_call( -// preprocessing_result: ControlFlow<(Svc::Stream, usize)>, -// request: &Request<::Output>, -// svc: &Svc, -// ) -> (Svc::Stream, Option) -// where -// Buf: BufSource, -// Buf::Output: Octets, -// Svc: Service, -// Svc::Target: Composer + Default, -// { -// match preprocessing_result { -// ControlFlow::Continue(()) => { -// let res = if enabled!(Level::INFO) { -// let span = info_span!("svc-call", -// msg_id = request.message().header().id(), -// client = %request.client_addr(), -// ); -// let _guard = span.enter(); -// svc.call(request.clone()) -// } else { -// svc.call(request.clone()) -// }; - -// // Handle any error returned by the service. -// // let txn = res.unwrap_or_else(|err| { -// // if matches!(err, ServiceError::InternalError) { -// // error!("Service error while processing request: {err}"); -// // } - -// // let mut response = start_reply(request); -// // response.header_mut().set_rcode(err.rcode()); -// // let call_result = CallResult::new(response.additional()); -// // Transaction::immediate(Ok(call_result)) -// // }); - -// // Pass the transaction out for post-processing. -// (res, None) -// } - -// ControlFlow::Break((txn, aborted_preprocessor_idx)) => { -// (txn, Some(aborted_preprocessor_idx)) -// } -// } -// } - -// /// Pre-process a request. -// /// -// /// Pre-processing involves parsing a [`Message`] from the byte buffer and -// /// pre-processing it via any supplied [`MiddlewareChain`]. -// /// -// /// On success the result is an immutable request message and a -// /// [`ControlFlow`] decision about whether to continue with further processing -// /// or to break early with a possible response. If processing failed the -// /// result will be a [`ServiceError`]. -// /// -// /// On break the result will be one ([`Transaction::single`]) or more -// /// ([`Transaction::stream`]) to post-process. -// #[allow(clippy::type_complexity)] -// fn do_middleware_preprocessing( -// request: &Request, -// middleware_chain: &MiddlewareChain, -// metrics: &Arc, -// ) -> Result, ServiceError> -// where -// Buf: BufSource, -// Buf::Output: Octets + Send + Sync + 'static, -// Svc: Service + Send + Sync, -// Svc::Target: Send + Composer + Default + 'static, -// { -// let span = info_span!("pre-process", -// msg_id = request.message().header().id(), -// client = %request.client_addr(), -// ); -// let _guard = span.enter(); - -// metrics.inc_num_inflight_requests(); - -// let pp_res = middleware_chain.preprocess(request); - -// Ok(pp_res) -// } - -// /// Post-process a response in the context of its originating request. -// /// -// /// Each response is post-processed in its own Tokio task. Note that there is -// /// no guarantee about the order in which responses will be post-processed. If -// /// the order of a seqence of responses is important it should be provided as -// /// a [`Transaction::stream`] rather than [`Transaction::single`]. -// /// -// /// Responses are first post-processed by the [`MiddlewareChain`] provided, if -// /// any, then passed to [`Self::process_call_result`] for final processing. -// #[allow(clippy::type_complexity)] -// fn do_middleware_postprocessing( -// request: Request, -// meta: Server::Meta, -// middleware_chain: MiddlewareChain, -// mut response_txn: Svc::Stream, -// last_processor_id: Option, -// metrics: Arc, -// ) where -// Buf: BufSource, -// Buf::Output: Octets + Send + Sync + 'static, -// Svc: Service + Send + Sync + 'static, -// Svc::Stream: Send, -// Svc::Target: Send + Composer + Default, -// Server: CommonMessageFlow + ?Sized, -// { -// tokio::spawn(async move { -// let span = info_span!("post-process", -// msg_id = request.message().header().id(), -// client = %request.client_addr(), -// ); -// let _guard = span.enter(); - -// while let Some(Ok(mut call_result)) = response_txn.next().await { -// if let Some(response) = call_result.get_response_mut() { -// middleware_chain.postprocess( -// &request, -// response, -// last_processor_id, -// ); -// } - -// Server::process_call_result( -// &request, -// call_result, -// meta.clone(), -// metrics.clone(), -// ); -// } - -// metrics.dec_num_inflight_requests(); -// }); -// } diff --git a/src/net/server/middleware/chain.rs b/src/net/server/middleware/chain.rs index 4546f8d1d..470cd0659 100644 --- a/src/net/server/middleware/chain.rs +++ b/src/net/server/middleware/chain.rs @@ -12,7 +12,7 @@ use crate::base::message_builder::AdditionalBuilder; use crate::base::wire::Composer; use crate::base::StreamTarget; use crate::net::server::message::Request; -use crate::net::server::service::{CallResult, ServiceError, Transaction}; +use crate::net::server::service::{CallResult, ServiceError}; use super::processor::MiddlewareProcessor; From c6682b66ca31da087e2e014b0f51c8d248b3aaeb Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 16 Apr 2024 10:15:29 +0200 Subject: [PATCH 06/28] WIP: Working Box based Service and mandatory middleware (a) being actually used and (b) with Box-less post-processing via new `PostprocessingMap` type based on stream::Map, with stream server re-enabled but still no other middleware. --- examples/serve-zone.rs | 5 +- .../middleware/processors/mandatory_svc.rs | 173 ++++++++++++------ 2 files changed, 119 insertions(+), 59 deletions(-) diff --git a/examples/serve-zone.rs b/examples/serve-zone.rs index 8a110eb49..948edad4c 100644 --- a/examples/serve-zone.rs +++ b/examples/serve-zone.rs @@ -21,6 +21,7 @@ use domain::base::{Dname, Message, Rtype, ToDname}; use domain::net::server::buf::VecBufSource; use domain::net::server::dgram::DgramServer; use domain::net::server::message::Request; +use domain::net::server::middleware::processors::mandatory_svc::MandatoryMiddlewareSvc; use domain::net::server::service::{CallResult, ServiceError}; use domain::net::server::stream::StreamServer; use domain::net::server::util::{mk_builder_for_target, service_fn}; @@ -64,7 +65,9 @@ async fn main() { let zones = Arc::new(zones); let addr = "127.0.0.1:8053"; - let svc = Arc::new(service_fn(my_service, zones)); + let business_svc = service_fn(my_service, zones); + let mandatory_svc = MandatoryMiddlewareSvc::new(business_svc); + let svc = Arc::new(mandatory_svc); let sock = UdpSocket::bind(addr).await.unwrap(); let sock = Arc::new(sock); diff --git a/src/net/server/middleware/processors/mandatory_svc.rs b/src/net/server/middleware/processors/mandatory_svc.rs index 440322a70..8b964d213 100644 --- a/src/net/server/middleware/processors/mandatory_svc.rs +++ b/src/net/server/middleware/processors/mandatory_svc.rs @@ -2,11 +2,12 @@ use core::future::{ready, Ready}; use core::marker::PhantomData; use core::ops::{ControlFlow, DerefMut}; +use core::pin::Pin; +use core::task::{Context, Poll}; -use std::boxed::Box; use std::fmt::Display; -use futures::stream::{once, Once}; +use futures::stream::{once, FuturesOrdered, Once}; use futures::{Stream, StreamExt}; use octseq::Octets; use tracing::{debug, error, trace, warn}; @@ -294,28 +295,42 @@ impl //--- Service -pub enum MiddlewareStream +pub enum MiddlewareStream where - S: futures::stream::Stream< + RequestOctets: Octets, + InnerServiceResponseStream: futures::stream::Stream< Item = Result, ServiceError>, > + Unpin, + Self: Unpin, + Target: Unpin, { - Continue(S), - Map( - Box< - dyn futures::stream::Stream< - Item = Result, ServiceError>, - > + Unpin, - >, + /// The inner service response will be passed through this service without + /// modification. + Passthru(InnerServiceResponseStream), + + /// The inner service response will be post-processed by this service. + Postprocess( + PostprocessingMap, + ), + + /// A single response has been created without invoking the inner service. + HandledOne(Once, ServiceError>>>), + + /// Multiple responses have been created without invoking the inner + /// service. + HandledMany( + FuturesOrdered, ServiceError>>>, ), - BreakOne(Once, ServiceError>>>), } -impl Stream for MiddlewareStream +impl Stream + for MiddlewareStream where + RequestOctets: Octets, S: futures::stream::Stream< Item = Result, ServiceError>, > + Unpin, + Target: Composer + Default + Unpin, { type Item = Result, ServiceError>; @@ -324,9 +339,19 @@ where cx: &mut core::task::Context<'_>, ) -> core::task::Poll> { match self.deref_mut() { - MiddlewareStream::Continue(s) => s.poll_next_unpin(cx), - MiddlewareStream::Map(s) => s.poll_next_unpin(cx), - MiddlewareStream::BreakOne(s) => s.poll_next_unpin(cx), + MiddlewareStream::Passthru(s) => s.poll_next_unpin(cx), + MiddlewareStream::Postprocess(s) => s.poll_next_unpin(cx), + MiddlewareStream::HandledOne(s) => s.poll_next_unpin(cx), + MiddlewareStream::HandledMany(s) => s.poll_next_unpin(cx), + } + } + + fn size_hint(&self) -> (usize, Option) { + match self { + MiddlewareStream::Passthru(s) => s.size_hint(), + MiddlewareStream::Postprocess(s) => s.size_hint(), + MiddlewareStream::HandledOne(s) => s.size_hint(), + MiddlewareStream::HandledMany(s) => s.size_hint(), } } } @@ -340,62 +365,94 @@ where Item = Result, ServiceError>, > + Unpin + 'static, - Target: Composer + Default + 'static, + Target: Composer + Default + 'static + Unpin, { type Target = Target; - type Stream = MiddlewareStream; + type Stream = MiddlewareStream; fn call(&self, request: Request) -> Self::Stream { match self.preprocess(&request) { ControlFlow::Continue(()) => { - let strict = self.strict; - let cloned_request = request.clone(); - let map = self.inner.call(request).map(move |mut res| { - if let Ok(cr) = &mut res { - if let Some(response) = cr.get_response_mut() { - Self::postprocess( - &cloned_request, - response, - strict, - ); - } - } - res - }); - MiddlewareStream::Map(Box::new(map)) + let st = self.inner.call(request.clone()); + let map = PostprocessingMap::new(st, request, self.strict); + MiddlewareStream::Postprocess(map) } ControlFlow::Break(mut response) => { Self::postprocess(&request, &mut response, self.strict); - MiddlewareStream::BreakOne(once(ready(Ok(CallResult::new( - response, - ))))) + MiddlewareStream::HandledOne(once(ready(Ok( + CallResult::new(response), + )))) } } } } -// impl MiddlewareProcessor -// for MandatoryMiddlewareSvc -// where -// RequestOctets: Octets, -// S: Service, -// S::Target: Composer + Default, -// { -// fn preprocess( -// &self, -// request: &Request, -// ) -> ControlFlow>> { -// self.p -// } - -// fn postprocess( -// &self, -// request: &Request, -// response: &mut AdditionalBuilder>, -// ) { -// todo!() -// } -// } +pub struct PostprocessingMap +where + RequestOctets: Octets, + St: futures::stream::Stream< + Item = Result, ServiceError>, + >, +{ + request: Request, + strict: bool, + _phantom: PhantomData, + stream: St, +} + +impl PostprocessingMap +where + RequestOctets: Octets, + St: futures::stream::Stream< + Item = Result, ServiceError>, + >, +{ + pub(crate) fn new( + stream: St, + request: Request, + strict: bool, + ) -> Self { + Self { + stream, + request, + strict, + _phantom: PhantomData, + } + } +} + +impl Stream + for PostprocessingMap +where + RequestOctets: Octets, + St: futures::stream::Stream< + Item = Result, ServiceError>, + > + Unpin, + Target: Composer + Default + Unpin, +{ + type Item = Result, ServiceError>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let res = futures::ready!(self.stream.poll_next_unpin(cx)); + let request = self.request.clone(); + let strict = self.strict; + Poll::Ready(res.map(|mut res| { + if let Ok(cr) = &mut res { + if let Some(response) = cr.get_response_mut() { + MandatoryMiddlewareSvc::::postprocess(&request, response, strict); + } + } + res + })) + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} //------------ TruncateError ------------------------------------------------- From 8a406d3b8eec19f920a003fa7f5e749fd0dc51a6 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 16 Apr 2024 10:24:31 +0200 Subject: [PATCH 07/28] Factor `MiddlewareStream` out to new util.rs. --- src/net/server/middleware/mod.rs | 1 + .../middleware/processors/mandatory_svc.rs | 112 ++++++------------ src/net/server/middleware/util.rs | 80 +++++++++++++ 3 files changed, 116 insertions(+), 77 deletions(-) create mode 100644 src/net/server/middleware/util.rs diff --git a/src/net/server/middleware/mod.rs b/src/net/server/middleware/mod.rs index 2f5176866..ae61ca78d 100644 --- a/src/net/server/middleware/mod.rs +++ b/src/net/server/middleware/mod.rs @@ -30,3 +30,4 @@ pub mod builder; pub mod chain; pub mod processor; pub mod processors; +pub mod util; diff --git a/src/net/server/middleware/processors/mandatory_svc.rs b/src/net/server/middleware/processors/mandatory_svc.rs index 8b964d213..c3146961c 100644 --- a/src/net/server/middleware/processors/mandatory_svc.rs +++ b/src/net/server/middleware/processors/mandatory_svc.rs @@ -1,13 +1,13 @@ //! Core DNS RFC standards based message processing for MUST requirements. -use core::future::{ready, Ready}; +use core::future::ready; use core::marker::PhantomData; -use core::ops::{ControlFlow, DerefMut}; +use core::ops::ControlFlow; use core::pin::Pin; use core::task::{Context, Poll}; use std::fmt::Display; -use futures::stream::{once, FuturesOrdered, Once}; +use futures::stream::once; use futures::{Stream, StreamExt}; use octseq::Octets; use tracing::{debug, error, trace, warn}; @@ -17,6 +17,7 @@ use crate::base::message_builder::{AdditionalBuilder, PushError}; use crate::base::wire::{Composer, ParseError}; use crate::base::StreamTarget; use crate::net::server::message::{Request, TransportSpecificContext}; +use crate::net::server::middleware::util::MiddlewareStream; use crate::net::server::service::{CallResult, Service, ServiceError}; use crate::net::server::util::{mk_builder_for_target, start_reply}; @@ -295,67 +296,6 @@ impl //--- Service -pub enum MiddlewareStream -where - RequestOctets: Octets, - InnerServiceResponseStream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Unpin, - Self: Unpin, - Target: Unpin, -{ - /// The inner service response will be passed through this service without - /// modification. - Passthru(InnerServiceResponseStream), - - /// The inner service response will be post-processed by this service. - Postprocess( - PostprocessingMap, - ), - - /// A single response has been created without invoking the inner service. - HandledOne(Once, ServiceError>>>), - - /// Multiple responses have been created without invoking the inner - /// service. - HandledMany( - FuturesOrdered, ServiceError>>>, - ), -} - -impl Stream - for MiddlewareStream -where - RequestOctets: Octets, - S: futures::stream::Stream< - Item = Result, ServiceError>, - > + Unpin, - Target: Composer + Default + Unpin, -{ - type Item = Result, ServiceError>; - - fn poll_next( - mut self: core::pin::Pin<&mut Self>, - cx: &mut core::task::Context<'_>, - ) -> core::task::Poll> { - match self.deref_mut() { - MiddlewareStream::Passthru(s) => s.poll_next_unpin(cx), - MiddlewareStream::Postprocess(s) => s.poll_next_unpin(cx), - MiddlewareStream::HandledOne(s) => s.poll_next_unpin(cx), - MiddlewareStream::HandledMany(s) => s.poll_next_unpin(cx), - } - } - - fn size_hint(&self) -> (usize, Option) { - match self { - MiddlewareStream::Passthru(s) => s.size_hint(), - MiddlewareStream::Postprocess(s) => s.size_hint(), - MiddlewareStream::HandledOne(s) => s.size_hint(), - MiddlewareStream::HandledMany(s) => s.size_hint(), - } - } -} - impl Service for MandatoryMiddlewareSvc where @@ -368,13 +308,17 @@ where Target: Composer + Default + 'static + Unpin, { type Target = Target; - type Stream = MiddlewareStream; + type Stream = MiddlewareStream< + S::Stream, + PostprocessingStream, + Target, + >; fn call(&self, request: Request) -> Self::Stream { match self.preprocess(&request) { ControlFlow::Continue(()) => { let st = self.inner.call(request.clone()); - let map = PostprocessingMap::new(st, request, self.strict); + let map = PostprocessingStream::new(st, request, self.strict); MiddlewareStream::Postprocess(map) } ControlFlow::Break(mut response) => { @@ -387,28 +331,32 @@ where } } -pub struct PostprocessingMap -where +pub struct PostprocessingStream< + RequestOctets, + Target, + InnerServiceResponseStream, +> where RequestOctets: Octets, - St: futures::stream::Stream< + InnerServiceResponseStream: futures::stream::Stream< Item = Result, ServiceError>, >, { request: Request, strict: bool, _phantom: PhantomData, - stream: St, + stream: InnerServiceResponseStream, } -impl PostprocessingMap +impl + PostprocessingStream where RequestOctets: Octets, - St: futures::stream::Stream< + InnerServiceResponseStream: futures::stream::Stream< Item = Result, ServiceError>, >, { pub(crate) fn new( - stream: St, + stream: InnerServiceResponseStream, request: Request, strict: bool, ) -> Self { @@ -421,11 +369,15 @@ where } } -impl Stream - for PostprocessingMap +impl Stream + for PostprocessingStream< + RequestOctets, + Target, + InnerServiceResponseStream, + > where RequestOctets: Octets, - St: futures::stream::Stream< + InnerServiceResponseStream: futures::stream::Stream< Item = Result, ServiceError>, > + Unpin, Target: Composer + Default + Unpin, @@ -442,7 +394,13 @@ where Poll::Ready(res.map(|mut res| { if let Ok(cr) = &mut res { if let Some(response) = cr.get_response_mut() { - MandatoryMiddlewareSvc::::postprocess(&request, response, strict); + MandatoryMiddlewareSvc::< + RequestOctets, + InnerServiceResponseStream, + Target, + >::postprocess( + &request, response, strict + ); } } res diff --git a/src/net/server/middleware/util.rs b/src/net/server/middleware/util.rs new file mode 100644 index 000000000..c408450ec --- /dev/null +++ b/src/net/server/middleware/util.rs @@ -0,0 +1,80 @@ +use core::ops::DerefMut; + +use std::future::Ready; + +use futures::stream::{FuturesOrdered, Once}; +use futures::Stream; +use futures_util::StreamExt; + +use crate::base::wire::Composer; +use crate::net::server::service::{CallResult, ServiceError}; + +pub enum MiddlewareStream< + InnerServiceResponseStream, + PostprocessingStream, + Target, +> where + InnerServiceResponseStream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Unpin, + PostprocessingStream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Unpin, + Self: Unpin, + Target: Unpin, +{ + /// The inner service response will be passed through this service without + /// modification. + Passthru(InnerServiceResponseStream), + + /// The inner service response will be post-processed by this service. + Postprocess(PostprocessingStream), + + /// A single response has been created without invoking the inner service. + HandledOne(Once, ServiceError>>>), + + /// Multiple responses have been created without invoking the inner + /// service. + HandledMany( + FuturesOrdered, ServiceError>>>, + ), +} + +impl Stream + for MiddlewareStream< + InnerServiceResponseStream, + PostprocessingStream, + Target, + > +where + InnerServiceResponseStream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Unpin, + PostprocessingStream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Unpin, + Target: Composer + Default + Unpin, +{ + type Item = Result, ServiceError>; + + fn poll_next( + mut self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> core::task::Poll> { + match self.deref_mut() { + MiddlewareStream::Passthru(s) => s.poll_next_unpin(cx), + MiddlewareStream::Postprocess(s) => s.poll_next_unpin(cx), + MiddlewareStream::HandledOne(s) => s.poll_next_unpin(cx), + MiddlewareStream::HandledMany(s) => s.poll_next_unpin(cx), + } + } + + fn size_hint(&self) -> (usize, Option) { + match self { + MiddlewareStream::Passthru(s) => s.size_hint(), + MiddlewareStream::Postprocess(s) => s.size_hint(), + MiddlewareStream::HandledOne(s) => s.size_hint(), + MiddlewareStream::HandledMany(s) => s.size_hint(), + } + } +} From 98fcb6589e38761716b27ae2bdac9213fcbe45de Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 16 Apr 2024 14:32:15 +0200 Subject: [PATCH 08/28] Adds Service style versions of the EDNS and Cookie middleware too. Also fixes bugs in the Cookie middleware relating to adding the COOKIE option more than once, and not preserving OPT header values when appending OPT options. --- examples/serve-zone.rs | 8 +- .../server/middleware/processors/cookies.rs | 17 +- .../middleware/processors/cookies_svc.rs | 625 ++++++++++++++++++ src/net/server/middleware/processors/edns.rs | 16 +- .../server/middleware/processors/edns_svc.rs | 600 +++++++++++++++++ src/net/server/middleware/processors/mod.rs | 6 +- src/net/server/util.rs | 29 +- 7 files changed, 1284 insertions(+), 17 deletions(-) create mode 100644 src/net/server/middleware/processors/cookies_svc.rs create mode 100644 src/net/server/middleware/processors/edns_svc.rs diff --git a/examples/serve-zone.rs b/examples/serve-zone.rs index 948edad4c..a9fd1c481 100644 --- a/examples/serve-zone.rs +++ b/examples/serve-zone.rs @@ -21,6 +21,8 @@ use domain::base::{Dname, Message, Rtype, ToDname}; use domain::net::server::buf::VecBufSource; use domain::net::server::dgram::DgramServer; use domain::net::server::message::Request; +use domain::net::server::middleware::processors::cookies_svc::CookiesMiddlewareSvc; +use domain::net::server::middleware::processors::edns_svc::EdnsMiddlewareSvc; use domain::net::server::middleware::processors::mandatory_svc::MandatoryMiddlewareSvc; use domain::net::server::service::{CallResult, ServiceError}; use domain::net::server::stream::StreamServer; @@ -66,8 +68,10 @@ async fn main() { let addr = "127.0.0.1:8053"; let business_svc = service_fn(my_service, zones); - let mandatory_svc = MandatoryMiddlewareSvc::new(business_svc); - let svc = Arc::new(mandatory_svc); + + let svc = Arc::new(MandatoryMiddlewareSvc::new(EdnsMiddlewareSvc::new( + CookiesMiddlewareSvc::with_random_secret(business_svc), + ))); let sock = UdpSocket::bind(addr).await.unwrap(); let sock = Arc::new(sock); diff --git a/src/net/server/middleware/processors/cookies.rs b/src/net/server/middleware/processors/cookies.rs index 6f0a245a4..a5bfbd65f 100644 --- a/src/net/server/middleware/processors/cookies.rs +++ b/src/net/server/middleware/processors/cookies.rs @@ -8,7 +8,7 @@ use octseq::Octets; use rand::RngCore; use tracing::{debug, trace, warn}; -use crate::base::iana::{OptRcode, Rcode}; +use crate::base::iana::{OptRcode, OptionCode, Rcode}; use crate::base::message_builder::AdditionalBuilder; use crate::base::opt; use crate::base::opt::Cookie; @@ -142,7 +142,7 @@ impl CookiesMiddlewareProcessor { // Note: if rcode is non-extended this will also correctly handle // setting the rcode in the main message header. - if let Err(err) = add_edns_options(&mut additional, |opt| { + if let Err(err) = add_edns_options(&mut additional, |_, opt| { opt.cookie(response_cookie)?; opt.set_rcode(rcode); Ok(()) @@ -475,9 +475,16 @@ where // option from the request or (b) generating a new COOKIE option // containing both the Client Cookie copied from the request and // a valid Server Cookie it has generated." - if let Err(err) = add_edns_options(response, |builder| { - builder.push(&filled_cookie) - }) { + if let Err(err) = add_edns_options( + response, + |existing_option_codes, builder| { + if !existing_option_codes.contains(&OptionCode::COOKIE) { + builder.push(&filled_cookie) + } else { + Ok(()) + } + }, + ) { warn!("Cannot add RFC 7873 DNS Cookie option to response: {err}"); } } diff --git a/src/net/server/middleware/processors/cookies_svc.rs b/src/net/server/middleware/processors/cookies_svc.rs new file mode 100644 index 000000000..45e778ae8 --- /dev/null +++ b/src/net/server/middleware/processors/cookies_svc.rs @@ -0,0 +1,625 @@ +//! DNS Cookies related message processing. +use core::future::ready; +use core::marker::PhantomData; +use core::ops::ControlFlow; +use core::pin::Pin; +use core::task::{Context, Poll}; + +use std::net::IpAddr; +use std::vec::Vec; + +use futures::stream::once; +use futures::Stream; +use futures_util::StreamExt; +use octseq::Octets; +use rand::RngCore; +use tracing::{debug, enabled, trace, warn, Level}; + +use crate::base::iana::{OptRcode, OptionCode, Rcode}; +use crate::base::message_builder::AdditionalBuilder; +use crate::base::opt; +use crate::base::opt::Cookie; +use crate::base::wire::{Composer, ParseError}; +use crate::base::{Serial, StreamTarget}; +use crate::net::server::message::Request; +use crate::net::server::middleware::util::MiddlewareStream; +use crate::net::server::service::{CallResult, Service, ServiceError}; +use crate::net::server::util::{add_edns_options, to_pcap_text}; +use crate::net::server::util::{mk_builder_for_target, start_reply}; +use std::sync::Arc; + +/// The five minute period referred to by +/// https://www.rfc-editor.org/rfc/rfc9018.html#section-4.3. +const FIVE_MINUTES_AS_SECS: u32 = 5 * 60; + +/// The one hour period referred to by +/// https://www.rfc-editor.org/rfc/rfc9018.html#section-4.3. +const ONE_HOUR_AS_SECS: u32 = 60 * 60; + +/// A DNS Cookies [`MiddlewareProcessor`]. +/// +/// Standards covered by ths implementation: +/// +/// | RFC | Status | +/// |--------|---------| +/// | [7873] | TBD | +/// | [9018] | TBD | +/// +/// [7873]: https://datatracker.ietf.org/doc/html/rfc7873 +/// [9018]: https://datatracker.ietf.org/doc/html/rfc7873 +/// [`MiddlewareProcessor`]: crate::net::server::middleware::processor::MiddlewareProcessor +#[derive(Debug)] +pub struct CookiesMiddlewareSvc { + inner: S, + + /// A user supplied secret used in making the cookie value. + server_secret: Arc<[u8; 16]>, + + /// Clients connecting from these IP addresses will be required to provide + /// a cookie otherwise they will receive REFUSED with TC=1 prompting them + /// to reconnect with TCP in order to "authenticate" themselves. + ip_deny_list: Vec, +} + +impl CookiesMiddlewareSvc { + /// Creates an instance of this processor. + #[must_use] + pub fn new(inner: S, server_secret: [u8; 16]) -> Self { + Self { + inner, + server_secret: Arc::new(server_secret), + ip_deny_list: vec![], + } + } + + pub fn with_random_secret(inner: S) -> Self { + let mut server_secret = [0u8; 16]; + rand::thread_rng().fill_bytes(&mut server_secret); + Self::new(inner, server_secret) + } + + /// Define IP addresses required to supply DNS cookies if using UDP. + #[must_use] + pub fn with_denied_ips>>( + mut self, + ip_deny_list: T, + ) -> Self { + self.ip_deny_list = ip_deny_list.into(); + self + } +} + +impl CookiesMiddlewareSvc { + /// Get the DNS COOKIE, if any, for the given message. + /// + /// https://datatracker.ietf.org/doc/html/rfc7873#section-5.2: Responding + /// to a Request: "In all cases of multiple COOKIE options in a request, + /// only the first (the one closest to the DNS header) is considered. + /// All others are ignored." + /// + /// Returns: + /// - `None` if the request has no cookie, + /// - Some(Ok(cookie)) if the request has a cookie in the correct + /// format, + /// - Some(Err(err)) if the request has a cookie that we could not + /// parse. + #[must_use] + fn cookie( + request: &Request, + ) -> Option> { + // Note: We don't use `opt::Opt::first()` because that will silently + // ignore an unparseable COOKIE option but we need to detect and + // handle that case. TODO: Should we warn in some way if the request + // has more than one COOKIE option? + request + .message() + .opt() + .and_then(|opt| opt.opt().iter::().next()) + } + + /// Check whether or not the given timestamp is okay. + /// + /// Returns true if the given timestamp is within the permitted difference + /// to now as specified by [RFC 9018 section 4.3]. + /// + /// [RFC 9018 section 4.3]: https://www.rfc-editor.org/rfc/rfc9018.html#section-4.3 + #[must_use] + fn timestamp_ok(serial: Serial) -> bool { + // https://www.rfc-editor.org/rfc/rfc9018.html#section-4.3 + // 4.3. The Timestamp Sub-Field: + // "The Timestamp value prevents Replay Attacks and MUST be checked + // by the server to be within a defined period of time. The DNS + // server SHOULD allow cookies within a 1-hour period in the past + // and a 5-minute period into the future to allow operation of + // low-volume clients and some limited time skew between the DNS + // servers in the anycast set." + let now = Serial::now(); + let too_new_at = now.add(FIVE_MINUTES_AS_SECS); + let expires_at = serial.add(ONE_HOUR_AS_SECS); + now <= expires_at && serial <= too_new_at + } + + /// Create a DNS response message for the given request, including cookie. + fn response_with_cookie( + &self, + request: &Request, + rcode: OptRcode, + ) -> AdditionalBuilder> + where + RequestOctets: Octets, + Target: Composer + Default, + { + let mut additional = start_reply(request).additional(); + + if let Some(Ok(client_cookie)) = Self::cookie(request) { + let response_cookie = client_cookie.create_response( + Serial::now(), + request.client_addr().ip(), + &self.server_secret, + ); + + // Note: if rcode is non-extended this will also correctly handle + // setting the rcode in the main message header. + if let Err(err) = add_edns_options(&mut additional, |_, opt| { + opt.cookie(response_cookie)?; + opt.set_rcode(rcode); + Ok(()) + }) { + warn!("Failed to add cookie to response: {err}"); + } + } + + additional + } + + /// Create a DNS error response message indicating that the client + /// supplied cookie is not okay. + /// + /// Panics + /// + /// This function will panic if the given request does not include a DNS + /// client cookie or is unable to write to an internal buffer while + /// constructing the response. + #[must_use] + fn bad_cookie_response( + &self, + request: &Request, + ) -> AdditionalBuilder> + where + RequestOctets: Octets, + Target: Composer + Default, + { + // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.3 + // "If the server responds [ed: by sending a BADCOOKIE error + // response], it SHALL generate its own COOKIE option containing + // both the Client Cookie copied from the request and a Server + // Cookie it has generated, and it will add this COOKIE option to + // the response's OPT record. + + self.response_with_cookie(request, OptRcode::BADCOOKIE) + } + + /// Create a DNS response to a client cookie prefetch request. + #[must_use] + fn prefetch_cookie_response( + &self, + request: &Request, + ) -> AdditionalBuilder> + where + RequestOctets: Octets, + Target: Composer + Default, + { + // https://datatracker.ietf.org/doc/html/rfc7873#section-5.4 + // Querying for a Server Cookie: + // "For servers with DNS Cookies enabled, the + // QUERY opcode behavior is extended to support queries with an + // empty Question Section (a QDCOUNT of zero (0)), provided that an + // OPT record is present with a COOKIE option. Such servers will + // send a reply that has an empty Answer Section and has a COOKIE + // option containing the Client Cookie and a valid Server Cookie. + // + // If such a query provided just a Client Cookie and no Server + // Cookie, the response SHALL have the RCODE NOERROR." + self.response_with_cookie(request, Rcode::NOERROR.into()) + } + + /// Check the cookie contained in the request to make sure that it is + /// complete, and if so return the cookie to the caller. + #[must_use] + fn ensure_cookie_is_complete( + request: &Request, + server_secret: &[u8; 16], + ) -> Option { + if let Some(Ok(cookie)) = Self::cookie(request) { + let cookie = if cookie.server().is_some() { + cookie + } else { + cookie.create_response( + Serial::now(), + request.client_addr().ip(), + server_secret, + ) + }; + + Some(cookie) + } else { + None + } + } +} + +//--- MiddlewareProcessor + +impl CookiesMiddlewareSvc { + fn preprocess( + &self, + request: &Request, + ) -> ControlFlow>> + where + RequestOctets: Octets, + Target: Composer + Default, + { + match Self::cookie(request) { + None => { + trace!("Request does not include DNS cookies"); + + // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.1 + // No OPT RR or No COOKIE Option: + // "If there is no OPT record or no COOKIE option + // present in the request, then the server responds to + // the request as if the server doesn't implement the + // COOKIE option." + + // For clients on the IP deny list they MUST authenticate + // themselves to the server, either with a cookie or by + // re-connecting over TCP, so we REFUSE them and reply with + // TC=1 to prompt them to reconnect via TCP. + if request.transport_ctx().is_udp() + && self.ip_deny_list.contains(&request.client_addr().ip()) + { + debug!( + "Rejecting cookie-less non-TCP request due to matching IP deny list entry" + ); + let builder = mk_builder_for_target(); + let mut additional = builder.additional(); + additional.header_mut().set_rcode(Rcode::REFUSED); + additional.header_mut().set_tc(true); + return ControlFlow::Break(additional); + } else { + trace!("Permitting cookie-less request to flow due to use of TCP transport"); + } + } + + Some(Err(err)) => { + // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.2 + // Malformed COOKIE Option: + // "If the COOKIE option is too short to contain a + // Client Cookie, then FORMERR is generated. If the + // COOKIE option is longer than that required to hold a + // COOKIE option with just a Client Cookie (8 bytes) + // but is shorter than the minimum COOKIE option with + // both a Client Cookie and a Server Cookie (16 bytes), + // then FORMERR is generated. If the COOKIE option is + // longer than the maximum valid COOKIE option (40 + // bytes), then FORMERR is generated." + + // TODO: Should we warn in some way about the exact reason + // for rejecting the request? + + // NOTE: The RFC doesn't say that we should send our server + // cookie back with the response, so we don't do that here + // unlike in the other cases where we respond early. + debug!("Received malformed DNS cookie: {err}"); + let mut builder = mk_builder_for_target(); + builder.header_mut().set_rcode(Rcode::FORMERR); + return ControlFlow::Break(builder.additional()); + } + + Some(Ok(cookie)) => { + // TODO: Does the "at least occasionally" condition below + // referencing RFC 7873 section 5.2.3 mean that (a) we don't + // have to do this for every response, and (b) we might want + // to add configuration settings for controlling how often we + // do this? + + let server_cookie_exists = cookie.server().is_some(); + let server_cookie_is_valid = cookie.check_server_hash( + request.client_addr().ip(), + &self.server_secret, + Self::timestamp_ok, + ); + + if !server_cookie_is_valid { + // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.3 + // Only a Client Cookie: + // "Based on server policy, including rate limiting, the + // server chooses one of the following: + // + // (1) Silently discard the request. + // + // (2) Send a BADCOOKIE error response. + // + // (3) Process the request and provide a normal + // response. The RCODE is NOERROR, unless some + // non-cookie error occurs in processing the + // request. + // + // ... ... + // + // Servers MUST, at least occasionally, respond to such + // requests to inform the client of the correct Server + // Cookie. + // + // ... ... + // + // If the request was received over TCP, the + // server SHOULD take the authentication + // provided by the use of TCP into account and + // SHOULD choose (3). In this case, if the + // server is not willing to accept the security + // provided by TCP as a substitute for the + // security provided by DNS Cookies but instead + // chooses (2), there is some danger of an + // indefinite loop of retries (see Section + // 5.3)." + + // TODO: Does "(1)" above in combination with the text in + // section 5.2.5 "SHALL process the request" mean that we + // are not allowed to reject the request prior to this + // point based on rate limiting or other server policy? + + // TODO: Should we add a configuration option that allows + // for choosing between approaches (1), (2) and (3)? For + // now err on the side of security and go with approach + // (2): send a BADCOOKIE response. + + // https://datatracker.ietf.org/doc/html/rfc7873#section-5.4 + // Querying for a Server Cookie: + // "For servers with DNS Cookies enabled, the QUERY + // opcode behavior is extended to support queries with + // an empty Question Section (a QDCOUNT of zero (0)), + // provided that an OPT record is present with a COOKIE + // option. Such servers will send a reply that has an + // empty Answer Section and has a COOKIE option + // containing the Client Cookie and a valid Server + // Cookie. + + // TODO: Does the TCP check also apply to RFC 7873 section + // 5.4 "Querying for a Server Cookie" too? + + if request.message().header_counts().qdcount() == 0 { + let additional = if !server_cookie_exists { + // "If such a query provided just a Client Cookie + // and no Server Cookie, the response SHALL have + // the RCODE NOERROR." + trace!( + "Replying to DNS cookie pre-fetch request with missing server cookie"); + self.prefetch_cookie_response(request) + } else { + // "In this case, the response SHALL have the + // RCODE BADCOOKIE if the Server Cookie sent with + // the query was invalid" + debug!( + "Rejecting pre-fetch request due to invalid server cookie"); + self.bad_cookie_response(request) + }; + return ControlFlow::Break(additional); + } else if request.transport_ctx().is_udp() { + let additional = self.bad_cookie_response(request); + debug!( + "Rejecting non-TCP request due to invalid server cookie"); + return ControlFlow::Break(additional); + } + } else if request.message().header_counts().qdcount() == 0 { + // https://datatracker.ietf.org/doc/html/rfc7873#section-5.4 + // Querying for a Server Cookie: + // "This mechanism can also be used to + // confirm/re-establish an existing Server Cookie by + // sending a cached Server Cookie with the Client + // Cookie. In this case, the response SHALL have the + // RCODE BADCOOKIE if the Server Cookie sent with the + // query was invalid and the RCODE NOERROR if it was + // valid." + + // TODO: Does the TCP check also apply to RFC 7873 section + // 5.4 "Querying for a Server Cookie" too? + trace!( + "Replying to DNS cookie pre-fetch request with valid server cookie"); + let additional = self.prefetch_cookie_response(request); + return ControlFlow::Break(additional); + } else { + trace!("Request has a valid DNS cookie"); + } + } + } + + trace!("Permitting request to flow"); + + ControlFlow::Continue(()) + } + + fn postprocess( + request: &Request, + response: &mut AdditionalBuilder>, + server_secret: &[u8; 16], + ) where + RequestOctets: Octets, + Target: Composer + Default, + { + trace!("4"); + // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.1 + // No OPT RR or No COOKIE Option: + // If the request lacked a client cookie we don't need to do + // anything. + // + // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.2 + // Malformed COOKIE Option: + // If the request COOKIE option was malformed we would have already + // rejected it during pre-processing so again nothing to do here. + // + // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.3 + // Only a Client Cookie: + // If the request had a client cookie but no server cookie and + // we didn't already reject the request during pre-processing. + // + // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.4 + // A Client Cookie and an Invalid Server Cookie: + // Per RFC 7873 this is handled the same way as the "Only a Client + // Cookie" case. + // + // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.5 + // A Client Cookie and a Valid Server Cookie + // Any server cookie will already have been validated during + // pre-processing, we don't need to check it again here. + + if let Some(filled_cookie) = + Self::ensure_cookie_is_complete(request, server_secret) + { + // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.5 + // "The server SHALL process the request and include a COOKIE + // option in the response by (a) copying the complete COOKIE + // option from the request or (b) generating a new COOKIE option + // containing both the Client Cookie copied from the request and + // a valid Server Cookie it has generated." + if let Err(err) = add_edns_options( + response, + |existing_option_codes, builder| { + if !existing_option_codes.contains(&OptionCode::COOKIE) { + builder.push(&filled_cookie) + } else { + Ok(()) + } + }, + ) { + warn!("Cannot add RFC 7873 DNS Cookie option to response: {err}"); + } + } + + if enabled!(Level::TRACE) { + let bytes = response.as_slice(); + let pcap_text = to_pcap_text(bytes, bytes.len()); + trace!(pcap_text, "post-processing complete"); + } + } +} + +//--- Service + +impl Service + for CookiesMiddlewareSvc +where + RequestOctets: Octets + 'static, + S: Service, + S::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Unpin + + 'static, + Target: Composer + Default + 'static + Unpin, +{ + type Target = Target; + type Stream = MiddlewareStream< + S::Stream, + PostprocessingStream, + Target, + >; + + fn call(&self, request: Request) -> Self::Stream { + match self.preprocess(&request) { + ControlFlow::Continue(()) => { + let st = self.inner.call(request.clone()); + let map = PostprocessingStream::new( + st, + request, + self.server_secret.clone(), + ); + MiddlewareStream::Postprocess(map) + } + ControlFlow::Break(mut response) => { + Self::postprocess( + &request, + &mut response, + &self.server_secret, + ); + + MiddlewareStream::HandledOne(once(ready(Ok( + CallResult::new(response), + )))) + } + } + } +} + +pub struct PostprocessingStream< + RequestOctets, + Target, + InnerServiceResponseStream, +> where + RequestOctets: Octets, + InnerServiceResponseStream: futures::stream::Stream< + Item = Result, ServiceError>, + >, +{ + request: Request, + server_secret: Arc<[u8; 16]>, + stream: InnerServiceResponseStream, + _phantom: PhantomData, +} + +impl<'a, RequestOctets, Target, InnerServiceResponseStream> + PostprocessingStream +where + RequestOctets: Octets, + InnerServiceResponseStream: futures::stream::Stream< + Item = Result, ServiceError>, + >, +{ + pub(crate) fn new( + stream: InnerServiceResponseStream, + request: Request, + server_secret: Arc<[u8; 16]>, + ) -> Self { + Self { + stream, + request, + server_secret, + _phantom: PhantomData, + } + } +} + +impl Stream + for PostprocessingStream< + RequestOctets, + Target, + InnerServiceResponseStream, + > +where + RequestOctets: Octets, + InnerServiceResponseStream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Unpin, + Target: Composer + Default + Unpin, +{ + type Item = Result, ServiceError>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let res = futures::ready!(self.stream.poll_next_unpin(cx)); + let request = self.request.clone(); + let server_secret = self.server_secret.clone(); + Poll::Ready(res.map(|mut res| { + if let Ok(cr) = &mut res { + if let Some(response) = cr.get_response_mut() { + CookiesMiddlewareSvc::::postprocess(&request, response, &server_secret); + } + } + res + })) + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} diff --git a/src/net/server/middleware/processors/edns.rs b/src/net/server/middleware/processors/edns.rs index 54ce92f44..d465a5ad1 100644 --- a/src/net/server/middleware/processors/edns.rs +++ b/src/net/server/middleware/processors/edns.rs @@ -4,7 +4,7 @@ use core::ops::ControlFlow; use octseq::Octets; use tracing::{debug, enabled, error, trace, warn, Level}; -use crate::base::iana::OptRcode; +use crate::base::iana::{OptRcode, OptionCode}; use crate::base::message_builder::AdditionalBuilder; use crate::base::opt::keepalive::IdleTimeout; use crate::base::opt::{Opt, OptRecord, TcpKeepalive}; @@ -64,7 +64,7 @@ impl EdnsMiddlewareProcessor { // Note: if rcode is non-extended this will also correctly handle // setting the rcode in the main message header. - if let Err(err) = add_edns_options(&mut additional, |opt| { + if let Err(err) = add_edns_options(&mut additional, |_, opt| { opt.set_rcode(rcode); Ok(()) }) { @@ -291,10 +291,14 @@ where // using the edns-tcp-keepalive EDNS(0) option // [RFC7828]". if let Err(err) = - add_edns_options(response, |builder| { - builder.push(&TcpKeepalive::new( - Some(timeout), - )) + add_edns_options(response, |existing_option_codes, builder| { + if !existing_option_codes.contains(&OptionCode::TCP_KEEPALIVE) { + builder.push(&TcpKeepalive::new( + Some(timeout), + )) + } else { + Ok(()) + } }) { warn!("Cannot add RFC 7828 edns-tcp-keepalive option to response: {err}"); diff --git a/src/net/server/middleware/processors/edns_svc.rs b/src/net/server/middleware/processors/edns_svc.rs new file mode 100644 index 000000000..b20f04c35 --- /dev/null +++ b/src/net/server/middleware/processors/edns_svc.rs @@ -0,0 +1,600 @@ +//! RFC 6891 and related EDNS message processing. +use core::future::ready; +use core::marker::PhantomData; +use core::ops::ControlFlow; +use core::task::{Context, Poll}; + +use std::pin::Pin; + +use futures::stream::once; +use futures::Stream; +use futures_util::StreamExt; +use octseq::Octets; +use tracing::{debug, enabled, error, trace, warn, Level}; + +use crate::base::iana::{OptRcode, OptionCode}; +use crate::base::message_builder::AdditionalBuilder; +use crate::base::opt::keepalive::IdleTimeout; +use crate::base::opt::{Opt, OptRecord, TcpKeepalive}; +use crate::base::wire::Composer; +use crate::base::StreamTarget; +use crate::net::server::message::{Request, TransportSpecificContext}; +use crate::net::server::middleware::processors::mandatory::MINIMUM_RESPONSE_BYTE_LEN; +use crate::net::server::middleware::util::MiddlewareStream; +use crate::net::server::service::{CallResult, Service, ServiceError}; +use crate::net::server::util::start_reply; +use crate::net::server::util::{add_edns_options, remove_edns_opt_record}; + +/// EDNS version 0. +/// +/// Version 0 is the highest EDNS version number recoded in the [IANA +/// registry] at the time of writing. +/// +/// [IANA registry]: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-14 +const EDNS_VERSION_ZERO: u8 = 0; + +/// A [`MiddlewareProcessor`] for adding EDNS(0) related functionality. +/// +/// Standards covered by ths implementation: +/// +/// | RFC | Status | +/// |--------|---------| +/// | [6891] | TBD | +/// | [7828] | TBD | +/// | [9210] | TBD | +/// +/// [6891]: https://datatracker.ietf.org/doc/html/rfc6891 +/// [7828]: https://datatracker.ietf.org/doc/html/rfc7828 +/// [9210]: https://datatracker.ietf.org/doc/html/rfc9210 +/// [`MiddlewareProcessor`]: crate::net::server::middleware::processor::MiddlewareProcessor +#[derive(Debug, Default)] +pub struct EdnsMiddlewareSvc { + inner: S, +} + +impl EdnsMiddlewareSvc { + /// Creates an instance of this processor. + #[must_use] + pub fn new(inner: S) -> Self { + Self { inner } + } +} + +impl EdnsMiddlewareSvc { + /// Create a DNS error response to the given request with the given RCODE. + fn error_response( + request: &Request, + rcode: OptRcode, + ) -> AdditionalBuilder> + where + RequestOctets: Octets, + Target: Composer + Default, + { + let mut additional = start_reply(request).additional(); + + // Note: if rcode is non-extended this will also correctly handle + // setting the rcode in the main message header. + if let Err(err) = add_edns_options(&mut additional, |_, opt| { + opt.set_rcode(rcode); + Ok(()) + }) { + warn!( + "Failed to set (extended) error '{rcode}' in response: {err}" + ); + } + + Self::postprocess(request, &mut additional); + additional + } +} + +//--- MiddlewareProcessor + +impl EdnsMiddlewareSvc { + fn preprocess( + &self, + request: &Request, + ) -> ControlFlow>> + where + RequestOctets: Octets, + Target: Composer + Default, + { + // https://www.rfc-editor.org/rfc/rfc6891.html#section-6.1.1 + // 6.1.1: Basic Elements + // ... + // "If a query message with more than one OPT RR is received, a + // FORMERR (RCODE=1) MUST be returned" + let msg = request.message().clone(); + if let Ok(additional) = msg.additional() { + let mut iter = additional.limit_to::>(); + if let Some(opt) = iter.next() { + if iter.next().is_some() { + // More than one OPT RR received. + debug!("RFC 6891 6.1.1 violation: request contains more than one OPT RR."); + return ControlFlow::Break(Self::error_response( + request, + OptRcode::FORMERR, + )); + } + + if let Ok(opt) = opt { + let opt_rec = OptRecord::from(opt); + + // https://datatracker.ietf.org/doc/html/rfc6891#section-6.1.3 + // 6.1.3. OPT Record TTL Field Use + // "If a responder does not implement the VERSION level + // of the request, then it MUST respond with + // RCODE=BADVERS." + if opt_rec.version() > EDNS_VERSION_ZERO { + debug!("RFC 6891 6.1.3 violation: request EDNS version {} > 0", opt_rec.version()); + return ControlFlow::Break(Self::error_response( + request, + OptRcode::BADVERS, + )); + } + + match request.transport_ctx() { + TransportSpecificContext::Udp(ctx) => { + // https://datatracker.ietf.org/doc/html/rfc7828#section-3.2.1 + // 3.2.1. Sending Queries + // "DNS clients MUST NOT include the + // edns-tcp-keepalive option in queries sent + // using UDP transport." + // TODO: We assume there is only one keep-alive + // option in the request. Should we check for + // multiple? Neither RFC 6891 nor RFC 7828 seem to + // disallow multiple keep alive options in the OPT + // RDATA but multiple at once seems strange. + if opt_rec.opt().tcp_keepalive().is_some() { + debug!("RFC 7828 3.2.1 violation: edns-tcp-keepalive option received via UDP"); + return ControlFlow::Break( + Self::error_response( + request, + OptRcode::FORMERR, + ), + ); + } + + // https://datatracker.ietf.org/doc/html/rfc6891#section-6.2.3 + // 6.2.3. Requestor's Payload Size + // "The requestor's UDP payload size (encoded in + // the RR CLASS field) is the number of octets + // of the largest UDP payload that can be + // reassembled and delivered in the requestor's + // network stack. Note that path MTU, with or + // without fragmentation, could be smaller than + // this. + // + // Values lower than 512 MUST be treated as + // equal to 512." + let requestors_udp_payload_size = + opt_rec.udp_payload_size(); + + if requestors_udp_payload_size + < MINIMUM_RESPONSE_BYTE_LEN + { + debug!("RFC 6891 6.2.3 violation: OPT RR class (requestor's UDP payload size) < {MINIMUM_RESPONSE_BYTE_LEN}"); + } + + // Clamp the lower bound of the size limit + // requested by the client: + let clamped_requestors_udp_payload_size = + u16::max(512, requestors_udp_payload_size); + + // Clamp the upper bound of the size limit + // requested by the server: + let server_max_response_size_hint = + ctx.max_response_size_hint(); + let clamped_server_hint = + server_max_response_size_hint.map(|v| { + v.clamp( + MINIMUM_RESPONSE_BYTE_LEN, + clamped_requestors_udp_payload_size, + ) + }); + + // Use the clamped client size limit if no server hint exists, + // otherwise use the smallest of the client and server limits + // while not going lower than 512 bytes. + let negotiated_hint = match clamped_server_hint { + Some(clamped_server_hint) => u16::min( + clamped_requestors_udp_payload_size, + clamped_server_hint, + ), + + None => clamped_requestors_udp_payload_size, + }; + + if enabled!(Level::TRACE) { + trace!("EDNS(0) response size negotation concluded: client requested={}, server requested={:?}, chosen value={}", + opt_rec.udp_payload_size(), server_max_response_size_hint, negotiated_hint); + } + + ctx.set_max_response_size_hint(Some( + negotiated_hint, + )); + } + + TransportSpecificContext::NonUdp(_) => { + // https://datatracker.ietf.org/doc/html/rfc7828#section-3.2.1 + // 3.2.1. Sending Queries + // "Clients MUST specify an OPTION-LENGTH of 0 + // and omit the TIMEOUT value." + if let Some(keep_alive) = + opt_rec.opt().tcp_keepalive() + { + if keep_alive.timeout().is_some() { + debug!("RFC 7828 3.2.1 violation: edns-tcp-keepalive option received via TCP contains timeout"); + return ControlFlow::Break( + Self::error_response( + request, + OptRcode::FORMERR, + ), + ); + } + } + } + } + } + } + } + + ControlFlow::Continue(()) + } + + fn postprocess( + request: &Request, + response: &mut AdditionalBuilder>, + ) where + RequestOctets: Octets, + Target: Composer + Default, + { + // https://www.rfc-editor.org/rfc/rfc6891.html#section-6.1.1 + // 6.1.1: Basic Elements + // ... + // "If an OPT record is present in a received request, compliant + // responders MUST include an OPT record in their respective + // responses." + // + // We don't do anything about this scenario at present. + + // https://www.rfc-editor.org/rfc/rfc6891.html#section-7 + // 7: Transport considerations + // ... + // "Lack of presence of an OPT record in a request MUST be taken as an + // indication that the requestor does not implement any part of this + // specification and that the responder MUST NOT include an OPT + // record in its response." + // + // So strip off any OPT record present if the query lacked an OPT + // record. + if request.message().opt().is_none() { + if let Err(err) = remove_edns_opt_record(response) { + error!( + "Error while stripping OPT record from response: {err}" + ); + *response = Self::error_response(request, OptRcode::SERVFAIL); + return; + } + } + + // https://datatracker.ietf.org/doc/html/rfc7828#section-3.3.2 + // 3.3.2. Sending Responses + // "A DNS server that receives a query sent using TCP transport that + // includes an OPT RR (with or without the edns-tcp-keepalive + // option) MAY include the edns-tcp-keepalive option in the + // response to signal the expected idle timeout on a connection. + // Servers MUST specify the TIMEOUT value that is currently + // associated with the TCP session." + // + // https://datatracker.ietf.org/doc/html/rfc9210#section-4.2 + // 4.2. Connection Management + // "... DNS clients and servers SHOULD signal their timeout values + // using the edns-tcp-keepalive EDNS(0) option [RFC7828]." + if let TransportSpecificContext::NonUdp(ctx) = request.transport_ctx() + { + if let Some(idle_timeout) = ctx.idle_timeout() { + if let Ok(additional) = request.message().additional() { + let mut iter = additional.limit_to::>(); + if iter.next().is_some() { + match IdleTimeout::try_from(idle_timeout) { + Ok(timeout) => { + // Request has an OPT RR and server idle + // timeout is known: "Signal the timeout value + // using the edns-tcp-keepalive EDNS(0) option + // [RFC7828]". + if let Err(err) = add_edns_options( + response, + |existing_option_codes, builder| { + if !existing_option_codes.contains( + &OptionCode::TCP_KEEPALIVE, + ) { + builder.push(&TcpKeepalive::new( + Some(timeout), + )) + } else { + Ok(()) + } + }, + ) { + warn!("Cannot add RFC 7828 edns-tcp-keepalive option to response: {err}"); + } + } + + Err(err) => { + warn!("Cannot add RFC 7828 edns-tcp-keepalive option to response: invalid timeout: {err}"); + } + } + } + } + } + } + + // TODO: For UDP EDNS capable clients (those that included an OPT + // record in the request) should we set the Requestor's Payload Size + // field to some value? + } +} + +//--- Service + +impl Service for EdnsMiddlewareSvc +where + RequestOctets: Octets + 'static, + S: Service, + S::Stream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Unpin + + 'static, + Target: Composer + Default + 'static + Unpin, +{ + type Target = Target; + type Stream = MiddlewareStream< + S::Stream, + PostprocessingStream, + Target, + >; + + fn call(&self, request: Request) -> Self::Stream { + match self.preprocess(&request) { + ControlFlow::Continue(()) => { + let st = self.inner.call(request.clone()); + let map = PostprocessingStream::new(st, request); + MiddlewareStream::Postprocess(map) + } + ControlFlow::Break(mut response) => { + Self::postprocess(&request, &mut response); + MiddlewareStream::HandledOne(once(ready(Ok( + CallResult::new(response), + )))) + } + } + } +} + +pub struct PostprocessingStream< + RequestOctets, + Target, + InnerServiceResponseStream, +> where + RequestOctets: Octets, + InnerServiceResponseStream: futures::stream::Stream< + Item = Result, ServiceError>, + >, +{ + request: Request, + _phantom: PhantomData, + stream: InnerServiceResponseStream, +} + +impl + PostprocessingStream +where + RequestOctets: Octets, + InnerServiceResponseStream: futures::stream::Stream< + Item = Result, ServiceError>, + >, +{ + pub(crate) fn new( + stream: InnerServiceResponseStream, + request: Request, + ) -> Self { + Self { + stream, + request, + _phantom: PhantomData, + } + } +} + +impl Stream + for PostprocessingStream< + RequestOctets, + Target, + InnerServiceResponseStream, + > +where + RequestOctets: Octets, + InnerServiceResponseStream: futures::stream::Stream< + Item = Result, ServiceError>, + > + Unpin, + Target: Composer + Default + Unpin, +{ + type Item = Result, ServiceError>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + let res = futures::ready!(self.stream.poll_next_unpin(cx)); + let request = self.request.clone(); + Poll::Ready(res.map(|mut res| { + if let Ok(cr) = &mut res { + if let Some(response) = cr.get_response_mut() { + EdnsMiddlewareSvc::::postprocess(&request, response); + } + } + res + })) + } + + fn size_hint(&self) -> (usize, Option) { + self.stream.size_hint() + } +} + +#[cfg(test)] +mod tests { + use core::pin::Pin; + + use std::boxed::Box; + use std::vec::Vec; + + use bytes::Bytes; + use futures::stream::Once; + use futures::stream::StreamExt; + use tokio::time::Instant; + + use crate::base::{Dname, Message, MessageBuilder, Rtype}; + use crate::net::server::message::{ + Request, TransportSpecificContext, UdpTransportContext, + }; + + use crate::base::iana::Rcode; + use crate::net::server::middleware::processors::mandatory::MINIMUM_RESPONSE_BYTE_LEN; + use crate::net::server::service::{CallResult, Service, ServiceError}; + use crate::net::server::util::{mk_builder_for_target, service_fn}; + + use super::EdnsMiddlewareSvc; + + //------------ Constants ------------------------------------------------- + + const MIN_ALLOWED: Option = Some(MINIMUM_RESPONSE_BYTE_LEN); + const TOO_SMALL: Option = Some(511); + const JUST_RIGHT: Option = MIN_ALLOWED; + const HUGE: Option = Some(u16::MAX); + + //------------ Tests ----------------------------------------------------- + + #[tokio::test] + async fn clamp_max_response_size_correctly() { + // Neither client or server specified a max UDP response size. + assert_eq!(process(None, None).await, None); + + // --- Only server specified max UDP response sizes + // + // The EdnsMiddlewareProcessor should leave these untouched as no EDNS + // option was present in the request, only the server hint exists, and + // EdnsMiddlewareProcessor only acts if the client EDNS option is + // present. + assert_eq!(process(None, TOO_SMALL).await, TOO_SMALL); + assert_eq!(process(None, JUST_RIGHT).await, JUST_RIGHT); + assert_eq!(process(None, HUGE).await, HUGE); + + // --- Only client specified max UDP response sizes + // + // The EdnsMiddlewareProcessor should adopt these, after clamping + // them. + assert_eq!(process(TOO_SMALL, None).await, JUST_RIGHT); + assert_eq!(process(JUST_RIGHT, None).await, JUST_RIGHT); + assert_eq!(process(HUGE, None).await, HUGE); + + // --- Both client and server specified max UDP response sizes + // + // The EdnsMiddlewareProcessor should negotiate the largest size + // acceptable to both sides. + assert_eq!(process(TOO_SMALL, TOO_SMALL).await, MIN_ALLOWED); + assert_eq!(process(TOO_SMALL, JUST_RIGHT).await, JUST_RIGHT); + assert_eq!(process(TOO_SMALL, HUGE).await, MIN_ALLOWED); + assert_eq!(process(JUST_RIGHT, TOO_SMALL).await, JUST_RIGHT); + assert_eq!(process(JUST_RIGHT, JUST_RIGHT).await, JUST_RIGHT); + assert_eq!(process(JUST_RIGHT, HUGE).await, JUST_RIGHT); + assert_eq!(process(HUGE, TOO_SMALL).await, MIN_ALLOWED); + assert_eq!(process(HUGE, JUST_RIGHT).await, JUST_RIGHT); + assert_eq!(process(HUGE, HUGE).await, HUGE); + } + + //------------ Helper functions ------------------------------------------ + + async fn process( + client_value: Option, + server_value: Option, + ) -> Option { + // Build a dummy DNS query. + let query = MessageBuilder::new_vec(); + + // With a dummy question. + let mut query = query.question(); + query.push((Dname::::root(), Rtype::A)).unwrap(); + + // And if requested, a requestor's UDP payload size: + let message: Message<_> = if let Some(v) = client_value { + let mut additional = query.additional(); + additional + .opt(|builder| { + builder.set_udp_payload_size(v); + Ok(()) + }) + .unwrap(); + additional.into_message() + } else { + query.into_message() + }; + + // Package the query into a context aware request to make it look + // as if it came from a UDP server. + let ctx = UdpTransportContext::new(server_value); + let request = Request::new( + "127.0.0.1:12345".parse().unwrap(), + Instant::now(), + message, + TransportSpecificContext::Udp(ctx), + ); + + fn my_service( + req: Request>, + _meta: (), + ) -> Once< + Pin< + Box< + dyn std::future::Future< + Output = Result< + CallResult>, + ServiceError, + >, + > + Send, + >, + >, + > { + // For each request create a single response: + let msg = req.message().clone(); + futures::stream::once(Box::pin(async move { + let builder = mk_builder_for_target(); + let answer = builder.start_answer(&msg, Rcode::NXDOMAIN)?; + Ok(CallResult::new(answer.additional())) + })) + } + + // Either call the service directly. + let my_svc = service_fn(my_service, ()); + let mut stream = my_svc.call(request.clone()); + let _call_result: CallResult> = + stream.next().await.unwrap().unwrap(); + + // Or pass the query through the middleware processor + let processor_svc = EdnsMiddlewareSvc::new(my_svc); + let mut stream = processor_svc.call(request.clone()); + let call_result: CallResult> = + stream.next().await.unwrap().unwrap(); + let (_response, _feedback) = call_result.into_inner(); + + // Get the modified response size hint. + let TransportSpecificContext::Udp(modified_udp_context) = + request.transport_ctx() + else { + unreachable!() + }; + + modified_udp_context.max_response_size_hint() + } +} diff --git a/src/net/server/middleware/processors/mod.rs b/src/net/server/middleware/processors/mod.rs index b0add717c..ce5dd2352 100644 --- a/src/net/server/middleware/processors/mod.rs +++ b/src/net/server/middleware/processors/mod.rs @@ -1,8 +1,12 @@ //! Pre-supplied [`MiddlewareProcessor`] implementations. //! //! [`MiddlewareProcessor`]: super::processor::MiddlewareProcessor + #[cfg(feature = "siphasher")] pub mod cookies; +#[cfg(feature = "siphasher")] +pub mod cookies_svc; pub mod edns; +pub mod edns_svc; pub mod mandatory; -pub mod mandatory_svc; +pub mod mandatory_svc; \ No newline at end of file diff --git a/src/net/server/util.rs b/src/net/server/util.rs index aa55560fe..ec234572e 100644 --- a/src/net/server/util.rs +++ b/src/net/server/util.rs @@ -15,7 +15,8 @@ use crate::rdata::AllRecordData; use super::message::Request; use super::service::{CallResult, Service, ServiceError}; -use crate::base::iana::Rcode; +use crate::base::iana::{OptionCode, Rcode}; +use smallvec::SmallVec; //----------- mk_builder_for_target() ---------------------------------------- @@ -198,12 +199,22 @@ where /// /// If the response already has an OPT record the options will be added to /// that. Otherwise an OPT record will be created to hold the new options. +/// +/// Similar to [`AdditionalBuilder::opt`] a caller supplied closure is passed +/// an [`OptBuilder`] which can be used to add EDNS options and set EDNS +/// header fields. +/// +/// However, unlike [`AdditionalBuilder::opt`], the closure is also passed a +/// collection of option codes for the options that already exist so that the +/// caller can avoid adding the same type of option more than once if that is +/// important to them. pub fn add_edns_options( response: &mut AdditionalBuilder>, op: F, ) -> Result<(), PushError> where F: FnOnce( + &[OptionCode], &mut OptBuilder>, ) -> Result< (), @@ -252,12 +263,24 @@ where // the options within the existing OPT record plus the new options // that we want to add. let res = response.opt(|builder| { + let mut existing_option_codes = + SmallVec::<[OptionCode; 4]>::new(); + // Copy the header fields + builder.set_version(current_opt.version()); + builder.set_dnssec_ok(current_opt.dnssec_ok()); + builder.set_rcode(current_opt.rcode(copied_response.header())); + builder.set_udp_payload_size(current_opt.udp_payload_size()); + + // Copy the options for opt in current_opt.opt().iter::>().flatten() { + existing_option_codes.push(opt.code()); builder.push(&opt)?; } - op(builder) + + // Invoking the user supplied callback + op(&existing_option_codes, builder) }); return res; @@ -265,7 +288,7 @@ where } // No existing OPT record in the additional section so build a new one. - response.opt(op) + response.opt(|builder| op(&[], builder)) } /// Removes any OPT records present in the response. From 3dc28bee927433d7125fa627eac86a385446f5a0 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 16 Apr 2024 14:38:54 +0200 Subject: [PATCH 09/28] Remove the old middleware processors, pocessor trait, chain and builder, none of it is needed in the new middleware-as-layered-services world. --- examples/serve-zone.rs | 6 +- src/net/server/message.rs | 1 - src/net/server/middleware/builder.rs | 139 ----- src/net/server/middleware/chain.rs | 180 ------- .../{processors/cookies_svc.rs => cookies.rs} | 0 .../{processors/edns_svc.rs => edns.rs} | 4 +- .../mandatory_svc.rs => mandatory.rs} | 7 +- src/net/server/middleware/mod.rs | 36 +- src/net/server/middleware/processor.rs | 39 -- .../server/middleware/processors/cookies.rs | 492 ------------------ src/net/server/middleware/processors/edns.rs | 439 ---------------- .../server/middleware/processors/mandatory.rs | 421 --------------- src/net/server/middleware/processors/mod.rs | 12 - src/net/server/util.rs | 7 +- 14 files changed, 16 insertions(+), 1767 deletions(-) delete mode 100644 src/net/server/middleware/builder.rs delete mode 100644 src/net/server/middleware/chain.rs rename src/net/server/middleware/{processors/cookies_svc.rs => cookies.rs} (100%) rename src/net/server/middleware/{processors/edns_svc.rs => edns.rs} (99%) rename src/net/server/middleware/{processors/mandatory_svc.rs => mandatory.rs} (99%) delete mode 100644 src/net/server/middleware/processor.rs delete mode 100644 src/net/server/middleware/processors/cookies.rs delete mode 100644 src/net/server/middleware/processors/edns.rs delete mode 100644 src/net/server/middleware/processors/mandatory.rs delete mode 100644 src/net/server/middleware/processors/mod.rs diff --git a/examples/serve-zone.rs b/examples/serve-zone.rs index a9fd1c481..b13a04969 100644 --- a/examples/serve-zone.rs +++ b/examples/serve-zone.rs @@ -21,9 +21,9 @@ use domain::base::{Dname, Message, Rtype, ToDname}; use domain::net::server::buf::VecBufSource; use domain::net::server::dgram::DgramServer; use domain::net::server::message::Request; -use domain::net::server::middleware::processors::cookies_svc::CookiesMiddlewareSvc; -use domain::net::server::middleware::processors::edns_svc::EdnsMiddlewareSvc; -use domain::net::server::middleware::processors::mandatory_svc::MandatoryMiddlewareSvc; +use domain::net::server::middleware::cookies::CookiesMiddlewareSvc; +use domain::net::server::middleware::edns::EdnsMiddlewareSvc; +use domain::net::server::middleware::mandatory::MandatoryMiddlewareSvc; use domain::net::server::service::{CallResult, ServiceError}; use domain::net::server::stream::StreamServer; use domain::net::server::util::{mk_builder_for_target, service_fn}; diff --git a/src/net/server/message.rs b/src/net/server/message.rs index c77664767..0fc06003a 100644 --- a/src/net/server/message.rs +++ b/src/net/server/message.rs @@ -7,7 +7,6 @@ use tokio::time::Instant; use crate::base::Message; - //------------ UdpTransportContext ------------------------------------------- /// Request context for a UDP transport. diff --git a/src/net/server/middleware/builder.rs b/src/net/server/middleware/builder.rs deleted file mode 100644 index 91b8b64b0..000000000 --- a/src/net/server/middleware/builder.rs +++ /dev/null @@ -1,139 +0,0 @@ -//! Middleware builders. -use std::sync::Arc; -use std::vec::Vec; - -use octseq::Octets; - -use crate::base::wire::Composer; - -use super::chain::MiddlewareChain; -use super::processor::MiddlewareProcessor; -use super::processors::edns::EdnsMiddlewareProcessor; -use super::processors::mandatory::MandatoryMiddlewareProcessor; - -/// A [`MiddlewareChain`] builder. -/// -/// A [`MiddlewareChain`] is immutable and so cannot be constructed one -/// [`MiddlewareProcessor`] at a time. -/// -/// This builder allows you to add [`MiddlewareProcessor`]s sequentially using -/// [`push`] before finally calling [`build`] to turn the builder into an -/// immutable [`MiddlewareChain`]. -/// -/// [`push`]: Self::push() -/// [`build`]: Self::build() -pub struct MiddlewareBuilder, Target = Vec> { - /// The ordered set of processors which will pre-process requests and then - /// in reverse order will post-process responses. - processors: Vec< - Arc< - dyn MiddlewareProcessor - + Send - + Sync - + 'static, - >, - >, -} - -impl MiddlewareBuilder -where - RequestOctets: Octets, - Target: Composer + Default, -{ - /// Create a new empty builder. - /// - ///
Warning: - /// - /// When building a standards compliant DNS server you should probably use - /// [`MiddlewareBuilder::minimal`] or [`MiddlewareBuilder::standard`] - /// instead. - ///
- /// - /// [`MiddlewareBuilder::minimal`]: Self::minimal() - /// [`MiddlewareBuilder::standard`]: Self::standard() - #[must_use] - pub fn new() -> Self { - Self { processors: vec![] } - } - - /// Creates a new builder pre-populated with "minimal" middleware - /// processors. - /// - /// The default configuration pre-populates the builder with a - /// [`MandatoryMiddlewareProcessor`] in the chain. - /// - /// This is the minimum most normal DNS servers probably need to comply - /// with applicable RFC standards for DNS servers, only special cases like - /// testing and research may want a chain that doesn't start with the - /// mandatory processor. - #[must_use] - pub fn minimal() -> Self { - let mut builder = Self::new(); - builder.push(MandatoryMiddlewareProcessor::default().into()); - builder - } - - /// Creates a new builder pre-populated with "standard" middleware - /// processors. - /// - /// The constructed builder will be pre-populated with the following - /// [`MiddlewareProcessor`]s in their [`Default`] configuration. - /// - /// - [`MandatoryMiddlewareProcessor`] - /// - [`EdnsMiddlewareProcessor`] - #[must_use] - pub fn standard() -> Self { - let mut builder = Self::new(); - - builder.push(MandatoryMiddlewareProcessor::default().into()); - - #[allow(clippy::default_constructed_unit_structs)] - builder.push(EdnsMiddlewareProcessor::default().into()); - - builder - } - - /// Add a [`MiddlewareProcessor`] to the end of the chain. - /// - /// Processors later in the chain pre-process requests after, and - /// post-process responses before, than processors earlier in the chain. - pub fn push(&mut self, processor: Arc) - where - T: MiddlewareProcessor + Send + Sync + 'static, - { - self.processors.push(processor); - } - - /// Add a [`MiddlewareProcessor`] to the start of the chain. - /// - /// Processors later in the chain pre-process requests after, and - /// post-process responses before, processors earlier in the chain. - pub fn push_front(&mut self, processor: Arc) - where - T: MiddlewareProcessor + Send + Sync + 'static, - { - self.processors.insert(0, processor); - } - - /// Turn the builder into an immutable [`MiddlewareChain`]. - #[must_use] - pub fn build(self) -> MiddlewareChain { - MiddlewareChain::new(self.processors) - } -} - -//--- Default - -impl Default - for MiddlewareBuilder -where - RequestOctets: Octets, - Target: Composer + Default, -{ - /// Create a middleware builder with default, aka "standard", processors. - /// - /// See [`Self::standard`]. - fn default() -> Self { - Self::standard() - } -} diff --git a/src/net/server/middleware/chain.rs b/src/net/server/middleware/chain.rs deleted file mode 100644 index 470cd0659..000000000 --- a/src/net/server/middleware/chain.rs +++ /dev/null @@ -1,180 +0,0 @@ -//! Chaining [`MiddlewareProcessor`]s together. -use core::future::ready; -use core::ops::{ControlFlow, RangeTo}; - -use std::fmt::Debug; -use std::sync::Arc; -use std::vec::Vec; - -use futures::stream::once; - -use crate::base::message_builder::AdditionalBuilder; -use crate::base::wire::Composer; -use crate::base::StreamTarget; -use crate::net::server::message::Request; -use crate::net::server::service::{CallResult, ServiceError}; - -use super::processor::MiddlewareProcessor; - -/// A chain of [`MiddlewareProcessor`]s. -/// -/// Processors earlier in the chain process requests _before_ and responses -/// _after_ processors later in the chain. -/// -/// The chain can be cloned in order to use it with more than one server at -/// once, assuming that you want to use exactly the same set of processors for -/// all servers using the same chain. -/// -/// A [`MiddlewareChain`] is immutable. Requests should not be post-processed -/// by a different or modified chain than they were pre-processed by. -#[derive(Default)] -pub struct MiddlewareChain { - /// The ordered set of processors which will pre-process requests and then - /// in reverse order will post-process responses. - processors: Arc< - Vec< - Arc + Sync + Send>, - >, - >, -} - -impl MiddlewareChain { - /// Create a new _empty_ chain of processors. - /// - ///
Warning: - /// - /// Most DNS server implementations will need to perform mandatory - /// pre-processing of requests and post-processing of responses in order - /// to comply with RFC defined standards. - /// - /// By using this function you are responsible for ensuring that you - /// perform such processing yourself. - /// - /// Most users should **NOT** use this function but should instead use - /// [`MiddlewareBuilder::default`] which constructs a chain that starts - /// with [`MandatoryMiddlewareProcessor`]. - ///
- /// - /// [`MiddlewareBuilder::default`]: - /// super::builder::MiddlewareBuilder::default() - /// [`MandatoryMiddlewareProcessor`]: - /// super::processors::mandatory::MandatoryMiddlewareProcessor - #[must_use] - pub fn new( - processors: Vec< - Arc + Send + Sync>, - >, - ) -> MiddlewareChain { - Self { - processors: Arc::new(processors), - } - } -} - -impl MiddlewareChain -where - RequestOctets: AsRef<[u8]> + Send + 'static, - Target: Composer + Default + Send + 'static, -{ - /// Walks the chain forward invoking pre-processors one by one. - /// - /// Pre-processors may inspect the given [`Request`] but may not generally - /// edit the request. There is some very limited support for editing the - /// context of the request but not the original DNS message contained - /// within it. - /// - /// Returns either [`ControlFlow::Continue`] indicating that processing of - /// the request should continue, or [`ControlFlow::Break`] indicating that - /// a pre-processor decided to terminate processing of the request. - /// - /// On [`ControlFlow::Break`] the caller should pass the given result to - /// [`postprocess`][Self::postprocess]. If processing terminated early the - /// result includes the index of the pre-processor which terminated the - /// processing. - /// - /// # Performance - /// - /// Pre-processing may take place in the same task that handles receipt - /// and pre-processing of other requests. It is therefore important to - /// finish pre-processing as quickly as possible. It is also important to - /// put pre-processors which protect the server against doing too much - /// work as early in the chain as possible. - #[allow(clippy::type_complexity)] - pub fn preprocess( - &self, - request: &Request, - ) -> ControlFlow<( - impl futures::stream::Stream< - Item = Result, ServiceError>, - > + Send, - usize, - )> { - for (i, p) in self.processors.iter().enumerate() { - match p.preprocess(request) { - ControlFlow::Continue(()) => { - // Pre-processing complete, move on to the next pre-processor. - } - - ControlFlow::Break(response) => { - // Stop pre-processing, return the produced response - // (after first applying post-processors to it). - let item = ready(Ok(CallResult::new(response))); - return ControlFlow::Break((once(item), i)); - } - } - } - - ControlFlow::Continue(()) - } - - /// Walks the chain backward invoking post-processors one by one. - /// - /// Post-processors either inspect the given response, or may also - /// optionally modify it. - /// - /// The request supplied should be the request to which the response was - /// generated. This is used e.g. for copying the request DNS message ID - /// into the response, or for checking the transport by which the reques - /// was recieved. - /// - /// The optional `last_processor_idx` value should come from an earlier - /// call to [`preprocess`][Self::preprocess]. Post-processing will start - /// with this processor and walk backward from there, post-processors - /// further down the chain will not be invoked. - pub fn postprocess( - &self, - request: &Request, - response: &mut AdditionalBuilder>, - last_processor_idx: Option, - ) { - let processors = match last_processor_idx { - Some(end) => &self.processors[RangeTo { end }], - None => &self.processors[..], - }; - - processors - .iter() - .rev() - .for_each(|p| p.postprocess(request, response)); - } -} - -//--- Clone - -impl Clone for MiddlewareChain { - fn clone(&self) -> Self { - Self { - processors: self.processors.clone(), - } - } -} - -//--- Debug - -impl Debug for MiddlewareChain { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("MiddlewareChain") - .field("processors", &self.processors.len()) - .finish() - } -} diff --git a/src/net/server/middleware/processors/cookies_svc.rs b/src/net/server/middleware/cookies.rs similarity index 100% rename from src/net/server/middleware/processors/cookies_svc.rs rename to src/net/server/middleware/cookies.rs diff --git a/src/net/server/middleware/processors/edns_svc.rs b/src/net/server/middleware/edns.rs similarity index 99% rename from src/net/server/middleware/processors/edns_svc.rs rename to src/net/server/middleware/edns.rs index b20f04c35..a5cca9f30 100644 --- a/src/net/server/middleware/processors/edns_svc.rs +++ b/src/net/server/middleware/edns.rs @@ -12,6 +12,7 @@ use futures_util::StreamExt; use octseq::Octets; use tracing::{debug, enabled, error, trace, warn, Level}; +use super::mandatory::MINIMUM_RESPONSE_BYTE_LEN; use crate::base::iana::{OptRcode, OptionCode}; use crate::base::message_builder::AdditionalBuilder; use crate::base::opt::keepalive::IdleTimeout; @@ -19,7 +20,6 @@ use crate::base::opt::{Opt, OptRecord, TcpKeepalive}; use crate::base::wire::Composer; use crate::base::StreamTarget; use crate::net::server::message::{Request, TransportSpecificContext}; -use crate::net::server::middleware::processors::mandatory::MINIMUM_RESPONSE_BYTE_LEN; use crate::net::server::middleware::util::MiddlewareStream; use crate::net::server::service::{CallResult, Service, ServiceError}; use crate::net::server::util::start_reply; @@ -461,7 +461,7 @@ mod tests { }; use crate::base::iana::Rcode; - use crate::net::server::middleware::processors::mandatory::MINIMUM_RESPONSE_BYTE_LEN; + use crate::net::server::middleware::mandatory::MINIMUM_RESPONSE_BYTE_LEN; use crate::net::server::service::{CallResult, Service, ServiceError}; use crate::net::server::util::{mk_builder_for_target, service_fn}; diff --git a/src/net/server/middleware/processors/mandatory_svc.rs b/src/net/server/middleware/mandatory.rs similarity index 99% rename from src/net/server/middleware/processors/mandatory_svc.rs rename to src/net/server/middleware/mandatory.rs index c3146961c..70d071bc6 100644 --- a/src/net/server/middleware/processors/mandatory_svc.rs +++ b/src/net/server/middleware/mandatory.rs @@ -12,12 +12,12 @@ use futures::{Stream, StreamExt}; use octseq::Octets; use tracing::{debug, error, trace, warn}; +use super::util::MiddlewareStream; use crate::base::iana::{Opcode, Rcode}; use crate::base::message_builder::{AdditionalBuilder, PushError}; use crate::base::wire::{Composer, ParseError}; use crate::base::StreamTarget; use crate::net::server::message::{Request, TransportSpecificContext}; -use crate::net::server::middleware::util::MiddlewareStream; use crate::net::server::service::{CallResult, Service, ServiceError}; use crate::net::server::util::{mk_builder_for_target, start_reply}; @@ -462,17 +462,16 @@ mod tests { use octseq::OctetsBuilder; use tokio::time::Instant; - use super::MandatoryMiddlewareSvc; - use crate::base::iana::{OptionCode, Rcode}; use crate::base::{Dname, MessageBuilder, Rtype}; use crate::net::server::message::{ Request, TransportSpecificContext, UdpTransportContext, }; - use crate::net::server::middleware::processors::mandatory::MINIMUM_RESPONSE_BYTE_LEN; use crate::net::server::service::{CallResult, Service, ServiceError}; use crate::net::server::util::{mk_builder_for_target, service_fn}; + use super::{MandatoryMiddlewareSvc, MINIMUM_RESPONSE_BYTE_LEN}; + //------------ Constants ------------------------------------------------- const MIN_ALLOWED: u16 = MINIMUM_RESPONSE_BYTE_LEN; diff --git a/src/net/server/middleware/mod.rs b/src/net/server/middleware/mod.rs index ae61ca78d..816f10339 100644 --- a/src/net/server/middleware/mod.rs +++ b/src/net/server/middleware/mod.rs @@ -1,33 +1,5 @@ -//! Request pre-processing and response post-processing. -//! -//! Middleware sits in the middle between the server nearest the client and -//! the [`Service`] that implements the application logic. -//! -//! Middleware pre-processes requests and post-processes responses to -//! filter/reject/modify them according to policy and standards. -//! -//! Middleware processing should happen immediately after receipt of a request -//! (to ensure the least resources are spent on processing malicious requests) -//! and immediately prior to writing responses back to the client (to ensure -//! that what is sent to the client is correct). -//! -//! Mandatory functionality and logic required by all standards compliant DNS -//! servers can be incorporated into your server by building a middleware -//! chain starting from [`MiddlewareBuilder::default`]. -//! -//! A selection of additional functionality relating to server behaviour and -//! DNS standards (as opposed to your own application logic) is provided which -//! you can incorporate into your DNS server via [`MiddlewareBuilder::push`]. -//! See the various implementations of [`MiddlewareProcessor`] for more -//! information. -//! -//! [`MiddlewareBuilder::default`]: builder::MiddlewareBuilder::default() -//! [`MiddlewareBuilder::push`]: builder::MiddlewareBuilder::push() -//! [`MiddlewareChain`]: chain::MiddlewareChain -//! [`MiddlewareProcessor`]: processor::MiddlewareProcessor -//! [`Service`]: crate::net::server::service::Service -pub mod builder; -pub mod chain; -pub mod processor; -pub mod processors; +#[cfg(feature = "siphasher")] +pub mod cookies; +pub mod edns; +pub mod mandatory; pub mod util; diff --git a/src/net/server/middleware/processor.rs b/src/net/server/middleware/processor.rs deleted file mode 100644 index 1b40d85da..000000000 --- a/src/net/server/middleware/processor.rs +++ /dev/null @@ -1,39 +0,0 @@ -//! Supporting types common to all processors. -use core::ops::ControlFlow; - -use crate::base::message_builder::AdditionalBuilder; -use crate::base::StreamTarget; -use crate::net::server::message::Request; - -/// A processing stage applied to incoming and outgoing messages. -/// -/// See the documentation in the [`middleware`] module for more information. -/// -/// [`middleware`]: crate::net::server::middleware -pub trait MiddlewareProcessor -where - RequestOctets: AsRef<[u8]>, -{ - /// Apply middleware pre-processing rules to a request. - /// - /// See [`MiddlewareChain::preprocess`] for more information. - /// - /// [`MiddlewareChain::preprocess`]: - /// crate::net::server::middleware::chain::MiddlewareChain::preprocess() - fn preprocess( - &self, - request: &Request, - ) -> ControlFlow>>; - - /// Apply middleware post-processing rules to a response. - /// - /// See [`MiddlewareChain::postprocess`] for more information. - /// - /// [`MiddlewareChain::postprocess`]: - /// crate::net::server::middleware::chain::MiddlewareChain::postprocess() - fn postprocess( - &self, - request: &Request, - response: &mut AdditionalBuilder>, - ); -} diff --git a/src/net/server/middleware/processors/cookies.rs b/src/net/server/middleware/processors/cookies.rs deleted file mode 100644 index a5bfbd65f..000000000 --- a/src/net/server/middleware/processors/cookies.rs +++ /dev/null @@ -1,492 +0,0 @@ -//! DNS Cookies related message processing. -use core::ops::ControlFlow; - -use std::net::IpAddr; -use std::vec::Vec; - -use octseq::Octets; -use rand::RngCore; -use tracing::{debug, trace, warn}; - -use crate::base::iana::{OptRcode, OptionCode, Rcode}; -use crate::base::message_builder::AdditionalBuilder; -use crate::base::opt; -use crate::base::opt::Cookie; -use crate::base::wire::{Composer, ParseError}; -use crate::base::{Serial, StreamTarget}; -use crate::net::server::message::Request; -use crate::net::server::middleware::processor::MiddlewareProcessor; -use crate::net::server::util::add_edns_options; -use crate::net::server::util::{mk_builder_for_target, start_reply}; - -/// The five minute period referred to by -/// https://www.rfc-editor.org/rfc/rfc9018.html#section-4.3. -const FIVE_MINUTES_AS_SECS: u32 = 5 * 60; - -/// The one hour period referred to by -/// https://www.rfc-editor.org/rfc/rfc9018.html#section-4.3. -const ONE_HOUR_AS_SECS: u32 = 60 * 60; - -/// A DNS Cookies [`MiddlewareProcessor`]. -/// -/// Standards covered by ths implementation: -/// -/// | RFC | Status | -/// |--------|---------| -/// | [7873] | TBD | -/// | [9018] | TBD | -/// -/// [7873]: https://datatracker.ietf.org/doc/html/rfc7873 -/// [9018]: https://datatracker.ietf.org/doc/html/rfc7873 -/// [`MiddlewareProcessor`]: crate::net::server::middleware::processor::MiddlewareProcessor -#[derive(Debug)] -pub struct CookiesMiddlewareProcessor { - /// A user supplied secret used in making the cookie value. - server_secret: [u8; 16], - - /// Clients connecting from these IP addresses will be required to provide - /// a cookie otherwise they will receive REFUSED with TC=1 prompting them - /// to reconnect with TCP in order to "authenticate" themselves. - ip_deny_list: Vec, -} - -impl CookiesMiddlewareProcessor { - /// Creates an instance of this processor. - #[must_use] - pub fn new(server_secret: [u8; 16]) -> Self { - Self { - server_secret, - ip_deny_list: vec![], - } - } - - /// Define IP addresses required to supply DNS cookies if using UDP. - #[must_use] - pub fn with_denied_ips>>( - mut self, - ip_deny_list: T, - ) -> Self { - self.ip_deny_list = ip_deny_list.into(); - self - } -} - -impl CookiesMiddlewareProcessor { - /// Get the DNS COOKIE, if any, for the given message. - /// - /// https://datatracker.ietf.org/doc/html/rfc7873#section-5.2: Responding - /// to a Request: "In all cases of multiple COOKIE options in a request, - /// only the first (the one closest to the DNS header) is considered. - /// All others are ignored." - /// - /// Returns: - /// - `None` if the request has no cookie, - /// - Some(Ok(cookie)) if the request has a cookie in the correct - /// format, - /// - Some(Err(err)) if the request has a cookie that we could not - /// parse. - #[must_use] - fn cookie( - request: &Request, - ) -> Option> { - // Note: We don't use `opt::Opt::first()` because that will silently - // ignore an unparseable COOKIE option but we need to detect and - // handle that case. TODO: Should we warn in some way if the request - // has more than one COOKIE option? - request - .message() - .opt() - .and_then(|opt| opt.opt().iter::().next()) - } - - /// Check whether or not the given timestamp is okay. - /// - /// Returns true if the given timestamp is within the permitted difference - /// to now as specified by [RFC 9018 section 4.3]. - /// - /// [RFC 9018 section 4.3]: https://www.rfc-editor.org/rfc/rfc9018.html#section-4.3 - #[must_use] - fn timestamp_ok(serial: Serial) -> bool { - // https://www.rfc-editor.org/rfc/rfc9018.html#section-4.3 - // 4.3. The Timestamp Sub-Field: - // "The Timestamp value prevents Replay Attacks and MUST be checked - // by the server to be within a defined period of time. The DNS - // server SHOULD allow cookies within a 1-hour period in the past - // and a 5-minute period into the future to allow operation of - // low-volume clients and some limited time skew between the DNS - // servers in the anycast set." - let now = Serial::now(); - let too_new_at = now.add(FIVE_MINUTES_AS_SECS); - let expires_at = serial.add(ONE_HOUR_AS_SECS); - now <= expires_at && serial <= too_new_at - } - - /// Create a DNS response message for the given request, including cookie. - fn response_with_cookie( - &self, - request: &Request, - rcode: OptRcode, - ) -> AdditionalBuilder> - where - RequestOctets: Octets, - Target: Composer + Default, - { - let mut additional = start_reply(request).additional(); - - if let Some(Ok(client_cookie)) = Self::cookie(request) { - let response_cookie = client_cookie.create_response( - Serial::now(), - request.client_addr().ip(), - &self.server_secret, - ); - - // Note: if rcode is non-extended this will also correctly handle - // setting the rcode in the main message header. - if let Err(err) = add_edns_options(&mut additional, |_, opt| { - opt.cookie(response_cookie)?; - opt.set_rcode(rcode); - Ok(()) - }) { - warn!("Failed to add cookie to response: {err}"); - } - } - - additional - } - - /// Create a DNS error response message indicating that the client - /// supplied cookie is not okay. - /// - /// Panics - /// - /// This function will panic if the given request does not include a DNS - /// client cookie or is unable to write to an internal buffer while - /// constructing the response. - #[must_use] - fn bad_cookie_response( - &self, - request: &Request, - ) -> AdditionalBuilder> - where - RequestOctets: Octets, - Target: Composer + Default, - { - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.3 - // "If the server responds [ed: by sending a BADCOOKIE error - // response], it SHALL generate its own COOKIE option containing - // both the Client Cookie copied from the request and a Server - // Cookie it has generated, and it will add this COOKIE option to - // the response's OPT record. - - self.response_with_cookie(request, OptRcode::BADCOOKIE) - } - - /// Create a DNS response to a client cookie prefetch request. - #[must_use] - fn prefetch_cookie_response( - &self, - request: &Request, - ) -> AdditionalBuilder> - where - RequestOctets: Octets, - Target: Composer + Default, - { - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.4 - // Querying for a Server Cookie: - // "For servers with DNS Cookies enabled, the - // QUERY opcode behavior is extended to support queries with an - // empty Question Section (a QDCOUNT of zero (0)), provided that an - // OPT record is present with a COOKIE option. Such servers will - // send a reply that has an empty Answer Section and has a COOKIE - // option containing the Client Cookie and a valid Server Cookie. - // - // If such a query provided just a Client Cookie and no Server - // Cookie, the response SHALL have the RCODE NOERROR." - self.response_with_cookie(request, Rcode::NOERROR.into()) - } - - /// Check the cookie contained in the request to make sure that it is - /// complete, and if so return the cookie to the caller. - #[must_use] - fn ensure_cookie_is_complete( - &self, - request: &Request, - ) -> Option { - if let Some(Ok(cookie)) = Self::cookie(request) { - let cookie = if cookie.server().is_some() { - cookie - } else { - cookie.create_response( - Serial::now(), - request.client_addr().ip(), - &self.server_secret, - ) - }; - - Some(cookie) - } else { - None - } - } -} - -//--- Default - -impl Default for CookiesMiddlewareProcessor { - /// Creates an instance of this processor with default configuration. - /// - /// The processor will use a randomly generated server secret. - fn default() -> Self { - let mut server_secret = [0u8; 16]; - rand::thread_rng().fill_bytes(&mut server_secret); - - Self { - server_secret, - ip_deny_list: Default::default(), - } - } -} - -//--- MiddlewareProcessor - -impl MiddlewareProcessor - for CookiesMiddlewareProcessor -where - RequestOctets: Octets, - Target: Composer + Default, -{ - fn preprocess( - &self, - request: &Request, - ) -> ControlFlow>> { - match Self::cookie(request) { - None => { - trace!("Request does not include DNS cookies"); - - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.1 - // No OPT RR or No COOKIE Option: - // "If there is no OPT record or no COOKIE option - // present in the request, then the server responds to - // the request as if the server doesn't implement the - // COOKIE option." - - // For clients on the IP deny list they MUST authenticate - // themselves to the server, either with a cookie or by - // re-connecting over TCP, so we REFUSE them and reply with - // TC=1 to prompt them to reconnect via TCP. - if request.transport_ctx().is_udp() - && self.ip_deny_list.contains(&request.client_addr().ip()) - { - debug!( - "Rejecting cookie-less non-TCP request due to matching IP deny list entry" - ); - let builder = mk_builder_for_target(); - let mut additional = builder.additional(); - additional.header_mut().set_rcode(Rcode::REFUSED); - additional.header_mut().set_tc(true); - return ControlFlow::Break(additional); - } else { - trace!("Permitting cookie-less request to flow due to use of TCP transport"); - } - } - - Some(Err(err)) => { - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.2 - // Malformed COOKIE Option: - // "If the COOKIE option is too short to contain a - // Client Cookie, then FORMERR is generated. If the - // COOKIE option is longer than that required to hold a - // COOKIE option with just a Client Cookie (8 bytes) - // but is shorter than the minimum COOKIE option with - // both a Client Cookie and a Server Cookie (16 bytes), - // then FORMERR is generated. If the COOKIE option is - // longer than the maximum valid COOKIE option (40 - // bytes), then FORMERR is generated." - - // TODO: Should we warn in some way about the exact reason - // for rejecting the request? - - // NOTE: The RFC doesn't say that we should send our server - // cookie back with the response, so we don't do that here - // unlike in the other cases where we respond early. - debug!("Received malformed DNS cookie: {err}"); - let mut builder = mk_builder_for_target(); - builder.header_mut().set_rcode(Rcode::FORMERR); - return ControlFlow::Break(builder.additional()); - } - - Some(Ok(cookie)) => { - // TODO: Does the "at least occasionally" condition below - // referencing RFC 7873 section 5.2.3 mean that (a) we don't - // have to do this for every response, and (b) we might want - // to add configuration settings for controlling how often we - // do this? - - let server_cookie_exists = cookie.server().is_some(); - let server_cookie_is_valid = cookie.check_server_hash( - request.client_addr().ip(), - &self.server_secret, - Self::timestamp_ok, - ); - - if !server_cookie_is_valid { - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.3 - // Only a Client Cookie: - // "Based on server policy, including rate limiting, the - // server chooses one of the following: - // - // (1) Silently discard the request. - // - // (2) Send a BADCOOKIE error response. - // - // (3) Process the request and provide a normal - // response. The RCODE is NOERROR, unless some - // non-cookie error occurs in processing the - // request. - // - // ... ... - // - // Servers MUST, at least occasionally, respond to such - // requests to inform the client of the correct Server - // Cookie. - // - // ... ... - // - // If the request was received over TCP, the - // server SHOULD take the authentication - // provided by the use of TCP into account and - // SHOULD choose (3). In this case, if the - // server is not willing to accept the security - // provided by TCP as a substitute for the - // security provided by DNS Cookies but instead - // chooses (2), there is some danger of an - // indefinite loop of retries (see Section - // 5.3)." - - // TODO: Does "(1)" above in combination with the text in - // section 5.2.5 "SHALL process the request" mean that we - // are not allowed to reject the request prior to this - // point based on rate limiting or other server policy? - - // TODO: Should we add a configuration option that allows - // for choosing between approaches (1), (2) and (3)? For - // now err on the side of security and go with approach - // (2): send a BADCOOKIE response. - - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.4 - // Querying for a Server Cookie: - // "For servers with DNS Cookies enabled, the QUERY - // opcode behavior is extended to support queries with - // an empty Question Section (a QDCOUNT of zero (0)), - // provided that an OPT record is present with a COOKIE - // option. Such servers will send a reply that has an - // empty Answer Section and has a COOKIE option - // containing the Client Cookie and a valid Server - // Cookie. - - // TODO: Does the TCP check also apply to RFC 7873 section - // 5.4 "Querying for a Server Cookie" too? - - if request.message().header_counts().qdcount() == 0 { - let additional = if !server_cookie_exists { - // "If such a query provided just a Client Cookie - // and no Server Cookie, the response SHALL have - // the RCODE NOERROR." - trace!( - "Replying to DNS cookie pre-fetch request with missing server cookie"); - self.prefetch_cookie_response(request) - } else { - // "In this case, the response SHALL have the - // RCODE BADCOOKIE if the Server Cookie sent with - // the query was invalid" - debug!( - "Rejecting pre-fetch request due to invalid server cookie"); - self.bad_cookie_response(request) - }; - return ControlFlow::Break(additional); - } else if request.transport_ctx().is_udp() { - let additional = self.bad_cookie_response(request); - debug!( - "Rejecting non-TCP request due to invalid server cookie"); - return ControlFlow::Break(additional); - } - } else if request.message().header_counts().qdcount() == 0 { - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.4 - // Querying for a Server Cookie: - // "This mechanism can also be used to - // confirm/re-establish an existing Server Cookie by - // sending a cached Server Cookie with the Client - // Cookie. In this case, the response SHALL have the - // RCODE BADCOOKIE if the Server Cookie sent with the - // query was invalid and the RCODE NOERROR if it was - // valid." - - // TODO: Does the TCP check also apply to RFC 7873 section - // 5.4 "Querying for a Server Cookie" too? - trace!( - "Replying to DNS cookie pre-fetch request with valid server cookie"); - let additional = self.prefetch_cookie_response(request); - return ControlFlow::Break(additional); - } else { - trace!("Request has a valid DNS cookie"); - } - } - } - - trace!("Permitting request to flow"); - - ControlFlow::Continue(()) - } - - fn postprocess( - &self, - request: &Request, - response: &mut AdditionalBuilder>, - ) { - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.1 - // No OPT RR or No COOKIE Option: - // If the request lacked a client cookie we don't need to do - // anything. - // - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.2 - // Malformed COOKIE Option: - // If the request COOKIE option was malformed we would have already - // rejected it during pre-processing so again nothing to do here. - // - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.3 - // Only a Client Cookie: - // If the request had a client cookie but no server cookie and - // we didn't already reject the request during pre-processing. - // - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.4 - // A Client Cookie and an Invalid Server Cookie: - // Per RFC 7873 this is handled the same way as the "Only a Client - // Cookie" case. - // - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.5 - // A Client Cookie and a Valid Server Cookie - // Any server cookie will already have been validated during - // pre-processing, we don't need to check it again here. - - if let Some(filled_cookie) = self.ensure_cookie_is_complete(request) { - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.5 - // "The server SHALL process the request and include a COOKIE - // option in the response by (a) copying the complete COOKIE - // option from the request or (b) generating a new COOKIE option - // containing both the Client Cookie copied from the request and - // a valid Server Cookie it has generated." - if let Err(err) = add_edns_options( - response, - |existing_option_codes, builder| { - if !existing_option_codes.contains(&OptionCode::COOKIE) { - builder.push(&filled_cookie) - } else { - Ok(()) - } - }, - ) { - warn!("Cannot add RFC 7873 DNS Cookie option to response: {err}"); - } - } - } -} diff --git a/src/net/server/middleware/processors/edns.rs b/src/net/server/middleware/processors/edns.rs deleted file mode 100644 index d465a5ad1..000000000 --- a/src/net/server/middleware/processors/edns.rs +++ /dev/null @@ -1,439 +0,0 @@ -//! RFC 6891 and related EDNS message processing. -use core::ops::ControlFlow; - -use octseq::Octets; -use tracing::{debug, enabled, error, trace, warn, Level}; - -use crate::base::iana::{OptRcode, OptionCode}; -use crate::base::message_builder::AdditionalBuilder; -use crate::base::opt::keepalive::IdleTimeout; -use crate::base::opt::{Opt, OptRecord, TcpKeepalive}; -use crate::base::wire::Composer; -use crate::base::StreamTarget; -use crate::net::server::message::{Request, TransportSpecificContext}; -use crate::net::server::middleware::processor::MiddlewareProcessor; -use crate::net::server::middleware::processors::mandatory::MINIMUM_RESPONSE_BYTE_LEN; -use crate::net::server::util::start_reply; -use crate::net::server::util::{add_edns_options, remove_edns_opt_record}; - -/// EDNS version 0. -/// -/// Version 0 is the highest EDNS version number recoded in the [IANA -/// registry] at the time of writing. -/// -/// [IANA registry]: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-14 -const EDNS_VERSION_ZERO: u8 = 0; - -/// A [`MiddlewareProcessor`] for adding EDNS(0) related functionality. -/// -/// Standards covered by ths implementation: -/// -/// | RFC | Status | -/// |--------|---------| -/// | [6891] | TBD | -/// | [7828] | TBD | -/// | [9210] | TBD | -/// -/// [6891]: https://datatracker.ietf.org/doc/html/rfc6891 -/// [7828]: https://datatracker.ietf.org/doc/html/rfc7828 -/// [9210]: https://datatracker.ietf.org/doc/html/rfc9210 -/// [`MiddlewareProcessor`]: crate::net::server::middleware::processor::MiddlewareProcessor -#[derive(Debug, Default)] -pub struct EdnsMiddlewareProcessor; - -impl EdnsMiddlewareProcessor { - /// Creates an instance of this processor. - #[must_use] - pub fn new() -> Self { - Self - } -} - -impl EdnsMiddlewareProcessor { - /// Create a DNS error response to the given request with the given RCODE. - fn error_response( - &self, - request: &Request, - rcode: OptRcode, - ) -> AdditionalBuilder> - where - RequestOctets: Octets, - Target: Composer + Default, - { - let mut additional = start_reply(request).additional(); - - // Note: if rcode is non-extended this will also correctly handle - // setting the rcode in the main message header. - if let Err(err) = add_edns_options(&mut additional, |_, opt| { - opt.set_rcode(rcode); - Ok(()) - }) { - warn!( - "Failed to set (extended) error '{rcode}' in response: {err}" - ); - } - - self.postprocess(request, &mut additional); - additional - } -} - -//--- MiddlewareProcessor - -impl MiddlewareProcessor - for EdnsMiddlewareProcessor -where - RequestOctets: Octets, - Target: Composer + Default, -{ - fn preprocess( - &self, - request: &Request, - ) -> ControlFlow>> { - // https://www.rfc-editor.org/rfc/rfc6891.html#section-6.1.1 - // 6.1.1: Basic Elements - // ... - // "If a query message with more than one OPT RR is received, a - // FORMERR (RCODE=1) MUST be returned" - let msg = request.message().clone(); - if let Ok(additional) = msg.additional() { - let mut iter = additional.limit_to::>(); - if let Some(opt) = iter.next() { - if iter.next().is_some() { - // More than one OPT RR received. - debug!("RFC 6891 6.1.1 violation: request contains more than one OPT RR."); - return ControlFlow::Break( - self.error_response(request, OptRcode::FORMERR), - ); - } - - if let Ok(opt) = opt { - let opt_rec = OptRecord::from(opt); - - // https://datatracker.ietf.org/doc/html/rfc6891#section-6.1.3 - // 6.1.3. OPT Record TTL Field Use - // "If a responder does not implement the VERSION level - // of the request, then it MUST respond with - // RCODE=BADVERS." - if opt_rec.version() > EDNS_VERSION_ZERO { - debug!("RFC 6891 6.1.3 violation: request EDNS version {} > 0", opt_rec.version()); - return ControlFlow::Break( - self.error_response(request, OptRcode::BADVERS), - ); - } - - match request.transport_ctx() { - TransportSpecificContext::Udp(ctx) => { - // https://datatracker.ietf.org/doc/html/rfc7828#section-3.2.1 - // 3.2.1. Sending Queries - // "DNS clients MUST NOT include the - // edns-tcp-keepalive option in queries sent - // using UDP transport." - // TODO: We assume there is only one keep-alive - // option in the request. Should we check for - // multiple? Neither RFC 6891 nor RFC 7828 seem to - // disallow multiple keep alive options in the OPT - // RDATA but multiple at once seems strange. - if opt_rec.opt().tcp_keepalive().is_some() { - debug!("RFC 7828 3.2.1 violation: edns-tcp-keepalive option received via UDP"); - return ControlFlow::Break( - self.error_response( - request, - OptRcode::FORMERR, - ), - ); - } - - // https://datatracker.ietf.org/doc/html/rfc6891#section-6.2.3 - // 6.2.3. Requestor's Payload Size - // "The requestor's UDP payload size (encoded in - // the RR CLASS field) is the number of octets - // of the largest UDP payload that can be - // reassembled and delivered in the requestor's - // network stack. Note that path MTU, with or - // without fragmentation, could be smaller than - // this. - // - // Values lower than 512 MUST be treated as - // equal to 512." - let requestors_udp_payload_size = - opt_rec.udp_payload_size(); - - if requestors_udp_payload_size - < MINIMUM_RESPONSE_BYTE_LEN - { - debug!("RFC 6891 6.2.3 violation: OPT RR class (requestor's UDP payload size) < {MINIMUM_RESPONSE_BYTE_LEN}"); - } - - // Clamp the lower bound of the size limit - // requested by the client: - let clamped_requestors_udp_payload_size = - u16::max(512, requestors_udp_payload_size); - - // Clamp the upper bound of the size limit - // requested by the server: - let server_max_response_size_hint = - ctx.max_response_size_hint(); - let clamped_server_hint = - server_max_response_size_hint.map(|v| { - v.clamp( - MINIMUM_RESPONSE_BYTE_LEN, - clamped_requestors_udp_payload_size, - ) - }); - - // Use the clamped client size limit if no server hint exists, - // otherwise use the smallest of the client and server limits - // while not going lower than 512 bytes. - let negotiated_hint = match clamped_server_hint { - Some(clamped_server_hint) => u16::min( - clamped_requestors_udp_payload_size, - clamped_server_hint, - ), - - None => clamped_requestors_udp_payload_size, - }; - - if enabled!(Level::TRACE) { - trace!("EDNS(0) response size negotation concluded: client requested={}, server requested={:?}, chosen value={}", - opt_rec.udp_payload_size(), server_max_response_size_hint, negotiated_hint); - } - - ctx.set_max_response_size_hint(Some( - negotiated_hint, - )); - } - - TransportSpecificContext::NonUdp(_) => { - // https://datatracker.ietf.org/doc/html/rfc7828#section-3.2.1 - // 3.2.1. Sending Queries - // "Clients MUST specify an OPTION-LENGTH of 0 - // and omit the TIMEOUT value." - if let Some(keep_alive) = - opt_rec.opt().tcp_keepalive() - { - if keep_alive.timeout().is_some() { - debug!("RFC 7828 3.2.1 violation: edns-tcp-keepalive option received via TCP contains timeout"); - return ControlFlow::Break( - self.error_response( - request, - OptRcode::FORMERR, - ), - ); - } - } - } - } - } - } - } - - ControlFlow::Continue(()) - } - - fn postprocess( - &self, - request: &Request, - response: &mut AdditionalBuilder>, - ) { - // https://www.rfc-editor.org/rfc/rfc6891.html#section-6.1.1 - // 6.1.1: Basic Elements - // ... - // "If an OPT record is present in a received request, compliant - // responders MUST include an OPT record in their respective - // responses." - // - // We don't do anything about this scenario at present. - - // https://www.rfc-editor.org/rfc/rfc6891.html#section-7 - // 7: Transport considerations - // ... - // "Lack of presence of an OPT record in a request MUST be taken as an - // indication that the requestor does not implement any part of this - // specification and that the responder MUST NOT include an OPT - // record in its response." - // - // So strip off any OPT record present if the query lacked an OPT - // record. - if request.message().opt().is_none() { - if let Err(err) = remove_edns_opt_record(response) { - error!( - "Error while stripping OPT record from response: {err}" - ); - *response = self.error_response(request, OptRcode::SERVFAIL); - return; - } - } - - // https://datatracker.ietf.org/doc/html/rfc7828#section-3.3.2 - // 3.3.2. Sending Responses - // "A DNS server that receives a query sent using TCP transport that - // includes an OPT RR (with or without the edns-tcp-keepalive - // option) MAY include the edns-tcp-keepalive option in the - // response to signal the expected idle timeout on a connection. - // Servers MUST specify the TIMEOUT value that is currently - // associated with the TCP session." - // - // https://datatracker.ietf.org/doc/html/rfc9210#section-4.2 - // 4.2. Connection Management - // "... DNS clients and servers SHOULD signal their timeout values - // using the edns-tcp-keepalive EDNS(0) option [RFC7828]." - if let TransportSpecificContext::NonUdp(ctx) = request.transport_ctx() - { - if let Some(idle_timeout) = ctx.idle_timeout() { - if let Ok(additional) = request.message().additional() { - let mut iter = additional.limit_to::>(); - if iter.next().is_some() { - match IdleTimeout::try_from(idle_timeout) { - Ok(timeout) => { - // Request has an OPT RR and server idle - // timeout is known: "Signal the timeout value - // using the edns-tcp-keepalive EDNS(0) option - // [RFC7828]". - if let Err(err) = - add_edns_options(response, |existing_option_codes, builder| { - if !existing_option_codes.contains(&OptionCode::TCP_KEEPALIVE) { - builder.push(&TcpKeepalive::new( - Some(timeout), - )) - } else { - Ok(()) - } - }) - { - warn!("Cannot add RFC 7828 edns-tcp-keepalive option to response: {err}"); - } - } - - Err(err) => { - warn!("Cannot add RFC 7828 edns-tcp-keepalive option to response: invalid timeout: {err}"); - } - } - } - } - } - } - } -} - -#[cfg(test)] -mod tests { - use core::ops::ControlFlow; - - use std::vec::Vec; - - use bytes::Bytes; - use tokio::time::Instant; - - use crate::base::{Dname, Message, MessageBuilder, Rtype}; - use crate::net::server::message::{ - Request, TransportSpecificContext, UdpTransportContext, - }; - - use super::EdnsMiddlewareProcessor; - use crate::net::server::middleware::processor::MiddlewareProcessor; - use crate::net::server::middleware::processors::mandatory::MINIMUM_RESPONSE_BYTE_LEN; - - //------------ Constants ------------------------------------------------- - - const MIN_ALLOWED: Option = Some(MINIMUM_RESPONSE_BYTE_LEN); - const TOO_SMALL: Option = Some(511); - const JUST_RIGHT: Option = MIN_ALLOWED; - const HUGE: Option = Some(u16::MAX); - - //------------ Tests ----------------------------------------------------- - - #[test] - fn clamp_max_response_size_correctly() { - // Neither client or server specified a max UDP response size. - assert_eq!(process(None, None), None); - - // --- Only server specified max UDP response sizes - // - // The EdnsMiddlewareProcessor should leave these untouched as no EDNS - // option was present in the request, only the server hint exists, and - // EdnsMiddlewareProcessor only acts if the client EDNS option is - // present. - assert_eq!(process(None, TOO_SMALL), TOO_SMALL); - assert_eq!(process(None, JUST_RIGHT), JUST_RIGHT); - assert_eq!(process(None, HUGE), HUGE); - - // --- Only client specified max UDP response sizes - // - // The EdnsMiddlewareProcessor should adopt these, after clamping - // them. - assert_eq!(process(TOO_SMALL, None), JUST_RIGHT); - assert_eq!(process(JUST_RIGHT, None), JUST_RIGHT); - assert_eq!(process(HUGE, None), HUGE); - - // --- Both client and server specified max UDP response sizes - // - // The EdnsMiddlewareProcessor should negotiate the largest size - // acceptable to both sides. - assert_eq!(process(TOO_SMALL, TOO_SMALL), MIN_ALLOWED); - assert_eq!(process(TOO_SMALL, JUST_RIGHT), JUST_RIGHT); - assert_eq!(process(TOO_SMALL, HUGE), MIN_ALLOWED); - assert_eq!(process(JUST_RIGHT, TOO_SMALL), JUST_RIGHT); - assert_eq!(process(JUST_RIGHT, JUST_RIGHT), JUST_RIGHT); - assert_eq!(process(JUST_RIGHT, HUGE), JUST_RIGHT); - assert_eq!(process(HUGE, TOO_SMALL), MIN_ALLOWED); - assert_eq!(process(HUGE, JUST_RIGHT), JUST_RIGHT); - assert_eq!(process(HUGE, HUGE), HUGE); - } - - //------------ Helper functions ------------------------------------------ - - fn process( - client_value: Option, - server_value: Option, - ) -> Option { - // Build a dummy DNS query. - let query = MessageBuilder::new_vec(); - - // With a dummy question. - let mut query = query.question(); - query.push((Dname::::root(), Rtype::A)).unwrap(); - - // And if requested, a requestor's UDP payload size: - let message: Message<_> = if let Some(v) = client_value { - let mut additional = query.additional(); - additional - .opt(|builder| { - builder.set_udp_payload_size(v); - Ok(()) - }) - .unwrap(); - additional.into_message() - } else { - query.into_message() - }; - - // Package the query into a context aware request to make it look - // as if it came from a UDP server. - let ctx = UdpTransportContext::new(server_value); - let request = Request::new( - "127.0.0.1:12345".parse().unwrap(), - Instant::now(), - message, - TransportSpecificContext::Udp(ctx), - ); - - // And pass the query through the middleware processor - let processor = EdnsMiddlewareProcessor::new(); - let processor: &dyn MiddlewareProcessor, Vec> = - &processor; - let mut response = MessageBuilder::new_stream_vec().additional(); - if let ControlFlow::Continue(()) = processor.preprocess(&request) { - processor.postprocess(&request, &mut response); - } - - // Get the modified response size hint. - let TransportSpecificContext::Udp(modified_udp_context) = - request.transport_ctx() - else { - unreachable!() - }; - - modified_udp_context.max_response_size_hint() - } -} diff --git a/src/net/server/middleware/processors/mandatory.rs b/src/net/server/middleware/processors/mandatory.rs deleted file mode 100644 index ac0bdc366..000000000 --- a/src/net/server/middleware/processors/mandatory.rs +++ /dev/null @@ -1,421 +0,0 @@ -//! Core DNS RFC standards based message processing for MUST requirements. -use core::ops::ControlFlow; - -use octseq::Octets; -use tracing::{debug, error, trace, warn}; - -use crate::base::iana::{Opcode, Rcode}; -use crate::base::message_builder::{AdditionalBuilder, PushError}; -use crate::base::wire::{Composer, ParseError}; -use crate::base::StreamTarget; -use crate::net::server::message::{Request, TransportSpecificContext}; -use crate::net::server::middleware::processor::MiddlewareProcessor; -use crate::net::server::util::{mk_builder_for_target, start_reply}; -use std::fmt::Display; - -/// The minimum legal UDP response size in bytes. -/// -/// As defined by [RFC 1035 section 4.2.1]. -/// -/// [RFC 1035 section 4.2.1]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1 -pub const MINIMUM_RESPONSE_BYTE_LEN: u16 = 512; - -/// A [`MiddlewareProcessor`] for enforcing core RFC MUST requirements on -/// processed messages. -/// -/// Standards covered by ths implementation: -/// -/// | RFC | Status | -/// |--------|---------| -/// | [1035] | TBD | -/// | [2181] | TBD | -/// -/// [`MiddlewareProcessor`]: -/// crate::net::server::middleware::processor::MiddlewareProcessor -/// [1035]: https://datatracker.ietf.org/doc/html/rfc1035 -/// [2181]: https://datatracker.ietf.org/doc/html/rfc2181 -#[derive(Debug)] -pub struct MandatoryMiddlewareProcessor { - /// In strict mode the processor does more checks on requests and - /// responses. - strict: bool, -} - -impl MandatoryMiddlewareProcessor { - /// Creates a new processor instance. - /// - /// The processor will operate in strict mode. - #[must_use] - pub fn new() -> Self { - Self { strict: true } - } - - /// Creates a new processor instance. - /// - /// The processor will operate in relaxed mode. - #[must_use] - pub fn relaxed() -> Self { - Self { strict: false } - } - - /// Create a DNS error response to the given request with the given RCODE. - fn error_response( - &self, - request: &Request, - rcode: Rcode, - ) -> AdditionalBuilder> - where - RequestOctets: Octets, - Target: Composer + Default, - { - let mut response = start_reply(request); - response.header_mut().set_rcode(rcode); - let mut additional = response.additional(); - self.postprocess(request, &mut additional); - additional - } -} - -impl MandatoryMiddlewareProcessor { - /// Truncate the given response message if it is too large. - /// - /// Honours either a transport supplied hint, if present in the given - /// [`UdpSpecificTransportContext`], as to how large the response is - /// allowed to be, or if missing will instead honour the clients indicated - /// UDP response payload size (if an EDNS OPT is present in the request). - /// - /// Truncation discards the authority and additional sections, except for - /// any OPT record present which will be preserved, then truncates to the - /// specified byte length. - fn truncate( - request: &Request, - response: &mut AdditionalBuilder>, - ) -> Result<(), TruncateError> - where - Target: Composer + Default, - RequestOctets: AsRef<[u8]>, - { - if let TransportSpecificContext::Udp(ctx) = request.transport_ctx() { - // https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1 - // "Messages carried by UDP are restricted to 512 bytes (not - // counting the IP or UDP headers). Longer messages are - // truncated and the TC bit is set in the header." - let max_response_size = ctx - .max_response_size_hint() - .unwrap_or(MINIMUM_RESPONSE_BYTE_LEN); - let max_response_size = max_response_size as usize; - let response_len = response.as_slice().len(); - - if response_len > max_response_size { - // Truncate per RFC 1035 section 6.2 and RFC 2181 sections 5.1 - // and 9: - // - // https://datatracker.ietf.org/doc/html/rfc1035#section-6.2 - // "When a response is so long that truncation is required, - // the truncation should start at the end of the response - // and work forward in the datagram. Thus if there is any - // data for the authority section, the answer section is - // guaranteed to be unique." - // - // https://datatracker.ietf.org/doc/html/rfc2181#section-5.1 - // "A query for a specific (or non-specific) label, class, - // and type, will always return all records in the - // associated RRSet - whether that be one or more RRs. The - // response must be marked as "truncated" if the entire - // RRSet will not fit in the response." - // - // https://datatracker.ietf.org/doc/html/rfc2181#section-9 - // "Where TC is set, the partial RRSet that would not - // completely fit may be left in the response. When a DNS - // client receives a reply with TC set, it should ignore - // that response, and query again, using a mechanism, such - // as a TCP connection, that will permit larger replies." - // - // https://datatracker.ietf.org/doc/html/rfc6891#section-7 - // "The minimal response MUST be the DNS header, question - // section, and an OPT record. This MUST also occur when - // a truncated response (using the DNS header's TC bit) is - // returned." - - // Tell the client that we are truncating the response. - response.header_mut().set_tc(true); - - // Remember the original length. - let old_len = response.as_slice().len(); - - // Copy the header, question and opt record from the - // additional section, but leave the answer and authority - // sections empty. - let source = response.as_message(); - let mut target = mk_builder_for_target(); - - *target.header_mut() = source.header(); - - let mut target = target.question(); - for rr in source.question() { - target.push(rr?)?; - } - - let mut target = target.additional(); - if let Some(opt) = source.opt() { - if let Err(err) = target.push(opt.as_record()) { - warn!("Error while truncating response: unable to push OPT record: {err}"); - // As the client had an OPT record and RFC 6891 says - // when truncating that there MUST be an OPT record, - // attempt to push just the empty OPT record (as the - // OPT record header still has value, e.g. the - // requestors payload size field and extended rcode). - if let Err(err) = target.opt(|builder| { - builder.set_version(opt.version()); - builder.set_rcode(opt.rcode(response.header())); - builder - .set_udp_payload_size(opt.udp_payload_size()); - Ok(()) - }) { - error!("Error while truncating response: unable to add minimal OPT record: {err}"); - } - } - } - - let new_len = target.as_slice().len(); - trace!("Truncating response from {old_len} bytes to {new_len} bytes"); - - *response = target; - } - } - - Ok(()) - } -} - -//--- MiddlewareProcessor - -// TODO: If we extend this later to do a lot more than setting a couple of -// header flags, and if we think that there may be a need for alternate -// truncation strategies, then it might make sense to factor out truncation to -// make it "pluggable" by the user. -impl MiddlewareProcessor - for MandatoryMiddlewareProcessor -where - RequestOctets: AsRef<[u8]> + Octets, - Target: Composer + Default, -{ - fn preprocess( - &self, - request: &Request, - ) -> ControlFlow>> { - // https://www.rfc-editor.org/rfc/rfc3425.html - // 3 - Effect on RFC 1035 - // .. - // "Therefore IQUERY is now obsolete, and name servers SHOULD return - // a "Not Implemented" error when an IQUERY request is received." - if self.strict - && request.message().header().opcode() == Opcode::IQUERY - { - debug!( - "RFC 3425 3 violation: request opcode IQUERY is obsolete." - ); - return ControlFlow::Break( - self.error_response(request, Rcode::NOTIMP), - ); - } - - ControlFlow::Continue(()) - } - - fn postprocess( - &self, - request: &Request, - response: &mut AdditionalBuilder>, - ) { - if let Err(err) = Self::truncate(request, response) { - error!("Error while truncating response: {err}"); - *response = self.error_response(request, Rcode::SERVFAIL); - return; - } - - // https://datatracker.ietf.org/doc/html/rfc1035#section-4.1.1 - // 4.1.1: Header section format - // - // ID A 16 bit identifier assigned by the program that - // generates any kind of query. This identifier is copied - // the corresponding reply and can be used by the requester - // to match up replies to outstanding queries. - response - .header_mut() - .set_id(request.message().header().id()); - - // QR A one bit field that specifies whether this message is a - // query (0), or a response (1). - response.header_mut().set_qr(true); - - // RD Recursion Desired - this bit may be set in a query and - // is copied into the response. If RD is set, it directs - // the name server to pursue the query recursively. - // Recursive query support is optional. - response - .header_mut() - .set_rd(request.message().header().rd()); - - // https://www.rfc-editor.org/rfc/rfc1035.html - // https://www.rfc-editor.org/rfc/rfc3425.html - // - // All responses shown in RFC 1035 (except those for inverse queries, - // opcode 1, which was obsoleted by RFC 4325) contain the question - // from the request. So we would expect the number of questions in the - // response to match the number of questions in the request. - if self.strict - && !request.message().header_counts().qdcount() - == response.counts().qdcount() - { - warn!("RFC 1035 violation: response question count != request question count"); - } - } -} - -//--- Default - -impl Default for MandatoryMiddlewareProcessor { - fn default() -> Self { - Self::new() - } -} - -//------------ TruncateError ------------------------------------------------- - -/// An error occured during oversize response truncation. -enum TruncateError { - /// There was a problem parsing the request, specifically the question - /// section. - InvalidQuestion(ParseError), - - /// There was a problem pushing to the response. - PushFailure(PushError), -} - -impl Display for TruncateError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { - TruncateError::InvalidQuestion(err) => { - write!(f, "Unable to parse question: {err}") - } - TruncateError::PushFailure(err) => { - write!(f, "Unable to push into response: {err}") - } - } - } -} - -impl From for TruncateError { - fn from(err: ParseError) -> Self { - Self::InvalidQuestion(err) - } -} - -impl From for TruncateError { - fn from(err: PushError) -> Self { - Self::PushFailure(err) - } -} - -#[cfg(test)] -mod tests { - use core::ops::ControlFlow; - - use std::vec::Vec; - - use bytes::Bytes; - use tokio::time::Instant; - - use crate::base::{Dname, MessageBuilder, Rtype}; - use crate::net::server::message::{ - Request, TransportSpecificContext, UdpTransportContext, - }; - - use super::MandatoryMiddlewareProcessor; - use crate::base::iana::OptionCode; - use crate::net::server::middleware::processor::MiddlewareProcessor; - use crate::net::server::middleware::processors::mandatory::MINIMUM_RESPONSE_BYTE_LEN; - use octseq::OctetsBuilder; - - //------------ Constants ------------------------------------------------- - - const MIN_ALLOWED: u16 = MINIMUM_RESPONSE_BYTE_LEN; - const TOO_SMALL: u16 = 511; - const JUST_RIGHT: u16 = MIN_ALLOWED; - const HUGE: u16 = u16::MAX; - - //------------ Tests ----------------------------------------------------- - - #[test] - fn clamp_max_response_size_correctly() { - assert!(process(None) <= Some(MIN_ALLOWED as usize)); - assert!(process(Some(TOO_SMALL)) <= Some(MIN_ALLOWED as usize)); - assert!(process(Some(TOO_SMALL)) <= Some(MIN_ALLOWED as usize)); - assert!(process(Some(TOO_SMALL)) <= Some(MIN_ALLOWED as usize)); - assert!(process(Some(JUST_RIGHT)) <= Some(JUST_RIGHT as usize)); - assert!(process(Some(JUST_RIGHT)) <= Some(JUST_RIGHT as usize)); - assert!(process(Some(JUST_RIGHT)) <= Some(JUST_RIGHT as usize)); - assert!(process(Some(HUGE)) <= Some(HUGE as usize)); - assert!(process(Some(HUGE)) <= Some(HUGE as usize)); - assert!(process(Some(HUGE)) <= Some(HUGE as usize)); - } - - //------------ Helper functions ------------------------------------------ - - // Returns Some(n) if truncation occurred where n is the size after - // truncation. - fn process(max_response_size_hint: Option) -> Option { - // Build a dummy DNS query. - let query = MessageBuilder::new_vec(); - let mut query = query.question(); - query.push((Dname::::root(), Rtype::A)).unwrap(); - let extra_bytes = vec![0; (MIN_ALLOWED as usize) * 2]; - let mut additional = query.additional(); - additional - .opt(|builder| { - builder.push_raw_option( - OptionCode::PADDING, - extra_bytes.len() as u16, - |target| { - target.append_slice(&extra_bytes).unwrap(); - Ok(()) - }, - ) - }) - .unwrap(); - let old_size = additional.as_slice().len(); - let message = additional.into_message(); - - // TODO: Artificially expand the message to be as big as possible - // so that it will get truncated. - - // Package the query into a context aware request to make it look - // as if it came from a UDP server. - let ctx = UdpTransportContext::new(max_response_size_hint); - let request = Request::new( - "127.0.0.1:12345".parse().unwrap(), - Instant::now(), - message, - TransportSpecificContext::Udp(ctx), - ); - - // And pass the query through the middleware processor - let processor = MandatoryMiddlewareProcessor::default(); - let processor: &dyn MiddlewareProcessor, Vec> = - &processor; - let mut response = MessageBuilder::new_stream_vec().additional(); - if let ControlFlow::Continue(()) = processor.preprocess(&request) { - processor.postprocess(&request, &mut response); - } - - // Get the response length - let new_size = response.as_slice().len(); - - if new_size < old_size { - Some(new_size) - } else { - None - } - } -} diff --git a/src/net/server/middleware/processors/mod.rs b/src/net/server/middleware/processors/mod.rs deleted file mode 100644 index ce5dd2352..000000000 --- a/src/net/server/middleware/processors/mod.rs +++ /dev/null @@ -1,12 +0,0 @@ -//! Pre-supplied [`MiddlewareProcessor`] implementations. -//! -//! [`MiddlewareProcessor`]: super::processor::MiddlewareProcessor - -#[cfg(feature = "siphasher")] -pub mod cookies; -#[cfg(feature = "siphasher")] -pub mod cookies_svc; -pub mod edns; -pub mod edns_svc; -pub mod mandatory; -pub mod mandatory_svc; \ No newline at end of file diff --git a/src/net/server/util.rs b/src/net/server/util.rs index ec234572e..c84cf76cb 100644 --- a/src/net/server/util.rs +++ b/src/net/server/util.rs @@ -199,11 +199,11 @@ where /// /// If the response already has an OPT record the options will be added to /// that. Otherwise an OPT record will be created to hold the new options. -/// +/// /// Similar to [`AdditionalBuilder::opt`] a caller supplied closure is passed /// an [`OptBuilder`] which can be used to add EDNS options and set EDNS /// header fields. -/// +/// /// However, unlike [`AdditionalBuilder::opt`], the closure is also passed a /// collection of option codes for the options that already exist so that the /// caller can avoid adding the same type of option more than once if that is @@ -268,7 +268,8 @@ where // Copy the header fields builder.set_version(current_opt.version()); builder.set_dnssec_ok(current_opt.dnssec_ok()); - builder.set_rcode(current_opt.rcode(copied_response.header())); + builder + .set_rcode(current_opt.rcode(copied_response.header())); builder.set_udp_payload_size(current_opt.udp_payload_size()); // Copy the options From 0618b82b6b964f197c7d513d05a5443b288cfa18 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 23 Apr 2024 11:22:57 +0200 Subject: [PATCH 10/28] Remove diagnostic trace message accidentally left behind. --- src/net/server/middleware/cookies.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/net/server/middleware/cookies.rs b/src/net/server/middleware/cookies.rs index 45e778ae8..239ef75a9 100644 --- a/src/net/server/middleware/cookies.rs +++ b/src/net/server/middleware/cookies.rs @@ -446,7 +446,6 @@ impl CookiesMiddlewareSvc { RequestOctets: Octets, Target: Composer + Default, { - trace!("4"); // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.1 // No OPT RR or No COOKIE Option: // If the request lacked a client cookie we don't need to do From 20220a0b18da108637acfdb37085295dd7152701 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Mon, 29 Apr 2024 14:36:58 +0200 Subject: [PATCH 11/28] Test fixes: A single test stream should yield only a single message, and a stream preceeded by its length doesn't already need its length prepended to it. --- src/net/server/tests.rs | 75 ++++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 43 deletions(-) diff --git a/src/net/server/tests.rs b/src/net/server/tests.rs index 6687b13f0..ae8511108 100644 --- a/src/net/server/tests.rs +++ b/src/net/server/tests.rs @@ -262,25 +262,38 @@ impl BufSource for MockBufSource { /// A mock single result to be returned by a mock service, just to show that /// it is possible to define your own. -struct MySingle; +struct MySingle { + done: bool, +} + +impl MySingle { + fn new() -> MySingle { + Self { done: false } + } +} impl futures::stream::Stream for MySingle { type Item = Result>, ServiceError>; fn poll_next( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, _cx: &mut Context<'_>, ) -> Poll> { - let builder = MessageBuilder::new_stream_vec(); - let response = builder.additional(); + if self.done { + Poll::Ready(None) + } else { + let builder = MessageBuilder::new_stream_vec(); + let response = builder.additional(); - let command = ServiceFeedback::Reconfigure { - idle_timeout: Some(Duration::from_millis(5000)), - }; + let command = ServiceFeedback::Reconfigure { + idle_timeout: Some(Duration::from_millis(5000)), + }; - let call_result = CallResult::new(response).with_feedback(command); + let call_result = CallResult::new(response).with_feedback(command); + self.done = true; - Poll::Ready(Some(Ok(call_result))) + Poll::Ready(Some(Ok(call_result))) + } } } @@ -299,7 +312,7 @@ impl Service> for MyService { type Stream = MySingle; fn call(&self, _msg: Request>) -> MySingle { - MySingle + MySingle::new() } } @@ -334,39 +347,15 @@ fn mk_query() -> StreamTarget> { // signal that time has passed when in fact it actually hasn't, allowing a // time dependent test to run much faster without actual periods of // waiting to allow time to elapse. -// #[tokio::test(flavor = "current_thread", start_paused = true)] -// async fn service_test() { -// let (srv_handle, server_status_printer_handle) = { -// let fast_client = MockClientConfig { -// new_message_every: Duration::from_millis(100), -// messages: VecDeque::from([ -// mk_query().as_stream_slice().to_vec(), -// mk_query().as_stream_slice().to_vec(), -// mk_query().as_stream_slice().to_vec(), -// mk_query().as_stream_slice().to_vec(), -// mk_query().as_stream_slice().to_vec(), -// ]), -// client_port: 1, -// }; -// let slow_client = MockClientConfig { -// new_message_every: Duration::from_millis(3000), -// messages: VecDeque::from([ -// mk_query().as_stream_slice().to_vec(), -// mk_query().as_stream_slice().to_vec(), -// ]), -// client_port: 2, -// }; -// let num_messages = -// fast_client.messages.len() + slow_client.messages.len(); -// let streams_to_read = VecDeque::from([fast_client, slow_client]); -// let new_client_every = Duration::from_millis(2000); -// let listener = MockListener::new(streams_to_read, new_client_every); -// let ready_flag = listener.get_ready_flag(); - -// let buf = MockBufSource; -// let my_service = Arc::new(MyService::new()); -// let srv = -// Arc::new(StreamServer::new(listener, buf, my_service.clone())); +// mk_query().as_dgram_slice().to_vec(), +// mk_query().as_dgram_slice().to_vec(), +// mk_query().as_dgram_slice().to_vec(), +// mk_query().as_dgram_slice().to_vec(), +// mk_query().as_dgram_slice().to_vec(), + ]), +// mk_query().as_dgram_slice().to_vec(), +// mk_query().as_dgram_slice().to_vec(), + ]), // let metrics = srv.metrics(); // let server_status_printer_handle = tokio::spawn(async move { From 7549bd0bfc47d1b7380ebbffd5d44541b7189efe Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 30 Apr 2024 16:41:35 +0200 Subject: [PATCH 12/28] - Layered service based middleware. - Return Future of Stream from Service::call() instead of the custom Transaction type. - Relax trait bounds/move trait bounds nearer to the point of use. - Introduce PostprocessingStream impl of Stream for stream based upstream result post-processing. - Introduce ServiceResult aka io::Result. - Modify the serve_zone example to assume it is powered by (not included in these changes) XFR middleware. --- Cargo.toml | 4 +- examples/serve-zone.rs | 349 +++----------- examples/server-transports.rs | 613 +++++++++++++++---------- src/net/server/connection.rs | 259 ++++++----- src/net/server/dgram.rs | 241 ++-------- src/net/server/message.rs | 6 +- src/net/server/middleware/cookies.rs | 235 ++++------ src/net/server/middleware/edns.rs | 268 ++++------- src/net/server/middleware/mandatory.rs | 269 ++++------- src/net/server/middleware/mod.rs | 2 +- src/net/server/middleware/stream.rs | 166 +++++++ src/net/server/middleware/util.rs | 80 ---- src/net/server/mod.rs | 4 +- src/net/server/service.rs | 256 +++-------- src/net/server/stream.rs | 64 +-- src/net/server/tests.rs | 162 ++++--- src/net/server/util.rs | 70 ++- src/zonetree/zone.rs | 2 +- tests/net-server.rs | 227 ++++----- 19 files changed, 1380 insertions(+), 1897 deletions(-) create mode 100644 src/net/server/middleware/stream.rs delete mode 100644 src/net/server/middleware/util.rs diff --git a/Cargo.toml b/Cargo.toml index 047f22127..99bb3ff51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,9 +17,9 @@ name = "domain" path = "src/lib.rs" [dependencies] -octseq = { version = "0.5.1", default-features = false } +octseq = { version = "0.5.1", default-features = false } pin-project-lite = "0.2" -time = { version = "0.3.1", default-features = false } +time = { version = "0.3.1", default-features = false } rand = { version = "0.8", optional = true } arc-swap = { version = "1.7.0", optional = true } diff --git a/examples/serve-zone.rs b/examples/serve-zone.rs index b13a04969..849cef8e8 100644 --- a/examples/serve-zone.rs +++ b/examples/serve-zone.rs @@ -15,32 +15,29 @@ //! //! dig @127.0.0.1 -p 8053 AXFR example.com -use domain::base::iana::{Opcode, Rcode}; -use domain::base::message_builder::AdditionalBuilder; -use domain::base::{Dname, Message, Rtype, ToDname}; +use std::future::pending; +use std::io::BufReader; +use std::sync::Arc; +use std::time::Duration; + +use tokio::net::{TcpListener, UdpSocket}; +use tracing_subscriber::EnvFilter; + +use domain::base::iana::Rcode; +use domain::base::ToDname; use domain::net::server::buf::VecBufSource; use domain::net::server::dgram::DgramServer; use domain::net::server::message::Request; +#[cfg(feature = "siphasher")] use domain::net::server::middleware::cookies::CookiesMiddlewareSvc; use domain::net::server::middleware::edns::EdnsMiddlewareSvc; use domain::net::server::middleware::mandatory::MandatoryMiddlewareSvc; -use domain::net::server::service::{CallResult, ServiceError}; +use domain::net::server::service::{CallResult, ServiceResult}; use domain::net::server::stream::StreamServer; use domain::net::server::util::{mk_builder_for_target, service_fn}; use domain::zonefile::inplace; -use domain::zonetree::{Answer, Rrset}; +use domain::zonetree::Answer; use domain::zonetree::{Zone, ZoneTree}; -use futures::stream::{once, FuturesOrdered, Once}; -use futures::StreamExt; -use octseq::OctetsBuilder; -use std::future::{pending, ready, Future}; -use std::io::BufReader; -use std::ops::DerefMut; -use std::pin::Pin; -use std::sync::{Arc, Mutex}; -use std::time::Duration; -use tokio::net::{TcpListener, UdpSocket}; -use tracing_subscriber::EnvFilter; #[tokio::main()] async fn main() { @@ -54,24 +51,50 @@ async fn main() { .ok(); // Populate a zone tree with test data - let mut zones = ZoneTree::new(); let zone_bytes = include_bytes!("../test-data/zonefiles/nsd-example.txt"); let mut zone_bytes = BufReader::new(&zone_bytes[..]); + // let zone_bytes = std::fs::File::open("/etc/nsd/zones/de-zone").unwrap(); + // let mut zone_bytes = BufReader::new(zone_bytes); // We're reading from static data so this cannot fail due to I/O error. // Don't handle errors that shouldn't happen, keep the example focused // on what we want to demonstrate. let reader = inplace::Zonefile::load(&mut zone_bytes).unwrap(); let zone = Zone::try_from(reader).unwrap(); - zones.insert_zone(zone).unwrap(); + + // TODO: Make changes to a zone to create a diff for IXFR use. + // let mut writer = zone.write().await; + // { + // let node = writer.open(true).await.unwrap(); + // let mut new_ns = Rrset::new(Rtype::NS, Ttl::from_secs(60)); + // let ns_rec = domain::rdata::Ns::new( + // Dname::from_str("write-test.example.com").unwrap(), + // ); + // new_ns.push_data(ns_rec.into()); + // node.update_rrset(SharedRrset::new(new_ns)).await.unwrap(); + // } + // let diff = writer.commit().await.unwrap(); + + let mut zones = ZoneTree::new(); + zones.insert_zone(zone.clone()).unwrap(); let zones = Arc::new(zones); let addr = "127.0.0.1:8053"; - let business_svc = service_fn(my_service, zones); - - let svc = Arc::new(MandatoryMiddlewareSvc::new(EdnsMiddlewareSvc::new( - CookiesMiddlewareSvc::with_random_secret(business_svc), - ))); + let svc = service_fn(my_service, zones); + + // TODO: Insert XFR middleware to automagically handle AXFR and IXFR + // requests. + // let mut svc = XfrMiddlewareSvc::, _>::new(svc); + // svc.add_zone(zone.clone()); + // if let Some(diff) = diff { + // svc.add_diff(&zone, diff); + // } + + #[cfg(feature = "siphasher")] + let svc = CookiesMiddlewareSvc::, _>::with_random_secret(svc); + let svc = EdnsMiddlewareSvc::, _>::new(svc); + let svc = MandatoryMiddlewareSvc::, _>::new(svc); + let svc = Arc::new(svc); let sock = UdpSocket::bind(addr).await.unwrap(); let sock = Arc::new(sock); @@ -91,21 +114,32 @@ async fn main() { tokio::spawn(async move { tcp_srv.run().await }); + eprintln!("Ready"); + tokio::spawn(async move { loop { tokio::time::sleep(Duration::from_millis(5000)).await; - for (i, metrics) in udp_metrics.iter().enumerate() { - eprintln!( - "Server status: UDP[{i}]: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}", - metrics.num_connections(), - metrics.num_inflight_requests(), - metrics.num_pending_writes(), - metrics.num_received_requests(), - metrics.num_sent_responses(), - ); + + let mut udp_num_connections = 0; + let mut udp_num_inflight_requests = 0; + let mut udp_num_pending_writes = 0; + let mut udp_num_received_requests = 0; + let mut udp_num_sent_responses = 0; + + for metrics in udp_metrics.iter() { + udp_num_connections += metrics.num_connections(); + udp_num_inflight_requests += metrics.num_inflight_requests(); + udp_num_pending_writes += metrics.num_pending_writes(); + udp_num_received_requests += metrics.num_received_requests(); + udp_num_sent_responses += metrics.num_sent_responses(); } eprintln!( - "Server status: TCP: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}", + "Server status: #conn/#in-flight/#pending-writes/#msgs-recvd/#msgs-sent: UDP={}/{}/{}/{}/{} TCP={}/{}/{}/{}/{}", + udp_num_connections, + udp_num_inflight_requests, + udp_num_pending_writes, + udp_num_received_requests, + udp_num_sent_responses, tcp_metrics.num_connections(), tcp_metrics.num_inflight_requests(), tcp_metrics.num_pending_writes(), @@ -118,68 +152,11 @@ async fn main() { pending::<()>().await; } -enum SingleOrStream { - Single( - Once< - Pin< - Box< - dyn std::future::Future< - Output = Result< - CallResult>, - ServiceError, - >, - > + Send, - >, - >, - >, - ), - - Stream( - Box< - dyn futures::stream::Stream< - Item = Result>, ServiceError>, - > + Unpin - + Send, - >, - ), -} - -impl futures::stream::Stream for SingleOrStream { - type Item = Result>, ServiceError>; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - match self.deref_mut() { - SingleOrStream::Single(s) => s.poll_next_unpin(cx), - SingleOrStream::Stream(s) => s.poll_next_unpin(cx), - } - } -} - #[allow(clippy::type_complexity)] fn my_service( request: Request>, zones: Arc, -) -> SingleOrStream { - let qtype = request.message().sole_question().unwrap().qtype(); - match qtype { - Rtype::AXFR if request.transport_ctx().is_non_udp() => { - SingleOrStream::Stream(Box::new(handle_axfr_request( - request, zones, - ))) - } - _ => SingleOrStream::Single(once(Box::pin(handle_non_axfr_request( - request, zones, - )))), - } -} - -async fn handle_non_axfr_request( - request: Request>, - zones: Arc, -) -> Result>, ServiceError> { +) -> ServiceResult> { let question = request.message().sole_question().unwrap(); let zone = zones .find_zone(question.qname(), question.qclass()) @@ -197,193 +174,3 @@ async fn handle_non_axfr_request( let additional = answer.to_message(request.message(), builder); Ok(CallResult::new(additional)) } - -fn handle_axfr_request( - request: Request>, - zones: Arc, -) -> FuturesOrdered< - Pin< - Box< - dyn Future>, ServiceError>> - + Send, - >, - >, -> { - let mut stream = FuturesOrdered::< - Pin< - Box< - dyn Future>, ServiceError>> - + Send, - >, - >, - >::new(); - - // Look up the zone for the queried name. - let question = request.message().sole_question().unwrap(); - let zone = zones - .find_zone(question.qname(), question.qclass()) - .map(|zone| zone.read()); - - // If not found, return an NXDOMAIN error response. - let Some(zone) = zone else { - let answer = Answer::new(Rcode::NXDOMAIN); - add_to_stream(answer, request.message(), &mut stream); - return stream; - }; - - // https://datatracker.ietf.org/doc/html/rfc5936#section-2.2 - // 2.2: AXFR Response - // - // "An AXFR response that is transferring the zone's contents - // will consist of a series (which could be a series of - // length 1) of DNS messages. In such a series, the first - // message MUST begin with the SOA resource record of the - // zone, and the last message MUST conclude with the same SOA - // resource record. Intermediate messages MUST NOT contain - // the SOA resource record. The AXFR server MUST copy the - // Question section from the corresponding AXFR query message - // into the first response message's Question section. For - // subsequent messages, it MAY do the same or leave the - // Question section empty." - - // Get the SOA record as AXFR transfers must start and end with the SOA - // record. If not found, return a SERVFAIL error response. - let qname = question.qname().to_bytes(); - let Ok(soa_answer) = zone.query(qname, Rtype::SOA) else { - let answer = Answer::new(Rcode::SERVFAIL); - add_to_stream(answer, request.message(), &mut stream); - return stream; - }; - - // Push the begin SOA response message into the stream - add_to_stream(soa_answer.clone(), request.message(), &mut stream); - - // "The AXFR protocol treats the zone contents as an unordered - // collection (or to use the mathematical term, a "set") of - // RRs. Except for the requirement that the transfer must - // begin and end with the SOA RR, there is no requirement to - // send the RRs in any particular order or grouped into - // response messages in any particular way. Although servers - // typically do attempt to send related RRs (such as the RRs - // forming an RRset, and the RRsets of a name) as a - // contiguous group or, when message space allows, in the - // same response message, they are not required to do so, and - // clients MUST accept any ordering and grouping of the - // non-SOA RRs. Each RR SHOULD be transmitted only once, and - // AXFR clients MUST ignore any duplicate RRs received. - // - // Each AXFR response message SHOULD contain a sufficient - // number of RRs to reasonably amortize the per-message - // overhead, up to the largest number that will fit within a - // DNS message (taking the required content of the other - // sections into account, as described below). - // - // Some old AXFR clients expect each response message to - // contain only a single RR. To interoperate with such - // clients, the server MAY restrict response messages to a - // single RR. As there is no standard way to automatically - // detect such clients, this typically requires manual - // configuration at the server." - - let stream = Arc::new(Mutex::new(stream)); - let cloned_stream = stream.clone(); - let cloned_msg = request.message().clone(); - - let op = Box::new(move |owner: Dname<_>, rrset: &Rrset| { - if rrset.rtype() != Rtype::SOA { - let builder = mk_builder_for_target(); - let mut answer = - builder.start_answer(&cloned_msg, Rcode::NOERROR).unwrap(); - for item in rrset.data() { - answer.push((owner.clone(), rrset.ttl(), item)).unwrap(); - } - - let additional = answer.additional(); - let mut stream = cloned_stream.lock().unwrap(); - add_additional_to_stream(additional, &cloned_msg, &mut stream); - } - }); - zone.walk(op); - - let mutex = Arc::try_unwrap(stream).unwrap(); - let mut stream = mutex.into_inner().unwrap(); - - // Push the end SOA response message into the stream - add_to_stream(soa_answer, request.message(), &mut stream); - - stream -} - -#[allow(clippy::type_complexity)] -fn add_to_stream( - answer: Answer, - msg: &Message>, - stream: &mut FuturesOrdered< - Pin< - Box< - dyn Future>, ServiceError>> - + Send, - >, - >, - >, -) { - let builder = mk_builder_for_target(); - let additional = answer.to_message(msg, builder); - add_additional_to_stream(additional, msg, stream); -} - -#[allow(clippy::type_complexity)] -fn add_additional_to_stream( - mut additional: AdditionalBuilder>>, - msg: &Message>, - stream: &mut FuturesOrdered< - Pin< - Box< - dyn Future>, ServiceError>> - + Send, - >, - >, - >, -) { - set_axfr_header(msg, &mut additional); - stream.push_back(Box::pin(ready(Ok(CallResult::new(additional))))); -} - -fn set_axfr_header( - msg: &Message>, - additional: &mut AdditionalBuilder, -) where - Target: AsMut<[u8]>, - Target: OctetsBuilder, -{ - // https://datatracker.ietf.org/doc/html/rfc5936#section-2.2.1 - // 2.2.1: Header Values - // - // "These are the DNS message header values for AXFR responses. - // - // ID MUST be copied from request -- see Note a) - // - // QR MUST be 1 (Response) - // - // OPCODE MUST be 0 (Standard Query) - // - // Flags: - // AA normally 1 -- see Note b) - // TC MUST be 0 (Not truncated) - // RD RECOMMENDED: copy request's value; MAY be set to 0 - // RA SHOULD be 0 -- see Note c) - // Z "mbz" -- see Note d) - // AD "mbz" -- see Note d) - // CD "mbz" -- see Note d)" - let header = additional.header_mut(); - header.set_id(msg.header().id()); - header.set_qr(true); - header.set_opcode(Opcode::QUERY); - header.set_aa(true); - header.set_tc(false); - header.set_rd(msg.header().rd()); - header.set_ra(false); - header.set_z(false); - header.set_ad(false); - header.set_cd(false); -} diff --git a/examples/server-transports.rs b/examples/server-transports.rs index 7d62993b2..aec0c7e0d 100644 --- a/examples/server-transports.rs +++ b/examples/server-transports.rs @@ -1,20 +1,20 @@ -use core::future::ready; - use core::fmt; -use core::fmt::Debug; -use core::future::{Future, Ready}; -use core::ops::ControlFlow; +use core::future::{ready, Future, Ready}; use core::sync::atomic::{AtomicBool, AtomicU8, Ordering}; use core::task::{Context, Poll}; use core::time::Duration; + use std::fs::File; use std::io; use std::io::BufReader; use std::net::SocketAddr; use std::path::Path; +use std::pin::Pin; use std::sync::Arc; use std::sync::RwLock; +use futures::channel::mpsc::unbounded; +use futures::stream::{once, Empty, Once, Stream}; use octseq::{FreezeBuilder, Octets}; use rustls_pemfile::{certs, rsa_private_keys}; use tokio::net::{TcpListener, TcpSocket, TcpStream, UdpSocket}; @@ -28,25 +28,25 @@ use domain::base::iana::{Class, Rcode}; use domain::base::message_builder::{AdditionalBuilder, PushError}; use domain::base::name::ToLabelIter; use domain::base::wire::Composer; -use domain::base::{Dname, MessageBuilder, StreamTarget}; +use domain::base::{Dname, MessageBuilder, Rtype, Serial, StreamTarget, Ttl}; use domain::net::server::buf::VecBufSource; -use domain::net::server::dgram; use domain::net::server::dgram::DgramServer; use domain::net::server::message::Request; -use domain::net::server::middleware::builder::MiddlewareBuilder; -use domain::net::server::middleware::processor::MiddlewareProcessor; #[cfg(feature = "siphasher")] -use domain::net::server::middleware::processors::cookies::CookiesMiddlewareProcessor; -use domain::net::server::middleware::processors::mandatory::MandatoryMiddlewareProcessor; +use domain::net::server::middleware::cookies::CookiesMiddlewareSvc; +use domain::net::server::middleware::edns::EdnsMiddlewareSvc; +use domain::net::server::middleware::mandatory::MandatoryMiddlewareSvc; +use domain::net::server::middleware::stream::{ + MiddlewareStream, PostprocessingStream, +}; use domain::net::server::service::{ - CallResult, Service, ServiceError, ServiceFeedback, Transaction, + CallResult, Service, ServiceFeedback, ServiceResult, }; use domain::net::server::sock::AsyncAccept; -use domain::net::server::stream; use domain::net::server::stream::StreamServer; use domain::net::server::util::{mk_builder_for_target, service_fn}; -use domain::net::server::ConnectionConfig; -use domain::rdata::A; +use domain::rdata::{Soa, A}; +use std::vec::Vec; //----------- mk_answer() ---------------------------------------------------- @@ -70,29 +70,136 @@ where Ok(answer.additional()) } +fn mk_soa_answer( + msg: &Request>, + builder: MessageBuilder>, +) -> Result>, PushError> +where + Target: Octets + Composer + FreezeBuilder, + ::AppendError: fmt::Debug, +{ + let mname: Dname> = "a.root-servers.net".parse().unwrap(); + let rname = "nstld.verisign-grs.com".parse().unwrap(); + let mut answer = + builder.start_answer(msg.message(), Rcode::NOERROR).unwrap(); + answer.push(( + Dname::root_slice(), + 86390, + Soa::new( + mname, + rname, + Serial(2020081701), + Ttl::from_secs(1800), + Ttl::from_secs(900), + Ttl::from_secs(604800), + Ttl::from_secs(86400), + ), + ))?; + Ok(answer.additional()) +} + //----------- Example Service trait implementations -------------------------- -//--- MyService +//--- MySingleResultService -struct MyService; +struct MySingleResultService; /// This example shows how to implement the [`Service`] trait directly. /// +/// By implementing the trait directly you can do async calls with .await by +/// returning an async block, and can control the type of stream used and how +/// and when it gets populated. Neither are possible if implementing a service +/// via a simple compatible function signature or via service_fn, examples of +/// which can be seen below. +/// +/// For readability this example uses nonsensical future and stream types, +/// nonsensical because the future doesn't do any waiting and the stream +/// doesn't do any streaming. See the example below for a more complex case. +/// /// See [`query`] and [`name_to_ip`] for ways of implementing the [`Service`] /// trait for a function instead of a struct. -impl Service> for MyService { +impl Service> for MySingleResultService { type Target = Vec; - type Future = Ready, ServiceError>>; + type Stream = Once>>; + type Future = Ready; - fn call( - &self, - request: Request>, - ) -> Result, ServiceError> { + fn call(&self, request: Request>) -> Self::Future { let builder = mk_builder_for_target(); - let additional = mk_answer(&request, builder)?; - let item = ready(Ok(CallResult::new(additional))); - let txn = Transaction::single(item); - Ok(txn) + let additional = mk_answer(&request, builder).unwrap(); + let item = Ok(CallResult::new(additional)); + ready(once(ready(item))) + } +} + +//--- MyAsyncStreamingService + +struct MyAsyncStreamingService; + +/// This example also shows how to implement the [`Service`] trait directly. +/// +/// It implements a very simplistic dummy AXFR responder which can be tested +/// using `dig AXFR `. +/// +/// Unlike the simpler example above which returns a fixed type of future and +/// stream which are neither waiting nor streaming, this example goes to the +/// other extreme of returning future and stream types which are determined at +/// runtime (and thus involve Box'ing). +/// +/// There is a middle ground not shown here whereby you return concrete Future +/// and/or Stream implementations that actually wait and/or stream, e.g. +/// making the Stream type be UnboundedReceiver instead of Pin>. +impl Service> for MyAsyncStreamingService { + type Target = Vec; + type Stream = + Pin> + Send>>; + type Future = Pin + Send>>; + + fn call(&self, request: Request>) -> Self::Future { + Box::pin(async move { + if !matches!( + request + .message() + .sole_question() + .map(|q| q.qtype() == Rtype::AXFR), + Ok(true) + ) { + let builder = mk_builder_for_target(); + let additional = builder + .start_answer(request.message(), Rcode::NOTIMP) + .unwrap() + .additional(); + let item = Ok(CallResult::new(additional)); + let immediate_result = once(ready(item)); + return Box::pin(immediate_result) as Self::Stream; + } + + let (sender, receiver) = unbounded(); + let cloned_sender = sender.clone(); + + tokio::spawn(async move { + // Dummy AXFR response: SOA, record, SOA + tokio::time::sleep(Duration::from_millis(100)).await; + let builder = mk_builder_for_target(); + let additional = mk_soa_answer(&request, builder).unwrap(); + let item = Ok(CallResult::new(additional)); + cloned_sender.unbounded_send(item).unwrap(); + + tokio::time::sleep(Duration::from_millis(100)).await; + let builder = mk_builder_for_target(); + let additional = mk_answer(&request, builder).unwrap(); + let item = Ok(CallResult::new(additional)); + cloned_sender.unbounded_send(item).unwrap(); + + tokio::time::sleep(Duration::from_millis(100)).await; + let builder = mk_builder_for_target(); + let additional = mk_soa_answer(&request, builder).unwrap(); + let item = Ok(CallResult::new(additional)); + cloned_sender.unbounded_send(item).unwrap(); + }); + + Box::pin(receiver) as Self::Stream + }) } } @@ -104,20 +211,7 @@ impl Service> for MyService { /// The function signature is slightly more complex than when using /// [`service_fn`] (see the [`query`] example below). #[allow(clippy::type_complexity)] -fn name_to_ip( - request: Request>, -) -> Result< - Transaction< - Target, - impl Future, ServiceError>> + Send, - >, - ServiceError, -> -where - Target: - Composer + Octets + FreezeBuilder + Default + Send, - ::AppendError: Debug, -{ +fn name_to_ip(request: Request>) -> ServiceResult> { let mut out_answer = None; if let Ok(question) = request.message().sole_question() { let qname = question.qname(); @@ -153,8 +247,7 @@ where } let additional = out_answer.unwrap().additional(); - let item = Ok(CallResult::new(additional)); - Ok(Transaction::single(ready(item))) + Ok(CallResult::new(additional)) } //--- query() @@ -165,45 +258,28 @@ where /// The function signature is slightly simpler to write than when not using /// [`service_fn`] and supports passing in meta data without any extra /// boilerplate. -#[allow(clippy::type_complexity)] fn query( request: Request>, count: Arc, -) -> Result< - Transaction< - Vec, - impl Future>, ServiceError>> + Send, - >, - ServiceError, -> { +) -> ServiceResult> { let cnt = count .fetch_update(Ordering::SeqCst, Ordering::SeqCst, |x| { Some(if x > 0 { x - 1 } else { 0 }) }) .unwrap(); - // This fn blocks the server until it returns. By returning a future that - // handles the request we allow the server to execute the future in the - // background without blocking the server. - let fut = async move { - eprintln!("Sleeping for 100ms"); - tokio::time::sleep(Duration::from_millis(100)).await; - - // Note: A real service would have application logic here to process - // the request and generate an response. - - let idle_timeout = Duration::from_millis((50 * cnt).into()); - let cmd = ServiceFeedback::Reconfigure { - idle_timeout: Some(idle_timeout), - }; - eprintln!("Setting idle timeout to {idle_timeout:?}"); + // Note: A real service would have application logic here to process + // the request and generate an response. - let builder = mk_builder_for_target(); - let answer = mk_answer(&request, builder)?; - let res = CallResult::new(answer).with_feedback(cmd); - Ok(res) + let idle_timeout = Duration::from_millis((50 * cnt).into()); + let cmd = ServiceFeedback::Reconfigure { + idle_timeout: Some(idle_timeout), }; - Ok(Transaction::single(fut)) + eprintln!("Setting idle timeout to {idle_timeout:?}"); + + let builder = mk_builder_for_target(); + let answer = mk_answer(&request, builder).unwrap(); + Ok(CallResult::new(answer).with_feedback(cmd)) } //----------- Example socket trait implementations --------------------------- @@ -355,9 +431,9 @@ impl AsyncAccept for RustlsTcpListener { //----------- CustomMiddleware ----------------------------------------------- #[derive(Default)] -struct Stats { - slowest_req: Duration, - fastest_req: Duration, +pub struct Stats { + slowest_req: Option, + fastest_req: Option, num_req_bytes: u32, num_resp_bytes: u32, num_reqs: u32, @@ -366,59 +442,41 @@ struct Stats { num_udp: u32, } -#[derive(Default)] -pub struct StatsMiddlewareProcessor { - stats: RwLock, -} - -impl StatsMiddlewareProcessor { - /// Creates an instance of this processor. - #[must_use] - pub fn new() -> Self { - Default::default() +impl std::fmt::Display for Stats { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "# Reqs={} [UDP={}, IPv4={}, IPv6={}] Bytes [rx={}, tx={}] Speed [fastest={}, slowest={}]", + self.num_reqs, + self.num_udp, + self.num_ipv4, + self.num_ipv6, + self.num_req_bytes, + self.num_resp_bytes, + self.fastest_req.map(|v| format!("{}μs", v.as_micros())).unwrap_or_else(|| "-".to_string()), + self.slowest_req.map(|v| format!("{}ms", v.as_millis())).unwrap_or_else(|| "-".to_string()), + ) } } -impl std::fmt::Display for StatsMiddlewareProcessor { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let stats = self.stats.read().unwrap(); - write!(f, "# Reqs={} [UDP={}, IPv4={}, IPv6={}] Bytes [rx={}, tx={}] Speed [fastest={}μs, slowest={}μs]", - stats.num_reqs, - stats.num_udp, - stats.num_ipv4, - stats.num_ipv6, - stats.num_req_bytes, - stats.num_resp_bytes, - stats.fastest_req.as_micros(), - stats.slowest_req.as_micros())?; - Ok(()) - } +pub struct StatsMiddlewareSvc { + svc: Svc, + stats: Arc>, } -impl MiddlewareProcessor - for StatsMiddlewareProcessor -where - RequestOctets: AsRef<[u8]> + Octets, - Target: Composer + Default, -{ - fn preprocess( - &self, - _request: &Request, - ) -> ControlFlow>> { - ControlFlow::Continue(()) +impl StatsMiddlewareSvc { + /// Creates an instance of this processor. + #[must_use] + pub fn new(svc: Svc, stats: Arc>) -> Self { + Self { svc, stats } } - fn postprocess( - &self, - request: &Request, - _response: &mut AdditionalBuilder>, - ) { - let duration = Instant::now().duration_since(request.received_at()); + fn preprocess(&self, request: &Request) + where + RequestOctets: Octets + Send + Sync + Unpin, + { let mut stats = self.stats.write().unwrap(); stats.num_reqs += 1; stats.num_req_bytes += request.message().as_slice().len() as u32; - stats.num_resp_bytes += _response.as_slice().len() as u32; if request.transport_ctx().is_udp() { stats.num_udp += 1; @@ -429,14 +487,100 @@ where } else { stats.num_ipv6 += 1; } + } + + fn postprocess( + request: &Request, + response: &AdditionalBuilder>, + stats: Arc>, + ) where + RequestOctets: Octets + Send + Sync + Unpin, + Svc: Service, + Svc::Target: AsRef<[u8]>, + { + let duration = Instant::now().duration_since(request.received_at()); + let mut stats = stats.write().unwrap(); + + stats.num_resp_bytes += response.as_slice().len() as u32; - if duration < stats.fastest_req { - stats.fastest_req = duration; + if duration < stats.fastest_req.unwrap_or(Duration::MAX) { + stats.fastest_req = Some(duration); } - if duration > stats.slowest_req { - stats.slowest_req = duration; + if duration > stats.slowest_req.unwrap_or(Duration::ZERO) { + stats.slowest_req = Some(duration); } } + + fn map_stream_item( + request: Request, + stream_item: ServiceResult, + stats: Arc>, + ) -> ServiceResult + where + RequestOctets: Octets + Send + Sync + Unpin, + Svc: Service, + Svc::Target: AsRef<[u8]>, + { + if let Ok(cr) = &stream_item { + if let Some(response) = cr.response() { + Self::postprocess(&request, response, stats); + } + } + stream_item + } +} + +impl Service for StatsMiddlewareSvc +where + RequestOctets: Octets + Send + Sync + 'static + Unpin, + Svc: Service, + Svc::Target: AsRef<[u8]>, + Svc::Future: Unpin, +{ + type Target = Svc::Target; + type Stream = MiddlewareStream< + Svc::Stream, + PostprocessingStream< + RequestOctets, + Svc::Future, + Svc::Stream, + Arc>, + >, + Empty>, + ServiceResult, + >; + type Future = Ready; + + fn call(&self, request: Request) -> Self::Future { + self.preprocess(&request); + let svc_call_fut = self.svc.call(request.clone()); + let map = PostprocessingStream::new( + svc_call_fut, + request, + self.stats.clone(), + Self::map_stream_item, + ); + ready(MiddlewareStream::Map(map)) + } +} + +//------------ build_middleware_chain() -------------------------------------- + +#[allow(clippy::type_complexity)] +fn build_middleware_chain( + svc: Svc, + stats: Arc>, +) -> StatsMiddlewareSvc< + MandatoryMiddlewareSvc< + Vec, + EdnsMiddlewareSvc, CookiesMiddlewareSvc, Svc>>, + >, +> { + #[cfg(feature = "siphasher")] + let svc = CookiesMiddlewareSvc::, _>::with_random_secret(svc); + let svc = EdnsMiddlewareSvc::, _>::new(svc); + let svc = MandatoryMiddlewareSvc::, _>::new(svc); + StatsMiddlewareSvc::new(svc, stats.clone()) } //----------- main() --------------------------------------------------------- @@ -447,8 +591,8 @@ async fn main() { eprintln!(" dig +short -4 @127.0.0.1 -p 8053 A 1.2.3.4"); eprintln!(" dig +short -4 @127.0.0.1 +tcp -p 8053 A google.com"); eprintln!(" dig +short -4 @127.0.0.1 -p 8054 A google.com"); - eprintln!(" dig +short -4 @127.0.0.1 +tcp -p 8080 A google.com"); - eprintln!(" dig +short -6 @::1 +tcp -p 8080 A google.com"); + eprintln!(" dig +short -4 @127.0.0.1 +tcp -p 8080 AXFR google.com"); + eprintln!(" dig +short -6 @::1 +tcp -p 8080 AXFR google.com"); eprintln!(" dig +short -4 @127.0.0.1 +tcp -p 8081 A google.com"); eprintln!(" dig +short -4 @127.0.0.1 +tls -p 8443 A google.com"); @@ -463,33 +607,66 @@ async fn main() { .ok(); // ----------------------------------------------------------------------- - // Wrap `MyService` in an `Arc` so that it can be used by multiple servers - // at once. - let svc = Arc::new(MyService); + // Inject a custom statistics middleware service (defined above) at the + // start of each middleware chain constructed below so that it can time + // the request processing time from as early till as late as possible + // (excluding time spent in the servers that receive the requests and send + // the responses). Each chain needs its own copy of the stats middleware + // but they can share a single set of statistic counters. + let stats = Arc::new(RwLock::new(Stats::default())); // ----------------------------------------------------------------------- - // Prepare a modern middleware chain for use by servers defined below. - // Inject a custom statistics middleware processor (defined above) at the - // start of the chain so that it can time the request processing time from - // as early till as late as possible (excluding time spent in the servers - // that receive the requests and send the responses). - let mut middleware = MiddlewareBuilder::default(); - let stats = Arc::new(StatsMiddlewareProcessor::new()); - middleware.push_front(stats.clone()); - let middleware = middleware.build(); + // Create services with accompanying middleware chains to answer incoming + // requests. + + // 1. MySingleResultService: a struct that implements the `Service` trait + // directly. + let my_svc = Arc::new(build_middleware_chain( + MySingleResultService, + stats.clone(), + )); + + // 2. MyAsyncStreamingService: another struct that implements the + // `Service` trait directly. + let my_async_svc = Arc::new(build_middleware_chain( + MyAsyncStreamingService, + stats.clone(), + )); + + // 2. name_to_ip: a service impl defined as a function compatible with the + // `Service` trait. + let name_into_ip_svc = + Arc::new(build_middleware_chain(name_to_ip, stats.clone())); + + // 3. query: a service impl defined as a function converted to a `Service` + // impl via the `service_fn()` helper function. + // Show that we don't have to use the same middleware with every server by + // creating a separate middleware chain for use just by this server. + let count = Arc::new(AtomicU8::new(5)); + let svc = service_fn(query, count); + let svc = MandatoryMiddlewareSvc::, _>::new(svc); + #[cfg(feature = "siphasher")] + let svc = { + let server_secret = "server12secret34".as_bytes().try_into().unwrap(); + CookiesMiddlewareSvc::, _>::new(svc, server_secret) + }; + let svc = StatsMiddlewareSvc::new(svc, stats.clone()); + let query_svc = Arc::new(svc); // ----------------------------------------------------------------------- - // Run a DNS server on UDP port 8053 on 127.0.0.1. Test it like so: + // Run a DNS server on UDP port 8053 on 127.0.0.1 using the name_to_ip + // service defined above and accompanying middleware. Test it like so: // dig +short -4 @127.0.0.1 -p 8053 A google.com + let udpsocket = UdpSocket::bind("127.0.0.1:8053").await.unwrap(); let buf = Arc::new(VecBufSource); - let mut config = dgram::Config::default(); - config.set_middleware_chain(middleware.clone()); - let srv = - DgramServer::with_config(udpsocket, buf.clone(), name_to_ip, config); - + let srv = DgramServer::new(udpsocket, buf.clone(), name_into_ip_svc); let udp_join_handle = tokio::spawn(async move { srv.run().await }); + // ----------------------------------------------------------------------- + // Create an instance of our MyService `Service` impl with accompanying + // middleware. + // ----------------------------------------------------------------------- // Run a DNS server on TCP port 8053 on 127.0.0.1. Test it like so: // dig +short +keepopen +tcp -4 @127.0.0.1 -p 8053 A google.com @@ -498,16 +675,7 @@ async fn main() { v4socket.bind("127.0.0.1:8053".parse().unwrap()).unwrap(); let v4listener = v4socket.listen(1024).unwrap(); let buf = Arc::new(VecBufSource); - let mut conn_config = ConnectionConfig::default(); - conn_config.set_middleware_chain(middleware.clone()); - let mut config = stream::Config::default(); - config.set_connection_config(conn_config); - let srv = StreamServer::with_config( - v4listener, - buf.clone(), - svc.clone(), - config, - ); + let srv = StreamServer::new(v4listener, buf.clone(), query_svc.clone()); let srv = srv.with_pre_connect_hook(|stream| { // Demonstrate one way without having access to the code that creates // the socket initially to enable TCP keep alive, @@ -531,42 +699,41 @@ async fn main() { let tcp_join_handle = tokio::spawn(async move { srv.run().await }); + // ----------------------------------------------------------------------- + // This UDP example sets IP_MTU_DISCOVER via setsockopt(), using the libc + // crate (as the nix crate doesn't support IP_MTU_DISCOVER at the time of + // writing). This example is inspired by: + // + // - https://www.ietf.org/archive/id/draft-ietf-dnsop-avoid-fragmentation-17.html#name-recommendations-for-udp-res + // - https://mailarchive.ietf.org/arch/msg/dnsop/Zy3wbhHephubsy2uJesGeDst4F4/ + // - https://man7.org/linux/man-pages/man7/ip.7.html + // + // Some other good reading on sending faster via UDP with Rust: + // - https://devork.be/blog/2023/11/modern-linux-sockets/ + // + // We could also try the following settings that the Unbound man page + // mentions: + // - SO_RCVBUF - Unbound advises setting so-rcvbuf to 4m on busy + // servers to prevent short request spikes causing + // packet drops, + // - SO_SNDBUF - Unbound advises setting so-sndbuf to 4m on busy + // servers to avoid resource temporarily unavailable + // errors, + // - SO_REUSEPORT - Unbound advises to turn it off at extreme load to + // distribute queries evenly, + // - IP_TRANSPARENT - Allows to bind to non-existent IP addresses that + // are going to exist later on. Unbound uses + // IP_BINDANY on FreeBSD and SO_BINDANY on OpenBSD. + // - IP_FREEBIND - Linux only, similar to IP_TRANSPARENT. Allows to + // bind to IP addresses that are nonlocal or do not + // exist, like when the network interface is down. + // - TCP_MAXSEG - Value lower than common MSS on Ethernet (1220 for + // example) will address path MTU problem. + // - A means to control the value of the Differentiated Services + // Codepoint (DSCP) in the differentiated services field (DS) of the + // outgoing IP packet headers. #[cfg(target_os = "linux")] let udp_mtu_join_handle = { - // This UDP example sets IP_MTU_DISCOVER via setsockopt(), using the - // libc crate (as the nix crate doesn't support IP_MTU_DISCOVER at the - // time of writing). This example is inspired by: - // - // - https://www.ietf.org/archive/id/draft-ietf-dnsop-avoid-fragmentation-17.html#name-recommendations-for-udp-res - // - https://mailarchive.ietf.org/arch/msg/dnsop/Zy3wbhHephubsy2uJesGeDst4F4/ - // - https://man7.org/linux/man-pages/man7/ip.7.html - // - // Some other good reading on sending faster via UDP with Rust: - // - https://devork.be/blog/2023/11/modern-linux-sockets/ - // - // We could also try the following settings that the Unbound man page - // mentions: - // - SO_RCVBUF - Unbound advises setting so-rcvbuf to 4m on busy - // servers to prevent short request spikes causing - // packet drops, - // - SO_SNDBUF - Unbound advises setting so-sndbuf to 4m on busy - // servers to avoid resource temporarily - // unavailable errors, - // - SO_REUSEPORT - Unbound advises to turn it off at extreme load - // to distribute queries evenly, - // - IP_TRANSPARENT - Allows to bind to non-existent IP addresses - // that are going to exist later on. Unbound uses - // IP_BINDANY on FreeBSD and SO_BINDANY on - // OpenBSD. - // - IP_FREEBIND - Linux only, similar to IP_TRANSPARENT. Allows - // to bind to IP addresses that are nonlocal or do - // not exist, like when the network interface is - // down. - // - TCP_MAXSEG - Value lower than common MSS on Ethernet (1220 - // for example) will address path MTU problem. - // - A means to control the value of the Differentiated Services - // Codepoint (DSCP) in the differentiated services field (DS) of - // the outgoing IP packet headers. fn setsockopt(socket: libc::c_int, flag: libc::c_int) -> libc::c_int { unsafe { libc::setsockopt( @@ -595,14 +762,7 @@ async fn main() { } } - let mut config = dgram::Config::default(); - config.set_middleware_chain(middleware.clone()); - let srv = DgramServer::with_config( - udpsocket, - buf.clone(), - svc.clone(), - config, - ); + let srv = DgramServer::new(udpsocket, buf.clone(), my_svc.clone()); tokio::spawn(async move { srv.run().await }) }; @@ -623,38 +783,28 @@ async fn main() { let v6listener = v6socket.listen(1024).unwrap(); let listener = DoubleListener::new(v4listener, v6listener); - let mut conn_config = ConnectionConfig::new(); - conn_config.set_middleware_chain(middleware.clone()); - let mut config = stream::Config::new(); - config.set_connection_config(conn_config); - let srv = - StreamServer::with_config(listener, buf.clone(), svc.clone(), config); + let srv = StreamServer::new(listener, buf.clone(), my_async_svc); let double_tcp_join_handle = tokio::spawn(async move { srv.run().await }); // ----------------------------------------------------------------------- - // Demonstrate listening with TCP Fast Open enabled (via the tokio-tfo crate). - // On Linux strace can be used to show that the socket options are indeed - // set as expected, e.g.: + // Demonstrate listening with TCP Fast Open enabled (via the tokio-tfo + // crate). On Linux strace can be used to show that the socket options are + // indeed set as expected, e.g.: // // > strace -e trace=setsockopt cargo run --example serve \ // --features serve,tokio-tfo --release // Finished release [optimized] target(s) in 0.12s // Running `target/release/examples/serve` - // setsockopt(6, SOL_SOCKET, SO_REUSEADDR, [1], 4) = 0 - // setsockopt(7, SOL_SOCKET, SO_REUSEADDR, [1], 4) = 0 - // setsockopt(8, SOL_SOCKET, SO_REUSEADDR, [1], 4) = 0 - // setsockopt(8, SOL_TCP, TCP_FASTOPEN, [1024], 4) = 0 + // setsockopt(6, SOL_SOCKET, SO_REUSEADDR, [1], 4) = 0 setsockopt(7, + // SOL_SOCKET, SO_REUSEADDR, [1], 4) = 0 setsockopt(8, SOL_SOCKET, + // SO_REUSEADDR, [1], 4) = 0 setsockopt(8, SOL_TCP, TCP_FASTOPEN, + // [1024], 4) = 0 let listener = TfoListener::bind("127.0.0.1:8081".parse().unwrap()) .await .unwrap(); let listener = LocalTfoListener(listener); - let mut conn_config = ConnectionConfig::new(); - conn_config.set_middleware_chain(middleware.clone()); - let mut config = stream::Config::new(); - config.set_connection_config(conn_config); - let srv = - StreamServer::with_config(listener, buf.clone(), svc.clone(), config); + let srv = StreamServer::new(listener, buf.clone(), my_svc.clone()); let tfo_join_handle = tokio::spawn(async move { srv.run().await }); // ----------------------------------------------------------------------- @@ -677,34 +827,7 @@ async fn main() { let listener = TcpListener::bind("127.0.0.1:8082").await.unwrap(); let listener = BufferedTcpListener(listener); - let count = Arc::new(AtomicU8::new(5)); - - // Make our service from the `query` function with the help of the - // `service_fn` function. - let fn_svc = service_fn(query, count); - - // Show that we don't have to use the same middleware with every server by - // creating a separate middleware chain for use just by this server, and - // also show that by creating the individual middleware processors - // ourselves we can override their default configuration. - let mut fn_svc_middleware = MiddlewareBuilder::new(); - fn_svc_middleware.push(MandatoryMiddlewareProcessor::new().into()); - - #[cfg(feature = "siphasher")] - { - let server_secret = "server12secret34".as_bytes().try_into().unwrap(); - fn_svc_middleware - .push(CookiesMiddlewareProcessor::new(server_secret).into()); - } - - let fn_svc_middleware = fn_svc_middleware.build(); - - let mut conn_config = ConnectionConfig::new(); - conn_config.set_middleware_chain(fn_svc_middleware); - let mut config = stream::Config::new(); - config.set_connection_config(conn_config); - let srv = - StreamServer::with_config(listener, buf.clone(), fn_svc, config); + let srv = StreamServer::new(listener, buf.clone(), query_svc); let fn_join_handle = tokio::spawn(async move { srv.run().await }); // ----------------------------------------------------------------------- @@ -739,23 +862,17 @@ async fn main() { let acceptor = TlsAcceptor::from(Arc::new(config)); let listener = TcpListener::bind("127.0.0.1:8443").await.unwrap(); let listener = RustlsTcpListener::new(listener, acceptor); - - let mut conn_config = ConnectionConfig::new(); - conn_config.set_middleware_chain(middleware.clone()); - let mut config = stream::Config::new(); - config.set_connection_config(conn_config); - let srv = - StreamServer::with_config(listener, buf.clone(), svc.clone(), config); + let srv = StreamServer::new(listener, buf.clone(), my_svc.clone()); let tls_join_handle = tokio::spawn(async move { srv.run().await }); // ----------------------------------------------------------------------- // Print statistics periodically tokio::spawn(async move { - let mut interval = tokio::time::interval(Duration::from_secs(15)); + let mut interval = tokio::time::interval(Duration::from_secs(5)); loop { interval.tick().await; - println!("Statistics report: {stats}"); + println!("Statistics report: {}", stats.read().unwrap()); } }); diff --git a/src/net/server/connection.rs b/src/net/server/connection.rs index 4f10eeaa3..bab15a30a 100644 --- a/src/net/server/connection.rs +++ b/src/net/server/connection.rs @@ -23,15 +23,15 @@ use crate::base::{Message, StreamTarget}; use crate::net::server::buf::BufSource; use crate::net::server::message::Request; use crate::net::server::metrics::ServerMetrics; -use crate::net::server::service::{ - CallResult, Service, ServiceError, ServiceFeedback, -}; +use crate::net::server::service::{Service, ServiceError, ServiceFeedback}; use crate::net::server::util::to_pcap_text; use crate::utils::config::DefMinMax; use super::message::{NonUdpTransportContext, TransportSpecificContext}; use super::stream::Config as ServerConfig; use super::ServerCommand; +use crate::base::message_builder::AdditionalBuilder; +use arc_swap::ArcSwap; use std::fmt::Display; /// Limit on the amount of time to allow between client requests. @@ -87,6 +87,7 @@ const MAX_QUEUED_RESPONSES: DefMinMax = DefMinMax::new(10, 0, 1024); //----------- Config --------------------------------------------------------- /// Configuration for a stream server connection. +#[derive(Copy, Debug)] pub struct Config { /// Limit on the amount of time to allow between client requests. /// @@ -114,9 +115,6 @@ pub struct Config { /// Limit on the number of DNS responses queued for wriing to the client. max_queued_responses: usize, - // /// The middleware chain used to pre-process requests and post-process - // /// responses. - // middleware_chain: MiddlewareChain, } impl Config { @@ -200,25 +198,6 @@ impl Config { pub fn set_max_queued_responses(&mut self, value: usize) { self.max_queued_responses = value; } - - // /// Set the middleware chain used to pre-process requests and post-process - // /// responses. - // /// - // /// # Reconfigure - // /// - // /// On [`StreamServer::reconfigure`] only new connections created after - // /// this setting is changed will use the new value, existing connections - // /// and in-flight requests (and their responses) will continue to use - // /// their current middleware chain. - // /// - // /// [`StreamServer::reconfigure`]: - // /// super::stream::StreamServer::reconfigure() - // pub fn set_middleware_chain( - // &mut self, - // value: MiddlewareChain, - // ) { - // self.middleware_chain = value; - // } } //--- Default @@ -237,11 +216,7 @@ impl Default for Config { impl Clone for Config { fn clone(&self) -> Self { - Self { - idle_timeout: self.idle_timeout, - response_write_timeout: self.response_write_timeout, - max_queued_responses: self.max_queued_responses, - } + *self } } @@ -251,7 +226,8 @@ impl Clone for Config { pub struct Connection where Buf: BufSource, - Svc: Service, + Buf::Output: Send + Sync + Unpin, + Svc: Service + Clone, { /// Flag used by the Drop impl to track if the metric count has to be /// decreased or not. @@ -265,7 +241,7 @@ where /// /// Note: Some reconfiguration is possible at runtime via /// [`ServerCommand::Reconfigure`] and [`ServiceFeedback::Reconfigure`]. - config: Config, + config: Arc>, /// The address of the connected client. addr: SocketAddr, @@ -280,11 +256,11 @@ where /// The reader for consuming from the queue of responses waiting to be /// written back to the client. - result_q_rx: mpsc::Receiver>, + result_q_rx: mpsc::Receiver>>, /// The writer for pushing ready responses onto the queue waiting /// to be written back the client. - result_q_tx: mpsc::Sender>, + result_q_tx: mpsc::Sender>>, /// A [`Service`] for handling received requests and generating responses. service: Svc, @@ -302,9 +278,8 @@ impl Connection where Stream: AsyncRead + AsyncWrite, Buf: BufSource, - Buf::Output: Octets, - Svc: Service, - Svc::Target: Composer + Default, + Buf::Output: Octets + Send + Sync + Unpin, + Svc: Service + Clone, { /// Creates a new handler for an accepted stream connection. #[must_use] @@ -339,6 +314,7 @@ where let (stream_rx, stream_tx) = tokio::io::split(stream); let (result_q_tx, result_q_rx) = mpsc::channel(config.max_queued_responses); + let config = Arc::new(ArcSwap::from_pointee(config)); let idle_timer = IdleTimer::new(); // Place the ReadHalf of the stream into an Option so that we can take @@ -372,13 +348,10 @@ impl Connection where Stream: AsyncRead + AsyncWrite + Send + Sync + 'static, Buf: BufSource + Send + Sync + Clone + 'static, - Buf::Output: Octets + Send + Sync, - Svc: Service + Send + Sync + 'static, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, - Svc::Target: Send + Composer + Default, + Buf::Output: Octets + Send + Sync + Unpin, + Svc: Service + Clone + Send + Sync + 'static, + Svc::Target: Composer + Send, + Svc::Stream: Send, { /// Start reading requests and writing responses to the stream. /// @@ -395,7 +368,7 @@ where mut self, command_rx: watch::Receiver>, ) where - Svc::Stream: Send, + Svc::Future: Send, { self.metrics.inc_num_connections(); @@ -412,13 +385,11 @@ impl Connection where Stream: AsyncRead + AsyncWrite + Send + Sync + 'static, Buf: BufSource + Send + Sync + Clone + 'static, - Buf::Output: Octets + Send + Sync, - Svc: Service + Send + Sync + 'static, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, - Svc::Target: Send + Composer + Default, + Buf::Output: Octets + Send + Sync + Unpin, + Svc: Service + Clone + Send + Sync + 'static, + Svc::Target: Composer + Send, + Svc::Future: Send, + Svc::Stream: Send, { /// Connection handler main loop. async fn run_until_error( @@ -455,7 +426,7 @@ where self.process_queued_result(res).await } - _ = sleep_until(self.idle_timer.idle_timeout_deadline(self.config.idle_timeout)) => { + _ = sleep_until(self.idle_timer.idle_timeout_deadline(self.config.load().idle_timeout)) => { self.process_dns_idle_timeout() } @@ -534,12 +505,7 @@ where } ServerCommand::Reconfigure(ServerConfig { - connection_config: - Config { - idle_timeout, - response_write_timeout, - max_queued_responses: _, - }, + connection_config, .. // Ignore the Server specific configuration settings }) => { // Support RFC 7828 "The edns-tcp-keepalive EDNS0 Option". @@ -554,9 +520,7 @@ where // mechanism to signal to us that we should adjust the point // at which we will consider the connectin to be idle and thus // potentially worthy of timing out. - debug!("Server connection timeout reconfigured to {idle_timeout:?}"); - self.config.idle_timeout = *idle_timeout; - self.config.response_write_timeout = *response_write_timeout; + self.config.store(Arc::new(*connection_config)); } ServerCommand::Shutdown => { @@ -576,7 +540,10 @@ where } /// Stop queueing new responses and process those already in the queue. - async fn flush_write_queue(&mut self) { + async fn flush_write_queue(&mut self) + // where + // Target: Composer, + { debug!("Flushing connection write queue."); // Stop accepting new response messages (should we check for in-flight // messages that haven't generated a response yet but should be @@ -585,10 +552,9 @@ where trace!("Stop queueing up new results."); self.result_q_rx.close(); trace!("Process already queued results."); - while let Some(call_result) = self.result_q_rx.recv().await { + while let Some(response) = self.result_q_rx.recv().await { trace!("Processing queued result."); - if let Err(err) = - self.process_queued_result(Some(call_result)).await + if let Err(err) = self.process_queued_result(Some(response)).await { warn!("Error while processing queued result: {err}"); } else { @@ -601,27 +567,27 @@ where /// Process a single queued response. async fn process_queued_result( &mut self, - call_result: Option>, - ) -> Result<(), ConnectionEvent> { + response: Option>>, + ) -> Result<(), ConnectionEvent> +// where + // Target: Composer, + { // If we failed to read the results of requests processed by the // service because the queue holding those results is empty and can no // longer be read from, then there is no point continuing to read from // the input stream because we will not be able to access the result // of processing the request. I'm not sure when this could happen, // perhaps if we were dropped? - let Some(call_result) = call_result else { + let Some(response) = response else { + trace!("Disconnecting due to failed response queue read."); return Err(ConnectionEvent::DisconnectWithFlush); }; - let (response, feedback) = call_result.into_inner(); - - if let Some(feedback) = feedback { - self.process_service_feedback(feedback).await; - } - - if let Some(response) = response { - self.write_response_to_stream(response.finish()).await; - } + trace!( + "Writing queued response with id {} to stream", + response.header().id() + ); + self.write_response_to_stream(response.finish()).await; Ok(()) } @@ -630,7 +596,10 @@ where async fn write_response_to_stream( &mut self, msg: StreamTarget, - ) { + ) + // where + // Target: AsRef<[u8]>, + { if enabled!(Level::TRACE) { let bytes = msg.as_dgram_slice(); let pcap_text = to_pcap_text(bytes, bytes.len()); @@ -638,7 +607,7 @@ where } match timeout( - self.config.response_write_timeout, + self.config.load().response_write_timeout, self.stream_tx.write_all(msg.as_stream_slice()), ) .await @@ -646,7 +615,7 @@ where Err(_) => { error!( "Write timed out (>{:?})", - self.config.response_write_timeout + self.config.load().response_write_timeout ); // TODO: Push it to the back of the queue to retry it? } @@ -665,20 +634,6 @@ where } } - /// Decide what to do with received [`ServiceFeedback`]. - async fn process_service_feedback(&mut self, cmd: ServiceFeedback) { - match cmd { - ServiceFeedback::Reconfigure { idle_timeout } => { - if let Some(idle_timeout) = idle_timeout { - debug!( - "Reconfigured connection timeout to {idle_timeout:?}" - ); - self.config.idle_timeout = idle_timeout; - } - } - } - } - /// Implemnt DNS rules regarding timing out of idle connections. /// /// Disconnects the current connection of the timer is expired, flushing @@ -687,7 +642,7 @@ where // DNS idle timeout elapsed, or was it reset? if self .idle_timer - .idle_timeout_expired(self.config.idle_timeout) + .idle_timeout_expired(self.config.load().idle_timeout) { Err(ConnectionEvent::DisconnectWithoutFlush) } else { @@ -701,11 +656,7 @@ where res: Result, ) -> Result<(), ConnectionEvent> where - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin - + Send, + Svc::Stream: Send, { if let Ok(buf) = res { let received_at = Instant::now(); @@ -732,32 +683,101 @@ where Ok(msg) => { let ctx = NonUdpTransportContext::new(Some( - self.config.idle_timeout, + self.config.load().idle_timeout, )); let ctx = TransportSpecificContext::NonUdp(ctx); let request = Request::new(self.addr, received_at, msg, ctx); - let mut stream = self.service.call(request); - while let Some(Ok(call_result)) = stream.next().await { - match self.result_q_tx.try_send(call_result) { - Ok(()) => { - self.metrics.set_num_pending_writes( - self.result_q_tx.max_capacity() - - self.result_q_tx.capacity(), - ); - } - Err(TrySendError::Closed(_msg)) => { - // TODO: How should we properly communicate this to the operator? - error!("Unable to queue message for sending: server is shutting down."); + let svc = self.service.clone(); + let result_q_tx = self.result_q_tx.clone(); + let metrics = self.metrics.clone(); + let config = self.config.clone(); + + trace!( + "Spawning task to handle new message with id {}", + request.message().header().id() + ); + tokio::spawn(async move { + let request_id = request.message().header().id(); + trace!("Calling service for request id {request_id}"); + let mut stream = svc.call(request).await; + let mut in_transaction = false; + + trace!("Awaiting service call results for request id {request_id}"); + while let Some(Ok(call_result)) = stream.next().await + { + trace!("Processing service call result for request id {request_id}"); + let (response, feedback) = + call_result.into_inner(); + + if let Some(feedback) = feedback { + match feedback { + ServiceFeedback::Reconfigure { + idle_timeout, + } => { + if let Some(idle_timeout) = + idle_timeout + { + debug!( + "Reconfigured connection timeout to {idle_timeout:?}" + ); + let guard = config.load(); + let mut new_config = **guard; + new_config.idle_timeout = + idle_timeout; + config + .store(Arc::new(new_config)); + } + } + + ServiceFeedback::BeginTransaction => { + in_transaction = true; + } + + ServiceFeedback::EndTransaction => { + in_transaction = false; + } + } } - Err(TrySendError::Full(_msg)) => { - // TODO: How should we properly communicate this to the operator? - error!("Unable to queue message for sending: queue is full."); + if let Some(mut response) = response { + loop { + match result_q_tx.try_send(response) { + Ok(()) => { + trace!("Queued message for sending: # pending writes={}", result_q_tx.max_capacity() + - result_q_tx.capacity()); + metrics.set_num_pending_writes( + result_q_tx.max_capacity() + - result_q_tx.capacity(), + ); + break; + } + + Err(TrySendError::Closed(_)) => { + error!("Unable to queue message for sending: server is shutting down."); + break; + } + + Err(TrySendError::Full( + unused_response, + )) => { + if in_transaction { + // Wait until there is space in the message queue. + tokio::task::yield_now() + .await; + response = unused_response; + } else { + error!("Unable to queue message for sending: queue is full."); + break; + } + } + } + } } } - } + trace!("Finished processing service call results for request id {request_id}"); + }); } } } @@ -771,7 +791,8 @@ where impl Drop for Connection where Buf: BufSource, - Svc: Service, + Buf::Output: Send + Sync + Unpin, + Svc: Service + Clone, { fn drop(&mut self) { if self.active { diff --git a/src/net/server/dgram.rs b/src/net/server/dgram.rs index 8ab6a7e9f..dfd6961a5 100644 --- a/src/net/server/dgram.rs +++ b/src/net/server/dgram.rs @@ -20,7 +20,6 @@ use std::string::String; use std::string::ToString; use std::sync::{Arc, Mutex}; -use futures::StreamExt; use octseq::Octets; use tokio::io::ReadBuf; use tokio::net::UdpSocket; @@ -35,22 +34,19 @@ use tracing::{enabled, error, trace}; use crate::base::Message; use crate::net::server::buf::BufSource; use crate::net::server::error::Error; -// use crate::net::server::message::CommonMessageFlow; use crate::net::server::message::Request; use crate::net::server::metrics::ServerMetrics; -// use crate::net::server::middleware::chain::MiddlewareChain; -use crate::net::server::service::{CallResult, Service, ServiceFeedback}; +use crate::net::server::service::{Service, ServiceFeedback}; use crate::net::server::sock::AsyncDgramSock; use crate::net::server::util::to_pcap_text; use crate::utils::config::DefMinMax; use super::buf::VecBufSource; use super::message::{TransportSpecificContext, UdpTransportContext}; -use super::service::ServiceError; -// use super::middleware::builder::MiddlewareBuilder; use super::ServerCommand; use crate::base::wire::Composer; use arc_swap::ArcSwap; +use futures::prelude::stream::StreamExt; /// A UDP transport based DNS server transport. /// @@ -86,21 +82,14 @@ const MAX_RESPONSE_SIZE: DefMinMax = DefMinMax::new(1232, 512, 4096); /// Configuration for a datagram server. #[derive(Debug)] -pub struct Config /**/ { +pub struct Config { /// Limit suggested to [`Service`] on maximum response size to create. max_response_size: Option, /// Limit the time to wait for a complete message to be written to the client. write_timeout: Duration, - // /// The middleware chain used to pre-process requests and post-process - // /// responses. - // middleware_chain: MiddlewareChain, } -// impl Config -// where -// RequestOctets: Octets, -// Target: Composer + Default, impl Config { /// Creates a new, default config. pub fn new() -> Self { @@ -145,51 +134,26 @@ impl Config { pub fn set_write_timeout(&mut self, value: Duration) { self.write_timeout = value; } - - // /// Set the middleware chain used to pre-process requests and post-process - // /// responses. - // /// - // /// # Reconfigure - // /// - // /// On [`DgramServer::reconfigure`]` any change to this setting will only - // /// affect requests (and their responses) received after the setting is - // /// changed, in progress requests will be unaffected. - // pub fn set_middleware_chain( - // &mut self, - // value: MiddlewareChain, - // ) { - // self.middleware_chain = value; - // } } //--- Default -// impl Default for Config -// where -// RequestOctets: Octets, -// Target: Composer + Default, impl Default for Config { fn default() -> Self { Self { max_response_size: Some(MAX_RESPONSE_SIZE.default()), write_timeout: WRITE_TIMEOUT.default(), - // middleware_chain: MiddlewareBuilder::default().build(), } } } //--- Clone -// impl Clone for Config -// where -// RequestOctets: Octets, -// Target: Composer + Default, impl Clone for Config { fn clone(&self) -> Self { Self { max_response_size: self.max_response_size, write_timeout: self.write_timeout, - // middleware_chain: self.middleware_chain.clone(), } } } @@ -197,16 +161,12 @@ impl Clone for Config { //------------ DgramServer --------------------------------------------------- /// A [`ServerCommand`] capable of propagating a DgramServer [`Config`] value. -// type ServerCommandType = ServerCommand>; type ServerCommandType = ServerCommand; /// A thread safe sender of [`ServerCommand`]s. -// type CommandSender = -// Arc>>>; type CommandSender = Arc>>; /// A thread safe receiver of [`ServerCommand`]s. -// type CommandReceiver = watch::Receiver>; type CommandReceiver = watch::Receiver; /// A server for connecting clients via a datagram based network transport to @@ -300,27 +260,25 @@ pub struct DgramServer where Sock: AsyncDgramSock + Send + Sync + 'static, Buf: BufSource + Send + Sync + 'static, - Buf::Output: Octets + Send + Sync, + Buf::Output: Octets + Send + Sync + Unpin, Svc: Service + Send + Sync + 'static + Clone, - Svc::Target: Send + Composer + Default, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, + Svc::Target: Composer + Send, + Svc::Stream: Send, + Svc::Future: Send, { /// The configuration of the server. - config: Arc*/>>, + config: Arc>, /// A receiver for receiving [`ServerCommand`]s. /// /// Used by both the server and spawned connections to react to sent /// commands. - command_rx: CommandReceiver, //, + command_rx: CommandReceiver, /// A sender for sending [`ServerCommand`]s. /// /// Used to signal the server to stop, reconfigure, etc. - command_tx: CommandSender, //, + command_tx: CommandSender, /// The network socket over which client requests will be received /// and responses sent. @@ -342,13 +300,11 @@ impl DgramServer where Sock: AsyncDgramSock + Send + Sync, Buf: BufSource + Send + Sync, - Buf::Output: Octets + Send + Sync, + Buf::Output: Octets + Send + Sync + Unpin, Svc: Service + Send + Sync + Clone, - Svc::Target: Send + Composer + Default, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, + Svc::Target: Composer + Send, + Svc::Stream: Send, + Svc::Future: Send, { /// Constructs a new [`DgramServer`] with default configuration. /// @@ -376,7 +332,7 @@ where sock: Sock, buf: Buf, service: Svc, - config: Config, //, + config: Config, ) -> Self { let (command_tx, command_rx) = watch::channel(ServerCommand::Init); let command_tx = Arc::new(Mutex::new(command_tx)); @@ -401,13 +357,11 @@ impl DgramServer where Sock: AsyncDgramSock + Send + Sync + 'static, Buf: BufSource + Send + Sync + 'static, - Buf::Output: Octets + Send + Sync + 'static + Debug, + Buf::Output: Octets + Send + Sync + 'static + Debug + Unpin, Svc: Service + Send + Sync + 'static + Clone, - Svc::Target: Send + Composer + Debug + Default, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, + Svc::Target: Composer + Send, + Svc::Stream: Send, + Svc::Future: Send, { /// Get a reference to the network source being used to receive messages. #[must_use] @@ -428,13 +382,11 @@ impl DgramServer where Sock: AsyncDgramSock + Send + Sync + 'static, Buf: BufSource + Send + Sync, - Buf::Output: Octets + Send + Sync + 'static, + Buf::Output: Octets + Send + Sync + 'static + Unpin, Svc: Service + Send + Sync + 'static + Clone, - Svc::Target: Send + Composer + Default, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, + Svc::Target: Composer + Send, + Svc::Stream: Send, + Svc::Future: Send, { /// Start the server. /// @@ -443,10 +395,7 @@ where /// When dropped [`shutdown`] will be invoked. /// /// [`shutdown`]: Self::shutdown - pub async fn run(&self) - // where - // Svc::Stream: Send, - { + pub async fn run(&self) { if let Err(err) = self.run_until_error().await { error!("Server stopped due to error: {err}"); } @@ -455,10 +404,7 @@ where /// Reconfigure the server while running. /// /// - pub fn reconfigure( - &self, - config: Config, //,, - ) -> Result<(), Error> { + pub fn reconfigure(&self, config: Config) -> Result<(), Error> { self.command_tx .lock() .map_err(|_| Error::CommandCouldNotBeSent)? @@ -520,19 +466,14 @@ impl DgramServer where Sock: AsyncDgramSock + Send + Sync + 'static, Buf: BufSource + Send + Sync, - Buf::Output: Octets + Send + Sync + 'static, + Buf::Output: Octets + Send + Sync + 'static + Unpin, Svc: Service + Send + Sync + 'static + Clone, - Svc::Target: Send + Composer + Default, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, + Svc::Target: Composer + Send, + Svc::Stream: Send, + Svc::Future: Send, { /// Receive incoming messages until shutdown or fatal error. - async fn run_until_error(&self) -> Result<(), String> -// where - // Svc::Stream: Send, - { + async fn run_until_error(&self) -> Result<(), String> { let mut command_rx = self.command_rx.clone(); loop { @@ -563,17 +504,6 @@ where let state = self.mk_state_for_request(); - // self.process_request( - // msg, received_at, addr, - // self.config.load().middleware_chain.clone(), - // &self.service, - // self.metrics.clone(), - // state, - // ) - // .map_err(|err| - // format!("Error while processing message: {err}") - // )?; - let svc = self.service.clone(); let cfg = self.config.clone(); let metrics = self.metrics.clone(); @@ -588,7 +518,7 @@ where let ctx = UdpTransportContext::new(cfg.load().max_response_size); let ctx = TransportSpecificContext::Udp(ctx); let request = Request::new(addr, received_at, msg, ctx); - let mut stream = svc.call(request); + let mut stream = svc.call(request).await; while let Some(Ok(call_result)) = stream.next().await { let (response, feedback) = call_result.into_inner(); @@ -599,6 +529,10 @@ where } => { // Nothing to do. } + + ServiceFeedback::BeginTransaction|ServiceFeedback::EndTransaction => { + // Nothing to do. + } } } @@ -642,7 +576,7 @@ where fn process_server_command( &self, res: Result<(), watch::error::RecvError>, - command_rx: &mut CommandReceiver, //, + command_rx: &mut CommandReceiver, ) -> Result<(), String> { // If the parent server no longer exists but was not cleanly shutdown // then the command channel will be closed and attempting to check for @@ -729,7 +663,6 @@ where /// [`CommonMessageFlow`] call chain and ultimately back to ourselves at /// [`process_call_reusult`]. fn mk_state_for_request(&self) -> RequestState { - //}, Buf::Output, Svc::Target> { RequestState::new( self.sock.clone(), self.command_tx.clone(), @@ -738,99 +671,17 @@ where } } -// //--- CommonMessageFlow - -// impl CommonMessageFlow -// for DgramServer -// where -// Sock: AsyncDgramSock + Send + Sync + 'static, -// Buf: BufSource + Send + Sync + 'static, -// Buf::Output: Octets + Send + Sync + 'static, -// Svc: Service + Send + Sync + 'static, -// Svc::Target: Send + Composer + Default, -// { -// type Meta = RequestState; - -// /// Add information to the request that relates to the type of server we -// /// are and our state where relevant. -// fn add_context_to_request( -// &self, -// request: Message, -// received_at: Instant, -// addr: SocketAddr, -// ) -> Request { -// let ctx = -// UdpTransportContext::new(self.config.load().max_response_size); -// let ctx = TransportSpecificContext::Udp(ctx); -// Request::new(addr, received_at, request, ctx) -// } - -// /// Process the result from the middleware -> service -> middleware call -// /// tree. -// fn process_call_result( -// request: &Request, -// call_result: CallResult, -// state: RequestState, -// metrics: Arc, -// ) { -// metrics.inc_num_pending_writes(); -// let client_addr = request.client_addr(); - -// tokio::spawn(async move { -// let (response, feedback) = call_result.into_inner(); - -// if let Some(feedback) = feedback { -// match feedback { -// ServiceFeedback::Reconfigure { -// idle_timeout: _, // N/A - only applies to connection-oriented transports -// } => { -// // Nothing to do. -// } -// } -// } - -// // Process the DNS response message, if any. -// if let Some(response) = response { -// // Convert the DNS response message into bytes. -// let target = response.finish(); -// let bytes = target.as_dgram_slice(); - -// // Logging -// if enabled!(Level::TRACE) { -// let pcap_text = to_pcap_text(bytes, bytes.len()); -// trace!(%client_addr, pcap_text, "Sending response"); -// } - -// // Actually write the DNS response message bytes to the UDP -// // socket. -// let _ = Self::send_to( -// &state.sock, -// bytes, -// &client_addr, -// state.write_timeout, -// ) -// .await; - -// metrics.dec_num_pending_writes(); -// metrics.inc_num_sent_responses(); -// } -// }); -// } -// } - //--- Drop impl Drop for DgramServer where Sock: AsyncDgramSock + Send + Sync + 'static, Buf: BufSource + Send + Sync + 'static, - Buf::Output: Octets + Send + Sync, + Buf::Output: Octets + Send + Sync + Unpin, Svc: Service + Send + Sync + 'static + Clone, - Svc::Target: Send + Composer + Default, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, + Svc::Target: Composer + Send, + Svc::Stream: Send, + Svc::Future: Send, { fn drop(&mut self) { // Shutdown the DgramServer. Don't handle the failure case here as @@ -845,7 +696,6 @@ where /// Data needed by [`DgramServer::process_call_result`] which needs to be /// passed through the [`CommonMessageFlow`] call chain. pub struct RequestState { - //, RequestOctets, Target> { /// The network socket over which this request was received and over which /// the response should be sent. sock: Arc, @@ -853,19 +703,18 @@ pub struct RequestState { /// A sender for sending [`ServerCommand`]s. /// /// Used to signal the server to stop, reconfigure, etc. - command_tx: CommandSender, //, + command_tx: CommandSender, /// The maximum amount of time to wait for a response datagram to be /// accepted by the operating system for writing back to the client. write_timeout: Duration, } -impl RequestState { - //, RequestOctets, Target> { +impl RequestState { /// Creates a new instance of [`RequestState`]. fn new( sock: Arc, - command_tx: CommandSender, //, + command_tx: CommandSender, write_timeout: Duration, ) -> Self { Self { @@ -878,9 +727,7 @@ impl RequestState { //--- Clone -impl Clone - for RequestState -{ +impl Clone for RequestState { fn clone(&self) -> Self { Self { sock: self.sock.clone(), diff --git a/src/net/server/message.rs b/src/net/server/message.rs index 0fc06003a..e84e5da1e 100644 --- a/src/net/server/message.rs +++ b/src/net/server/message.rs @@ -137,7 +137,7 @@ impl TransportSpecificContext { /// message itself but also on the circumstances surrounding its creation and /// delivery. #[derive(Debug)] -pub struct Request> { +pub struct Request + Send + Sync + Unpin> { /// The network address of the connected client. client_addr: std::net::SocketAddr, @@ -152,7 +152,7 @@ pub struct Request> { transport_specific: TransportSpecificContext, } -impl> Request { +impl + Send + Sync + Unpin> Request { /// Creates a new request wrapper around a message along with its context. pub fn new( client_addr: std::net::SocketAddr, @@ -191,7 +191,7 @@ impl> Request { //--- Clone -impl> Clone for Request { +impl + Send + Sync + Unpin> Clone for Request { fn clone(&self) -> Self { Self { client_addr: self.client_addr, diff --git a/src/net/server/middleware/cookies.rs b/src/net/server/middleware/cookies.rs index 239ef75a9..4754fabfc 100644 --- a/src/net/server/middleware/cookies.rs +++ b/src/net/server/middleware/cookies.rs @@ -1,16 +1,12 @@ //! DNS Cookies related message processing. -use core::future::ready; +use core::future::{ready, Ready}; use core::marker::PhantomData; use core::ops::ControlFlow; -use core::pin::Pin; -use core::task::{Context, Poll}; use std::net::IpAddr; use std::vec::Vec; -use futures::stream::once; -use futures::Stream; -use futures_util::StreamExt; +use futures::stream::{once, Once}; use octseq::Octets; use rand::RngCore; use tracing::{debug, enabled, trace, warn, Level}; @@ -22,11 +18,12 @@ use crate::base::opt::Cookie; use crate::base::wire::{Composer, ParseError}; use crate::base::{Serial, StreamTarget}; use crate::net::server::message::Request; -use crate::net::server::middleware::util::MiddlewareStream; -use crate::net::server::service::{CallResult, Service, ServiceError}; +use crate::net::server::middleware::stream::MiddlewareStream; +use crate::net::server::service::{CallResult, Service, ServiceResult}; use crate::net::server::util::{add_edns_options, to_pcap_text}; use crate::net::server::util::{mk_builder_for_target, start_reply}; -use std::sync::Arc; + +use super::stream::PostprocessingStream; /// The five minute period referred to by /// https://www.rfc-editor.org/rfc/rfc9018.html#section-4.3. @@ -48,34 +45,37 @@ const ONE_HOUR_AS_SECS: u32 = 60 * 60; /// [7873]: https://datatracker.ietf.org/doc/html/rfc7873 /// [9018]: https://datatracker.ietf.org/doc/html/rfc7873 /// [`MiddlewareProcessor`]: crate::net::server::middleware::processor::MiddlewareProcessor -#[derive(Debug)] -pub struct CookiesMiddlewareSvc { - inner: S, +#[derive(Clone, Debug)] +pub struct CookiesMiddlewareSvc { + svc: Svc, /// A user supplied secret used in making the cookie value. - server_secret: Arc<[u8; 16]>, + server_secret: [u8; 16], /// Clients connecting from these IP addresses will be required to provide /// a cookie otherwise they will receive REFUSED with TC=1 prompting them /// to reconnect with TCP in order to "authenticate" themselves. ip_deny_list: Vec, + + _phantom: PhantomData, } -impl CookiesMiddlewareSvc { +impl CookiesMiddlewareSvc { /// Creates an instance of this processor. #[must_use] - pub fn new(inner: S, server_secret: [u8; 16]) -> Self { + pub fn new(svc: Svc, server_secret: [u8; 16]) -> Self { Self { - inner, - server_secret: Arc::new(server_secret), + svc, + server_secret, ip_deny_list: vec![], + _phantom: PhantomData, } } - pub fn with_random_secret(inner: S) -> Self { + pub fn with_random_secret(svc: Svc) -> Self { let mut server_secret = [0u8; 16]; rand::thread_rng().fill_bytes(&mut server_secret); - Self::new(inner, server_secret) + Self::new(svc, server_secret) } /// Define IP addresses required to supply DNS cookies if using UDP. @@ -89,7 +89,12 @@ impl CookiesMiddlewareSvc { } } -impl CookiesMiddlewareSvc { +impl CookiesMiddlewareSvc +where + RequestOctets: Octets + Send + Sync + Unpin, + Svc: Service, + Svc::Target: Composer + Default, +{ /// Get the DNS COOKIE, if any, for the given message. /// /// https://datatracker.ietf.org/doc/html/rfc7873#section-5.2: Responding @@ -104,7 +109,7 @@ impl CookiesMiddlewareSvc { /// - Some(Err(err)) if the request has a cookie that we could not /// parse. #[must_use] - fn cookie( + fn cookie( request: &Request, ) -> Option> { // Note: We don't use `opt::Opt::first()` because that will silently @@ -140,16 +145,12 @@ impl CookiesMiddlewareSvc { } /// Create a DNS response message for the given request, including cookie. - fn response_with_cookie( + fn response_with_cookie( &self, request: &Request, rcode: OptRcode, - ) -> AdditionalBuilder> - where - RequestOctets: Octets, - Target: Composer + Default, - { - let mut additional = start_reply(request).additional(); + ) -> AdditionalBuilder> { + let mut additional = start_reply(request.message()).additional(); if let Some(Ok(client_cookie)) = Self::cookie(request) { let response_cookie = client_cookie.create_response( @@ -181,14 +182,10 @@ impl CookiesMiddlewareSvc { /// client cookie or is unable to write to an internal buffer while /// constructing the response. #[must_use] - fn bad_cookie_response( + fn bad_cookie_response( &self, request: &Request, - ) -> AdditionalBuilder> - where - RequestOctets: Octets, - Target: Composer + Default, - { + ) -> AdditionalBuilder> { // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.3 // "If the server responds [ed: by sending a BADCOOKIE error // response], it SHALL generate its own COOKIE option containing @@ -201,14 +198,10 @@ impl CookiesMiddlewareSvc { /// Create a DNS response to a client cookie prefetch request. #[must_use] - fn prefetch_cookie_response( + fn prefetch_cookie_response( &self, request: &Request, - ) -> AdditionalBuilder> - where - RequestOctets: Octets, - Target: Composer + Default, - { + ) -> AdditionalBuilder> { // https://datatracker.ietf.org/doc/html/rfc7873#section-5.4 // Querying for a Server Cookie: // "For servers with DNS Cookies enabled, the @@ -226,8 +219,8 @@ impl CookiesMiddlewareSvc { /// Check the cookie contained in the request to make sure that it is /// complete, and if so return the cookie to the caller. #[must_use] - fn ensure_cookie_is_complete( - request: &Request, + fn ensure_cookie_is_complete( + request: &Request, server_secret: &[u8; 16], ) -> Option { if let Some(Ok(cookie)) = Self::cookie(request) { @@ -246,19 +239,11 @@ impl CookiesMiddlewareSvc { None } } -} - -//--- MiddlewareProcessor -impl CookiesMiddlewareSvc { - fn preprocess( + fn preprocess( &self, request: &Request, - ) -> ControlFlow>> - where - RequestOctets: Octets, - Target: Composer + Default, - { + ) -> ControlFlow>> { match Self::cookie(request) { None => { trace!("Request does not include DNS cookies"); @@ -438,13 +423,12 @@ impl CookiesMiddlewareSvc { ControlFlow::Continue(()) } - fn postprocess( + fn postprocess( request: &Request, - response: &mut AdditionalBuilder>, - server_secret: &[u8; 16], + response: &mut AdditionalBuilder>, + server_secret: [u8; 16], ) where RequestOctets: Octets, - Target: Composer + Default, { // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.1 // No OPT RR or No COOKIE Option: @@ -472,7 +456,7 @@ impl CookiesMiddlewareSvc { // pre-processing, we don't need to check it again here. if let Some(filled_cookie) = - Self::ensure_cookie_is_complete(request, server_secret) + Self::ensure_cookie_is_complete(request, &server_secret) { // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.5 // "The server SHALL process the request and include a COOKIE @@ -500,125 +484,68 @@ impl CookiesMiddlewareSvc { trace!(pcap_text, "post-processing complete"); } } + + fn map_stream_item( + request: Request, + mut stream_item: ServiceResult, + server_secret: [u8; 16], + ) -> ServiceResult { + if let Ok(cr) = &mut stream_item { + if let Some(response) = cr.response_mut() { + Self::postprocess(&request, response, server_secret); + } + } + stream_item + } } //--- Service -impl Service - for CookiesMiddlewareSvc +impl Service + for CookiesMiddlewareSvc where - RequestOctets: Octets + 'static, - S: Service, - S::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Unpin - + 'static, - Target: Composer + Default + 'static + Unpin, + RequestOctets: Octets + Send + Sync + 'static + Unpin, + Svc: Service, + Svc::Future: core::future::Future + Unpin, + ::Output: Unpin, + Svc::Target: Composer + Default, { - type Target = Target; + type Target = Svc::Target; type Stream = MiddlewareStream< - S::Stream, - PostprocessingStream, - Target, + Svc::Stream, + PostprocessingStream< + RequestOctets, + Svc::Future, + Svc::Stream, + [u8; 16], + >, + Once::Item>>, + ::Item, >; + type Future = core::future::Ready; - fn call(&self, request: Request) -> Self::Stream { + fn call(&self, request: Request) -> Self::Future { match self.preprocess(&request) { ControlFlow::Continue(()) => { - let st = self.inner.call(request.clone()); + let svc_call_fut = self.svc.call(request.clone()); let map = PostprocessingStream::new( - st, + svc_call_fut, request, - self.server_secret.clone(), + self.server_secret, + Self::map_stream_item, ); - MiddlewareStream::Postprocess(map) + ready(MiddlewareStream::Map(map)) } ControlFlow::Break(mut response) => { Self::postprocess( &request, &mut response, - &self.server_secret, + self.server_secret, ); - - MiddlewareStream::HandledOne(once(ready(Ok( + ready(MiddlewareStream::Result(once(ready(Ok( CallResult::new(response), - )))) + ))))) } } } } - -pub struct PostprocessingStream< - RequestOctets, - Target, - InnerServiceResponseStream, -> where - RequestOctets: Octets, - InnerServiceResponseStream: futures::stream::Stream< - Item = Result, ServiceError>, - >, -{ - request: Request, - server_secret: Arc<[u8; 16]>, - stream: InnerServiceResponseStream, - _phantom: PhantomData, -} - -impl<'a, RequestOctets, Target, InnerServiceResponseStream> - PostprocessingStream -where - RequestOctets: Octets, - InnerServiceResponseStream: futures::stream::Stream< - Item = Result, ServiceError>, - >, -{ - pub(crate) fn new( - stream: InnerServiceResponseStream, - request: Request, - server_secret: Arc<[u8; 16]>, - ) -> Self { - Self { - stream, - request, - server_secret, - _phantom: PhantomData, - } - } -} - -impl Stream - for PostprocessingStream< - RequestOctets, - Target, - InnerServiceResponseStream, - > -where - RequestOctets: Octets, - InnerServiceResponseStream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Unpin, - Target: Composer + Default + Unpin, -{ - type Item = Result, ServiceError>; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let res = futures::ready!(self.stream.poll_next_unpin(cx)); - let request = self.request.clone(); - let server_secret = self.server_secret.clone(); - Poll::Ready(res.map(|mut res| { - if let Ok(cr) = &mut res { - if let Some(response) = cr.get_response_mut() { - CookiesMiddlewareSvc::::postprocess(&request, response, &server_secret); - } - } - res - })) - } - - fn size_hint(&self) -> (usize, Option) { - self.stream.size_hint() - } -} diff --git a/src/net/server/middleware/edns.rs b/src/net/server/middleware/edns.rs index a5cca9f30..2c980a2b0 100644 --- a/src/net/server/middleware/edns.rs +++ b/src/net/server/middleware/edns.rs @@ -1,18 +1,12 @@ //! RFC 6891 and related EDNS message processing. -use core::future::ready; +use core::future::{ready, Ready}; use core::marker::PhantomData; use core::ops::ControlFlow; -use core::task::{Context, Poll}; -use std::pin::Pin; - -use futures::stream::once; -use futures::Stream; -use futures_util::StreamExt; +use futures::stream::{once, Once}; use octseq::Octets; use tracing::{debug, enabled, error, trace, warn, Level}; -use super::mandatory::MINIMUM_RESPONSE_BYTE_LEN; use crate::base::iana::{OptRcode, OptionCode}; use crate::base::message_builder::AdditionalBuilder; use crate::base::opt::keepalive::IdleTimeout; @@ -20,10 +14,14 @@ use crate::base::opt::{Opt, OptRecord, TcpKeepalive}; use crate::base::wire::Composer; use crate::base::StreamTarget; use crate::net::server::message::{Request, TransportSpecificContext}; -use crate::net::server::middleware::util::MiddlewareStream; -use crate::net::server::service::{CallResult, Service, ServiceError}; -use crate::net::server::util::start_reply; -use crate::net::server::util::{add_edns_options, remove_edns_opt_record}; +use crate::net::server::middleware::stream::MiddlewareStream; +use crate::net::server::service::{CallResult, Service, ServiceResult}; +use crate::net::server::util::{ + add_edns_options, mk_error_response, remove_edns_opt_record, +}; + +use super::mandatory::MINIMUM_RESPONSE_BYTE_LEN; +use super::stream::PostprocessingStream; /// EDNS version 0. /// @@ -47,58 +45,34 @@ const EDNS_VERSION_ZERO: u8 = 0; /// [7828]: https://datatracker.ietf.org/doc/html/rfc7828 /// [9210]: https://datatracker.ietf.org/doc/html/rfc9210 /// [`MiddlewareProcessor`]: crate::net::server::middleware::processor::MiddlewareProcessor -#[derive(Debug, Default)] -pub struct EdnsMiddlewareSvc { - inner: S, +#[derive(Clone, Debug, Default)] +pub struct EdnsMiddlewareSvc { + svc: Svc, + + _phantom: PhantomData, } -impl EdnsMiddlewareSvc { +impl EdnsMiddlewareSvc { /// Creates an instance of this processor. #[must_use] - pub fn new(inner: S) -> Self { - Self { inner } - } -} - -impl EdnsMiddlewareSvc { - /// Create a DNS error response to the given request with the given RCODE. - fn error_response( - request: &Request, - rcode: OptRcode, - ) -> AdditionalBuilder> - where - RequestOctets: Octets, - Target: Composer + Default, - { - let mut additional = start_reply(request).additional(); - - // Note: if rcode is non-extended this will also correctly handle - // setting the rcode in the main message header. - if let Err(err) = add_edns_options(&mut additional, |_, opt| { - opt.set_rcode(rcode); - Ok(()) - }) { - warn!( - "Failed to set (extended) error '{rcode}' in response: {err}" - ); + pub fn new(svc: Svc) -> Self { + Self { + svc, + _phantom: PhantomData, } - - Self::postprocess(request, &mut additional); - additional } } -//--- MiddlewareProcessor - -impl EdnsMiddlewareSvc { - fn preprocess( +impl EdnsMiddlewareSvc +where + RequestOctets: Octets + Send + Sync + Unpin, + Svc: Service, + Svc::Target: Composer + Default, +{ + fn preprocess( &self, request: &Request, - ) -> ControlFlow>> - where - RequestOctets: Octets, - Target: Composer + Default, - { + ) -> ControlFlow>> { // https://www.rfc-editor.org/rfc/rfc6891.html#section-6.1.1 // 6.1.1: Basic Elements // ... @@ -111,8 +85,8 @@ impl EdnsMiddlewareSvc { if iter.next().is_some() { // More than one OPT RR received. debug!("RFC 6891 6.1.1 violation: request contains more than one OPT RR."); - return ControlFlow::Break(Self::error_response( - request, + return ControlFlow::Break(mk_error_response( + request.message(), OptRcode::FORMERR, )); } @@ -127,8 +101,8 @@ impl EdnsMiddlewareSvc { // RCODE=BADVERS." if opt_rec.version() > EDNS_VERSION_ZERO { debug!("RFC 6891 6.1.3 violation: request EDNS version {} > 0", opt_rec.version()); - return ControlFlow::Break(Self::error_response( - request, + return ControlFlow::Break(mk_error_response( + request.message(), OptRcode::BADVERS, )); } @@ -148,8 +122,8 @@ impl EdnsMiddlewareSvc { if opt_rec.opt().tcp_keepalive().is_some() { debug!("RFC 7828 3.2.1 violation: edns-tcp-keepalive option received via UDP"); return ControlFlow::Break( - Self::error_response( - request, + mk_error_response( + request.message(), OptRcode::FORMERR, ), ); @@ -226,8 +200,8 @@ impl EdnsMiddlewareSvc { if keep_alive.timeout().is_some() { debug!("RFC 7828 3.2.1 violation: edns-tcp-keepalive option received via TCP contains timeout"); return ControlFlow::Break( - Self::error_response( - request, + mk_error_response( + request.message(), OptRcode::FORMERR, ), ); @@ -242,13 +216,10 @@ impl EdnsMiddlewareSvc { ControlFlow::Continue(()) } - fn postprocess( + fn postprocess( request: &Request, - response: &mut AdditionalBuilder>, - ) where - RequestOctets: Octets, - Target: Composer + Default, - { + response: &mut AdditionalBuilder>, + ) { // https://www.rfc-editor.org/rfc/rfc6891.html#section-6.1.1 // 6.1.1: Basic Elements // ... @@ -273,7 +244,8 @@ impl EdnsMiddlewareSvc { error!( "Error while stripping OPT record from response: {err}" ); - *response = Self::error_response(request, OptRcode::SERVFAIL); + *response = + mk_error_response(request.message(), OptRcode::SERVFAIL); return; } } @@ -334,124 +306,67 @@ impl EdnsMiddlewareSvc { // record in the request) should we set the Requestor's Payload Size // field to some value? } + + fn map_stream_item( + request: Request, + mut stream_item: ServiceResult, + _metadata: (), + ) -> ServiceResult { + if let Ok(cr) = &mut stream_item { + if let Some(response) = cr.response_mut() { + Self::postprocess(&request, response); + } + } + stream_item + } } //--- Service -impl Service for EdnsMiddlewareSvc +impl Service + for EdnsMiddlewareSvc where - RequestOctets: Octets + 'static, - S: Service, - S::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Unpin - + 'static, - Target: Composer + Default + 'static + Unpin, + RequestOctets: Octets + Send + Sync + 'static + Unpin, + Svc: Service, + Svc::Target: Composer + Default, + Svc::Future: Unpin, { - type Target = Target; + type Target = Svc::Target; type Stream = MiddlewareStream< - S::Stream, - PostprocessingStream, - Target, + Svc::Stream, + PostprocessingStream, + Once::Item>>, + ::Item, >; + type Future = core::future::Ready; - fn call(&self, request: Request) -> Self::Stream { + fn call(&self, request: Request) -> Self::Future { match self.preprocess(&request) { ControlFlow::Continue(()) => { - let st = self.inner.call(request.clone()); - let map = PostprocessingStream::new(st, request); - MiddlewareStream::Postprocess(map) + let svc_call_fut = self.svc.call(request.clone()); + let map = PostprocessingStream::new( + svc_call_fut, + request, + (), + Self::map_stream_item, + ); + ready(MiddlewareStream::Map(map)) } ControlFlow::Break(mut response) => { Self::postprocess(&request, &mut response); - MiddlewareStream::HandledOne(once(ready(Ok( + ready(MiddlewareStream::Result(once(ready(Ok( CallResult::new(response), - )))) + ))))) } } } } -pub struct PostprocessingStream< - RequestOctets, - Target, - InnerServiceResponseStream, -> where - RequestOctets: Octets, - InnerServiceResponseStream: futures::stream::Stream< - Item = Result, ServiceError>, - >, -{ - request: Request, - _phantom: PhantomData, - stream: InnerServiceResponseStream, -} - -impl - PostprocessingStream -where - RequestOctets: Octets, - InnerServiceResponseStream: futures::stream::Stream< - Item = Result, ServiceError>, - >, -{ - pub(crate) fn new( - stream: InnerServiceResponseStream, - request: Request, - ) -> Self { - Self { - stream, - request, - _phantom: PhantomData, - } - } -} - -impl Stream - for PostprocessingStream< - RequestOctets, - Target, - InnerServiceResponseStream, - > -where - RequestOctets: Octets, - InnerServiceResponseStream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Unpin, - Target: Composer + Default + Unpin, -{ - type Item = Result, ServiceError>; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let res = futures::ready!(self.stream.poll_next_unpin(cx)); - let request = self.request.clone(); - Poll::Ready(res.map(|mut res| { - if let Ok(cr) = &mut res { - if let Some(response) = cr.get_response_mut() { - EdnsMiddlewareSvc::::postprocess(&request, response); - } - } - res - })) - } - - fn size_hint(&self) -> (usize, Option) { - self.stream.size_hint() - } -} - #[cfg(test)] mod tests { - use core::pin::Pin; - - use std::boxed::Box; use std::vec::Vec; use bytes::Bytes; - use futures::stream::Once; use futures::stream::StreamExt; use tokio::time::Instant; @@ -462,7 +377,7 @@ mod tests { use crate::base::iana::Rcode; use crate::net::server::middleware::mandatory::MINIMUM_RESPONSE_BYTE_LEN; - use crate::net::server::service::{CallResult, Service, ServiceError}; + use crate::net::server::service::{CallResult, Service, ServiceResult}; use crate::net::server::util::{mk_builder_for_target, service_fn}; use super::EdnsMiddlewareSvc; @@ -554,36 +469,23 @@ mod tests { fn my_service( req: Request>, _meta: (), - ) -> Once< - Pin< - Box< - dyn std::future::Future< - Output = Result< - CallResult>, - ServiceError, - >, - > + Send, - >, - >, - > { + ) -> ServiceResult> { // For each request create a single response: - let msg = req.message().clone(); - futures::stream::once(Box::pin(async move { - let builder = mk_builder_for_target(); - let answer = builder.start_answer(&msg, Rcode::NXDOMAIN)?; - Ok(CallResult::new(answer.additional())) - })) + let builder = mk_builder_for_target(); + let answer = + builder.start_answer(req.message(), Rcode::NXDOMAIN)?; + Ok(CallResult::new(answer.additional())) } // Either call the service directly. let my_svc = service_fn(my_service, ()); - let mut stream = my_svc.call(request.clone()); + let mut stream = my_svc.call(request.clone()).await; let _call_result: CallResult> = stream.next().await.unwrap().unwrap(); // Or pass the query through the middleware processor let processor_svc = EdnsMiddlewareSvc::new(my_svc); - let mut stream = processor_svc.call(request.clone()); + let mut stream = processor_svc.call(request.clone()).await; let call_result: CallResult> = stream.next().await.unwrap().unwrap(); let (_response, _feedback) = call_result.into_inner(); diff --git a/src/net/server/middleware/mandatory.rs b/src/net/server/middleware/mandatory.rs index 70d071bc6..912531e21 100644 --- a/src/net/server/middleware/mandatory.rs +++ b/src/net/server/middleware/mandatory.rs @@ -1,25 +1,23 @@ //! Core DNS RFC standards based message processing for MUST requirements. -use core::future::ready; +use core::future::{ready, Ready}; use core::marker::PhantomData; use core::ops::ControlFlow; -use core::pin::Pin; -use core::task::{Context, Poll}; use std::fmt::Display; -use futures::stream::once; -use futures::{Stream, StreamExt}; +use futures::stream::{once, Once}; use octseq::Octets; use tracing::{debug, error, trace, warn}; -use super::util::MiddlewareStream; -use crate::base::iana::{Opcode, Rcode}; +use crate::base::iana::{Opcode, OptRcode}; use crate::base::message_builder::{AdditionalBuilder, PushError}; use crate::base::wire::{Composer, ParseError}; -use crate::base::StreamTarget; +use crate::base::{Message, StreamTarget}; use crate::net::server::message::{Request, TransportSpecificContext}; -use crate::net::server::service::{CallResult, Service, ServiceError}; -use crate::net::server::util::{mk_builder_for_target, start_reply}; +use crate::net::server::service::{CallResult, Service, ServiceResult}; +use crate::net::server::util::{mk_builder_for_target, mk_error_response}; + +use super::stream::{MiddlewareStream, PostprocessingStream}; /// The minimum legal UDP response size in bytes. /// @@ -42,28 +40,26 @@ pub const MINIMUM_RESPONSE_BYTE_LEN: u16 = 512; /// crate::net::server::middleware::processor::MiddlewareProcessor /// [1035]: https://datatracker.ietf.org/doc/html/rfc1035 /// [2181]: https://datatracker.ietf.org/doc/html/rfc2181 -#[derive(Debug)] -pub struct MandatoryMiddlewareSvc { +#[derive(Clone, Debug)] +pub struct MandatoryMiddlewareSvc { /// In strict mode the processor does more checks on requests and /// responses. strict: bool, - inner: S, + svc: Svc, - _phantom: PhantomData<(RequestOctets, Target)>, + _phantom: PhantomData, } -impl - MandatoryMiddlewareSvc -{ +impl MandatoryMiddlewareSvc { /// Creates a new processor instance. /// /// The processor will operate in strict mode. #[must_use] - pub fn new(inner: S) -> Self { + pub fn new(svc: Svc) -> Self { Self { strict: true, - inner, + svc, _phantom: PhantomData, } } @@ -72,34 +68,20 @@ impl /// /// The processor will operate in relaxed mode. #[must_use] - pub fn relaxed(inner: S) -> Self { + pub fn relaxed(svc: Svc) -> Self { Self { strict: false, - inner, + svc, _phantom: PhantomData, } } - - /// Create a DNS error response to the given request with the given RCODE. - fn error_response( - request: &Request, - rcode: Rcode, - strict: bool, - ) -> AdditionalBuilder> - where - RequestOctets: Octets, - Target: Composer + Default, - { - let mut response = start_reply(request); - response.header_mut().set_rcode(rcode); - let mut additional = response.additional(); - Self::postprocess(request, &mut additional, strict); - additional - } } -impl - MandatoryMiddlewareSvc +impl MandatoryMiddlewareSvc +where + RequestOctets: Octets + Send + Sync + Unpin, + Svc: Service, + Svc::Target: Composer + Default, { /// Truncate the given response message if it is too large. /// @@ -113,12 +95,8 @@ impl /// specified byte length. fn truncate( request: &Request, - response: &mut AdditionalBuilder>, - ) -> Result<(), TruncateError> - where - RequestOctets: Octets, - Target: Composer + Default, - { + response: &mut AdditionalBuilder>, + ) -> Result<(), TruncateError> { if let TransportSpecificContext::Udp(ctx) = request.transport_ctx() { // https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1 // "Messages carried by UDP are restricted to 512 bytes (not @@ -213,27 +191,20 @@ impl fn preprocess( &self, - request: &Request, - ) -> ControlFlow>> - where - RequestOctets: Octets, - Target: Composer + Default, - { + msg: &Message, + ) -> ControlFlow>> { // https://www.rfc-editor.org/rfc/rfc3425.html // 3 - Effect on RFC 1035 // .. // "Therefore IQUERY is now obsolete, and name servers SHOULD return // a "Not Implemented" error when an IQUERY request is received." - if self.strict - && request.message().header().opcode() == Opcode::IQUERY - { + if self.strict && msg.header().opcode() == Opcode::IQUERY { debug!( "RFC 3425 3 violation: request opcode IQUERY is obsolete." ); - return ControlFlow::Break(Self::error_response( - request, - Rcode::NOTIMP, - self.strict, + return ControlFlow::Break(mk_error_response( + msg, + OptRcode::NOTIMP, )); } @@ -242,16 +213,13 @@ impl fn postprocess( request: &Request, - response: &mut AdditionalBuilder>, + response: &mut AdditionalBuilder>, strict: bool, - ) where - RequestOctets: Octets, - Target: Composer + Default, - { + ) { if let Err(err) = Self::truncate(request, response) { error!("Error while truncating response: {err}"); *response = - Self::error_response(request, Rcode::SERVFAIL, strict); + mk_error_response(request.message(), OptRcode::SERVFAIL); return; } @@ -292,126 +260,62 @@ impl warn!("RFC 1035 violation: response question count != request question count"); } } + + fn map_stream_item( + request: Request, + mut stream_item: ServiceResult, + strict: bool, + ) -> ServiceResult { + if let Ok(cr) = &mut stream_item { + if let Some(response) = cr.response_mut() { + Self::postprocess(&request, response, strict); + } + } + stream_item + } } //--- Service -impl Service - for MandatoryMiddlewareSvc +impl Service + for MandatoryMiddlewareSvc where - RequestOctets: Octets + 'static, - S: Service, - S::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Unpin - + 'static, - Target: Composer + Default + 'static + Unpin, + RequestOctets: Octets + Send + Sync + 'static + Unpin, + Svc: Service, + Svc::Future: Unpin, + Svc::Target: Composer + Default, { - type Target = Target; + type Target = Svc::Target; type Stream = MiddlewareStream< - S::Stream, - PostprocessingStream, - Target, + Svc::Stream, + PostprocessingStream, + Once::Item>>, + ::Item, >; + type Future = Ready; - fn call(&self, request: Request) -> Self::Stream { - match self.preprocess(&request) { + fn call(&self, request: Request) -> Self::Future { + match self.preprocess(request.message()) { ControlFlow::Continue(()) => { - let st = self.inner.call(request.clone()); - let map = PostprocessingStream::new(st, request, self.strict); - MiddlewareStream::Postprocess(map) + let svc_call_fut = self.svc.call(request.clone()); + let map = PostprocessingStream::new( + svc_call_fut, + request, + self.strict, + Self::map_stream_item, + ); + ready(MiddlewareStream::Map(map)) } ControlFlow::Break(mut response) => { Self::postprocess(&request, &mut response, self.strict); - MiddlewareStream::HandledOne(once(ready(Ok( + ready(MiddlewareStream::Result(once(ready(Ok( CallResult::new(response), - )))) + ))))) } } } } -pub struct PostprocessingStream< - RequestOctets, - Target, - InnerServiceResponseStream, -> where - RequestOctets: Octets, - InnerServiceResponseStream: futures::stream::Stream< - Item = Result, ServiceError>, - >, -{ - request: Request, - strict: bool, - _phantom: PhantomData, - stream: InnerServiceResponseStream, -} - -impl - PostprocessingStream -where - RequestOctets: Octets, - InnerServiceResponseStream: futures::stream::Stream< - Item = Result, ServiceError>, - >, -{ - pub(crate) fn new( - stream: InnerServiceResponseStream, - request: Request, - strict: bool, - ) -> Self { - Self { - stream, - request, - strict, - _phantom: PhantomData, - } - } -} - -impl Stream - for PostprocessingStream< - RequestOctets, - Target, - InnerServiceResponseStream, - > -where - RequestOctets: Octets, - InnerServiceResponseStream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Unpin, - Target: Composer + Default + Unpin, -{ - type Item = Result, ServiceError>; - - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let res = futures::ready!(self.stream.poll_next_unpin(cx)); - let request = self.request.clone(); - let strict = self.strict; - Poll::Ready(res.map(|mut res| { - if let Ok(cr) = &mut res { - if let Some(response) = cr.get_response_mut() { - MandatoryMiddlewareSvc::< - RequestOctets, - InnerServiceResponseStream, - Target, - >::postprocess( - &request, response, strict - ); - } - } - res - })) - } - - fn size_hint(&self) -> (usize, Option) { - self.stream.size_hint() - } -} - //------------ TruncateError ------------------------------------------------- /// An error occured during oversize response truncation. @@ -451,13 +355,9 @@ impl From for TruncateError { #[cfg(test)] mod tests { - use core::pin::Pin; - - use std::boxed::Box; use std::vec::Vec; use bytes::Bytes; - use futures::stream::Once; use futures::StreamExt; use octseq::OctetsBuilder; use tokio::time::Instant; @@ -467,7 +367,7 @@ mod tests { use crate::net::server::message::{ Request, TransportSpecificContext, UdpTransportContext, }; - use crate::net::server::service::{CallResult, Service, ServiceError}; + use crate::net::server::service::{CallResult, Service, ServiceResult}; use crate::net::server::util::{mk_builder_for_target, service_fn}; use super::{MandatoryMiddlewareSvc, MINIMUM_RESPONSE_BYTE_LEN}; @@ -537,36 +437,23 @@ mod tests { fn my_service( req: Request>, _meta: (), - ) -> Once< - Pin< - Box< - dyn std::future::Future< - Output = Result< - CallResult>, - ServiceError, - >, - > + Send, - >, - >, - > { + ) -> ServiceResult> { // For each request create a single response: - let msg = req.message().clone(); - futures::stream::once(Box::pin(async move { - let builder = mk_builder_for_target(); - let answer = builder.start_answer(&msg, Rcode::NXDOMAIN)?; - Ok(CallResult::new(answer.additional())) - })) + let builder = mk_builder_for_target(); + let answer = + builder.start_answer(req.message(), Rcode::NXDOMAIN)?; + Ok(CallResult::new(answer.additional())) } // Either call the service directly. let my_svc = service_fn(my_service, ()); - let mut stream = my_svc.call(request.clone()); + let mut stream = my_svc.call(request.clone()).await; let _call_result: CallResult> = stream.next().await.unwrap().unwrap(); // Or pass the query through the middleware processor let processor_svc = MandatoryMiddlewareSvc::new(my_svc); - let mut stream = processor_svc.call(request); + let mut stream = processor_svc.call(request).await; let call_result: CallResult> = stream.next().await.unwrap().unwrap(); let (response, _feedback) = call_result.into_inner(); diff --git a/src/net/server/middleware/mod.rs b/src/net/server/middleware/mod.rs index 816f10339..1a1434ed8 100644 --- a/src/net/server/middleware/mod.rs +++ b/src/net/server/middleware/mod.rs @@ -2,4 +2,4 @@ pub mod cookies; pub mod edns; pub mod mandatory; -pub mod util; +pub mod stream; diff --git a/src/net/server/middleware/stream.rs b/src/net/server/middleware/stream.rs new file mode 100644 index 000000000..b20847333 --- /dev/null +++ b/src/net/server/middleware/stream.rs @@ -0,0 +1,166 @@ +use core::ops::DerefMut; +use core::task::{Context, Poll}; + +use std::pin::Pin; + +use futures::prelude::future::FutureExt; +use futures::stream::{Stream, StreamExt}; +use octseq::Octets; + +use crate::net::server::message::Request; +use tracing::trace; + +//------------ MiddlewareStream ---------------------------------------------- + +pub enum MiddlewareStream +where + IdentityStream: Stream, + MapStream: Stream, + ResultStream: Stream, +{ + /// The inner service response will be passed through this service without + /// modification. + Identity(IdentityStream), + + /// Either a single response has been created without invoking the innter + /// service, or the inner service response will be post-processed by this + /// service. + Map(MapStream), + + /// A response has been created without invoking the inner service. + Result(ResultStream), +} + +//--- impl Stream + +impl Stream + for MiddlewareStream +where + IdentityStream: Stream + Unpin, + MapStream: Stream + Unpin, + ResultStream: Stream + Unpin, + Self: Unpin, +{ + type Item = StreamItem; + + fn poll_next( + mut self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + ) -> Poll> { + match self.deref_mut() { + MiddlewareStream::Identity(s) => s.poll_next_unpin(cx), + MiddlewareStream::Map(s) => s.poll_next_unpin(cx), + MiddlewareStream::Result(s) => s.poll_next_unpin(cx), + } + } + + fn size_hint(&self) -> (usize, Option) { + match self { + MiddlewareStream::Identity(s) => s.size_hint(), + MiddlewareStream::Map(s) => s.size_hint(), + MiddlewareStream::Result(s) => s.size_hint(), + } + } +} + +//------------ PostprocessingStreamState ------------------------------------- + +enum PostprocessingStreamState +where + Stream: futures::stream::Stream, + Future: core::future::Future, +{ + Pending(Future), + Streaming(Stream), +} + +//------------ PostprocessingStreamCallback ---------------------------------- + +type PostprocessingStreamCallback = + fn(Request, StreamItem, Metadata) -> StreamItem; + +//------------ PostprocessingStream ------------------------------------------ + +pub struct PostprocessingStream +where + RequestOctets: Octets + Send + Sync + Unpin, + Future: core::future::Future, + Stream: futures::stream::Stream, +{ + request: Request, + state: PostprocessingStreamState, + cb: PostprocessingStreamCallback, + metadata: Metadata, +} + +impl + PostprocessingStream +where + RequestOctets: Octets + Send + Sync + Unpin, + Future: core::future::Future, + Stream: futures::stream::Stream, +{ + pub fn new( + svc_call_fut: Future, + request: Request, + metadata: Metadata, + cb: PostprocessingStreamCallback< + RequestOctets, + Stream::Item, + Metadata, + >, + ) -> Self { + Self { + state: PostprocessingStreamState::Pending(svc_call_fut), + request, + cb, + metadata, + } + } +} + +//--- impl Stream + +impl futures::stream::Stream + for PostprocessingStream +where + RequestOctets: Octets + Send + Sync + Unpin, + Future: core::future::Future + Unpin, + Stream: futures::stream::Stream + Unpin, + Self: Unpin, + Metadata: Clone, +{ + type Item = Stream::Item; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll> { + match &mut self.state { + PostprocessingStreamState::Pending(svc_call_fut) => { + let stream = futures::ready!(svc_call_fut.poll_unpin(cx)); + trace!("Stream has become available"); + self.state = PostprocessingStreamState::Streaming(stream); + self.poll_next(cx) + } + PostprocessingStreamState::Streaming(stream) => { + let stream_item = futures::ready!(stream.poll_next_unpin(cx)); + trace!("Stream item retrieved, mapping to downstream type"); + let request = self.request.clone(); + let metadata = self.metadata.clone(); + let map = stream_item + .map(|item| (self.cb)(request, item, metadata)); + Poll::Ready(map) + } + } + } + + fn size_hint(&self) -> (usize, Option) { + match &self.state { + PostprocessingStreamState::Pending(_fut) => (0, None), + PostprocessingStreamState::Streaming(stream) => { + stream.size_hint() + } + } + } +} diff --git a/src/net/server/middleware/util.rs b/src/net/server/middleware/util.rs deleted file mode 100644 index c408450ec..000000000 --- a/src/net/server/middleware/util.rs +++ /dev/null @@ -1,80 +0,0 @@ -use core::ops::DerefMut; - -use std::future::Ready; - -use futures::stream::{FuturesOrdered, Once}; -use futures::Stream; -use futures_util::StreamExt; - -use crate::base::wire::Composer; -use crate::net::server::service::{CallResult, ServiceError}; - -pub enum MiddlewareStream< - InnerServiceResponseStream, - PostprocessingStream, - Target, -> where - InnerServiceResponseStream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Unpin, - PostprocessingStream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Unpin, - Self: Unpin, - Target: Unpin, -{ - /// The inner service response will be passed through this service without - /// modification. - Passthru(InnerServiceResponseStream), - - /// The inner service response will be post-processed by this service. - Postprocess(PostprocessingStream), - - /// A single response has been created without invoking the inner service. - HandledOne(Once, ServiceError>>>), - - /// Multiple responses have been created without invoking the inner - /// service. - HandledMany( - FuturesOrdered, ServiceError>>>, - ), -} - -impl Stream - for MiddlewareStream< - InnerServiceResponseStream, - PostprocessingStream, - Target, - > -where - InnerServiceResponseStream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Unpin, - PostprocessingStream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Unpin, - Target: Composer + Default + Unpin, -{ - type Item = Result, ServiceError>; - - fn poll_next( - mut self: core::pin::Pin<&mut Self>, - cx: &mut core::task::Context<'_>, - ) -> core::task::Poll> { - match self.deref_mut() { - MiddlewareStream::Passthru(s) => s.poll_next_unpin(cx), - MiddlewareStream::Postprocess(s) => s.poll_next_unpin(cx), - MiddlewareStream::HandledOne(s) => s.poll_next_unpin(cx), - MiddlewareStream::HandledMany(s) => s.poll_next_unpin(cx), - } - } - - fn size_hint(&self) -> (usize, Option) { - match self { - MiddlewareStream::Passthru(s) => s.size_hint(), - MiddlewareStream::Postprocess(s) => s.size_hint(), - MiddlewareStream::HandledOne(s) => s.size_hint(), - MiddlewareStream::HandledMany(s) => s.size_hint(), - } - } -} diff --git a/src/net/server/mod.rs b/src/net/server/mod.rs index 13c8c34a7..cae1d2d8d 100644 --- a/src/net/server/mod.rs +++ b/src/net/server/mod.rs @@ -2,8 +2,8 @@ not(feature = "unstable-server-transport"), doc = " The `unstable-server-transport` feature is necessary to enable this module." )] -#![warn(missing_docs)] -#![warn(clippy::missing_docs_in_private_items)] +// #![warn(missing_docs)] +// #![warn(clippy::missing_docs_in_private_items)] //! Receiving requests and sending responses. //! //! This module provides skeleton asynchronous server implementations based on diff --git a/src/net/server/service.rs b/src/net/server/service.rs index 60ae8e226..854320a22 100644 --- a/src/net/server/service.rs +++ b/src/net/server/service.rs @@ -7,25 +7,24 @@ use core::fmt::Display; use core::ops::Deref; -use std::boxed::Box; -use std::future::Future; -use std::pin::Pin; use std::sync::Arc; use std::time::Duration; use std::vec::Vec; -use futures_util::stream::FuturesOrdered; -use futures_util::{FutureExt, StreamExt}; - -use super::message::Request; use crate::base::iana::Rcode; use crate::base::message_builder::{AdditionalBuilder, PushError}; use crate::base::wire::ParseError; use crate::base::StreamTarget; -use octseq::{OctetsBuilder, ShortBuf}; + +use super::message::Request; +use core::future::ready; +use futures::stream::once; //------------ Service ------------------------------------------------------- +/// The type of item that `Service` implementations stream as output. +pub type ServiceResult = Result, ServiceError>; + /// [`Service`]s are responsible for determining how to respond to valid DNS /// requests. /// @@ -199,48 +198,53 @@ use octseq::{OctetsBuilder, ShortBuf}; /// [net::server module documentation]: crate::net::server /// [`call`]: Self::call() /// [`service_fn`]: crate::net::server::util::service_fn() -pub trait Service = Vec> { - /// The type of buffer in which response messages are stored. +pub trait Service + Send + Sync + Unpin = Vec> +{ + /// The underlying byte storage type used to hold generated responses. type Target; - type Stream; + /// The type of stream that the service produces. + type Stream: futures::stream::Stream> + + Unpin; - /// The type of future returned by [`Service::call()`] via - /// [`Transaction::single()`]. - // type Item: ; + /// The type of future that will yield the service result stream. + type Future: core::future::Future; /// Generate a response to a fully pre-processed request. - #[allow(clippy::type_complexity)] - fn call(&self, request: Request) -> Self::Stream; + fn call(&self, request: Request) -> Self::Future; } +//--- impl Service for Arc + /// Helper trait impl to treat an [`Arc`] as a [`Service`]. -impl, T: Service> - Service for Arc +impl Service for Arc +where + RequestOctets: Unpin + Send + Sync + AsRef<[u8]>, + T: ?Sized + Service, { type Target = T::Target; type Stream = T::Stream; + type Future = T::Future; - fn call(&self, request: Request) -> Self::Stream { + fn call(&self, request: Request) -> Self::Future { Arc::deref(self).call(request) } } +//--- impl Service for functions with matching signature + /// Helper trait impl to treat a function as a [`Service`]. -impl Service for F +impl Service for F where - RequestOctets: AsRef<[u8]>, - F: Fn(Request) -> Stream, - Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + 'static, + RequestOctets: AsRef<[u8]> + Send + Sync + Unpin, + F: Fn(Request) -> ServiceResult, { type Target = Target; - type Stream = Stream; + type Stream = futures::stream::Once>>; + type Future = core::future::Ready; - fn call(&self, request: Request) -> Self::Stream { - (*self)(request) + fn call(&self, request: Request) -> Self::Future { + ready(once(ready((*self)(request)))) } } @@ -315,6 +319,12 @@ pub enum ServiceFeedback { /// server to use. idle_timeout: Option, }, + + /// Ensure that messages from this stream are all enqueued, don't drop + /// messages if the outgoing queue is full. + BeginTransaction, + + EndTransaction, } //------------ CallResult ---------------------------------------------------- @@ -334,14 +344,14 @@ pub struct CallResult { /// Optional response to send back to the client. response: Option>>, - /// Optioanl feedback from the [`Service`] to the server. + /// Optional feedback from the [`Service`] to the server. feedback: Option, } impl CallResult -where - Target: OctetsBuilder + AsRef<[u8]> + AsMut<[u8]>, - Target::AppendError: Into, +// where +// Target: OctetsBuilder + AsRef<[u8]> + AsMut<[u8]>, +// Target::AppendError: Into, { /// Construct a [`CallResult`] from a DNS response message. #[must_use] @@ -376,7 +386,15 @@ where /// Get a mutable reference to the contained DNS response message, if any. #[must_use] - pub fn get_response_mut( + pub fn response( + &self, + ) -> Option<&AdditionalBuilder>> { + self.response.as_ref() + } + + /// Get a mutable reference to the contained DNS response message, if any. + #[must_use] + pub fn response_mut( &mut self, ) -> Option<&mut AdditionalBuilder>> { self.response.as_mut() @@ -395,173 +413,3 @@ where (response, feedback) } } - -//------------ Transaction --------------------------------------------------- - -/// Zero or more DNS response futures relating to a single DNS request. -/// -/// A transaction is either empty, a single DNS response future, or a stream -/// of DNS response futures. -/// -/// # Usage -/// -/// Either: -/// - Construct a transaction for a [`single`] response future, OR -/// - Construct a transaction [`stream`] and [`push`] response futures into -/// it. -/// -/// Then iterate over the response futures one at a time using [`next`]. -/// -/// [`single`]: Self::single() -/// [`stream`]: Self::stream() -/// [`push`]: TransactionStream::push() -/// [`next`]: Self::next() -pub struct Transaction(TransactionInner) -where - Future: std::future::Future< - Output = Result, ServiceError>, - >; - -impl Transaction -where - Future: std::future::Future< - Output = Result, ServiceError>, - >, -{ - /// Construct a transaction for a single immediate response. - pub(crate) fn immediate( - item: Result, ServiceError>, - ) -> Self { - Self(TransactionInner::Immediate(Some(item))) - } - - /// Construct an empty transaction. - pub fn empty() -> Self { - Self(TransactionInner::Single(None)) - } - - /// Construct a transaction for a single response future. - pub fn single(fut: Future) -> Self { - Self(TransactionInner::Single(Some(fut))) - } - - /// Construct a transaction for a future stream of response futures. - /// - /// The given future should build the stream of response futures that will - /// eventually be resolved by [`Self::next`]. - /// - /// This takes a future instead of a [`TransactionStream`] because the - /// caller may not yet know how many futures they need to push into the - /// stream and we don't want them to block us while they work that out. - pub fn stream( - fut: Pin< - Box> + Send>, - >, - ) -> Self { - Self(TransactionInner::PendingStream(fut)) - } - - /// Take the next response from the transaction, if any. - /// - /// This function provides a single way to take futures from the - /// transaction without needing to handle which type of transaction it is. - /// - /// Returns None if there are no (more) responses to take, Some(future) - /// otherwise. - pub async fn next( - &mut self, - ) -> Option, ServiceError>> { - match &mut self.0 { - TransactionInner::Immediate(item) => item.take(), - - TransactionInner::Single(opt_fut) => match opt_fut.take() { - Some(fut) => Some(fut.await), - None => None, - }, - - TransactionInner::PendingStream(stream_fut) => { - let mut stream = stream_fut.await; - let next = stream.next().await; - self.0 = TransactionInner::Stream(stream); - next - } - - TransactionInner::Stream(stream) => stream.next().await, - } - } -} - -//------------ TransactionInner ---------------------------------------------- - -/// Private inner details of the [`Transaction`] type. -/// -/// This type exists to (a) hide the `Immediate` variant from the consumer of -/// this library as it is for internal use only and not something a -/// [`Service`] impl should return, and (b) to control the interface offered -/// to consumers of this type and avoid them having to work with the enum -/// variants directly. -enum TransactionInner -where - Future: std::future::Future< - Output = Result, ServiceError>, - >, -{ - /// The transaction will result in a single immediate response. - /// - /// This variant is for internal use only when aborting Middleware - /// processing early. - Immediate(Option, ServiceError>>), - - /// The transaction will result in at most a single response future. - Single(Option), - - /// The transaction will result in stream of multiple response futures. - PendingStream( - Pin> + Send>>, - ), - - /// The transaction is a stream of multiple response futures. - Stream(Stream), -} - -//------------ TransacationStream -------------------------------------------- - -/// A [`TransactionStream`] of [`Service`] results. -type Stream = - TransactionStream, ServiceError>>; - -/// A stream of zero or more DNS response futures relating to a single DNS request. -pub struct TransactionStream { - /// An ordered sequence of futures that will resolve to responses to be - /// sent back to the client. - stream: FuturesOrdered + Send>>>, -} - -impl TransactionStream { - /// Add a response future to a transaction stream. - pub fn push + Send + 'static>( - &mut self, - fut: T, - ) { - self.stream.push_back(fut.boxed()); - } - - /// Fetch the next message from the stream, if any. - async fn next(&mut self) -> Option { - self.stream.next().await - } -} - -impl Default for TransactionStream { - fn default() -> Self { - Self { - stream: Default::default(), - } - } -} - -impl std::fmt::Debug for TransactionStream { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - f.debug_struct("TransactionStream").finish() - } -} diff --git a/src/net/server/stream.rs b/src/net/server/stream.rs index 76febc667..6ff3263c3 100644 --- a/src/net/server/stream.rs +++ b/src/net/server/stream.rs @@ -38,7 +38,6 @@ use crate::utils::config::DefMinMax; use super::buf::VecBufSource; use super::connection::{self, Connection}; -use super::service::{CallResult, ServiceError}; use super::ServerCommand; use crate::base::wire::Composer; use tokio::io::{AsyncRead, AsyncWrite}; @@ -173,7 +172,7 @@ impl Clone for Config { Self { accept_connections_at_max: self.accept_connections_at_max, max_concurrent_connections: self.max_concurrent_connections, - connection_config: self.connection_config.clone(), + connection_config: self.connection_config, } } } @@ -278,13 +277,9 @@ pub struct StreamServer where Listener: AsyncAccept + Send + Sync, Buf: BufSource + Send + Sync + Clone, - Buf::Output: Octets + Send + Sync, + Buf::Output: Octets + Send + Sync + Unpin, Svc: Service + Send + Sync + Clone, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, - Svc::Target: Composer + Default + 'static, + Svc::Target: Composer + Default, // + 'static, { /// The configuration of the server. config: Arc>, @@ -327,12 +322,8 @@ impl StreamServer where Listener: AsyncAccept + Send + Sync, Buf: BufSource + Send + Sync + Clone, - Buf::Output: Octets + Send + Sync, + Buf::Output: Octets + Send + Sync + Unpin, Svc: Service + Send + Sync + Clone, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, Svc::Target: Composer + Default, { /// Creates a new [`StreamServer`] instance. @@ -411,12 +402,8 @@ impl StreamServer where Listener: AsyncAccept + Send + Sync, Buf: BufSource + Send + Sync + Clone, - Buf::Output: Octets + Debug + Send + Sync, + Buf::Output: Octets + Debug + Send + Sync + Unpin, Svc: Service + Send + Sync + Clone, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, Svc::Target: Composer + Default, { /// Get a reference to the source for this server. @@ -438,13 +425,9 @@ impl StreamServer where Listener: AsyncAccept + Send + Sync, Buf: BufSource + Send + Sync + Clone, - Buf::Output: Octets + Send + Sync, + Buf::Output: Octets + Send + Sync + Unpin, Svc: Service + Send + Sync + Clone, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, - Svc::Target: Composer + Default + 'static, + Svc::Target: Composer + Default, { /// Start the server. /// @@ -463,6 +446,7 @@ where Svc: 'static, Svc::Target: Send + Sync, Svc::Stream: Send, + Svc::Future: Send, { if let Err(err) = self.run_until_error().await { error!("Server stopped due to error: {err}"); @@ -536,12 +520,8 @@ impl StreamServer where Listener: AsyncAccept + Send + Sync, Buf: BufSource + Send + Sync + Clone, - Buf::Output: Octets + Send + Sync, + Buf::Output: Octets + Send + Sync + Unpin, Svc: Service + Send + Sync + Clone, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, Svc::Target: Composer + Default, { /// Accept stream connections until shutdown or fatal error. @@ -553,11 +533,9 @@ where Listener::Future: Send + 'static, Listener::StreamType: AsyncRead + AsyncWrite + Send + Sync + 'static, Svc: 'static, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, - Svc::Target: Send + Sync + 'static, + Svc::Target: Send + Sync, + Svc::Stream: Send, + Svc::Future: Send, { let mut command_rx = self.command_rx.clone(); @@ -677,17 +655,15 @@ where Listener::Future: Send + 'static, Listener::StreamType: AsyncRead + AsyncWrite + Send + Sync + 'static, Svc: 'static, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, - Svc::Target: Send + Sync + 'static, + Svc::Target: Composer + Send + Sync, + Svc::Stream: Send, + Svc::Future: Send, { // Work around the compiler wanting to move self to the async block by // preparing only those pieces of information from self for the new // connection handler that it actually needs. let config = ArcSwap::load(&self.config); - let conn_config = config.connection_config.clone(); + let conn_config = config.connection_config; let conn_command_rx = self.command_rx.clone(); let conn_service = self.service.clone(); let conn_buf = self.buf.clone(); @@ -744,13 +720,9 @@ impl Drop for StreamServer where Listener: AsyncAccept + Send + Sync, Buf: BufSource + Send + Sync + Clone, - Buf::Output: Octets + Send + Sync, + Buf::Output: Octets + Send + Sync + Unpin, Svc: Service + Send + Sync + Clone, - Svc::Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + Unpin, - Svc::Target: Composer + Default + 'static, + Svc::Target: Composer + Default, { fn drop(&mut self) { // Shutdown the StreamServer. Don't handle the failure case here as diff --git a/src/net/server/tests.rs b/src/net/server/tests.rs index ae8511108..04e4cc818 100644 --- a/src/net/server/tests.rs +++ b/src/net/server/tests.rs @@ -13,17 +13,21 @@ use std::vec::Vec; use tokio::io::{AsyncRead, AsyncWrite}; use tokio::time::sleep; use tokio::time::Instant; +use tracing::trace; use crate::base::Dname; use crate::base::MessageBuilder; use crate::base::Rtype; use crate::base::StaticCompressor; use crate::base::StreamTarget; +use crate::net::server::middleware::mandatory::MandatoryMiddlewareSvc; +use crate::net::server::stream::StreamServer; use super::buf::BufSource; use super::message::Request; use super::service::{CallResult, Service, ServiceError, ServiceFeedback}; use super::sock::AsyncAccept; +use core::future::{ready, Ready}; // use super::stream::StreamServer; /// Mock I/O which supplies a sequence of mock messages to the server at a @@ -310,9 +314,11 @@ impl MyService { impl Service> for MyService { type Target = Vec; type Stream = MySingle; + type Future = Ready; - fn call(&self, _msg: Request>) -> MySingle { - MySingle::new() + fn call(&self, request: Request>) -> Self::Future { + trace!("Processing request id {}", request.message().header().id()); + ready(MySingle::new()) } } @@ -347,65 +353,99 @@ fn mk_query() -> StreamTarget> { // signal that time has passed when in fact it actually hasn't, allowing a // time dependent test to run much faster without actual periods of // waiting to allow time to elapse. -// mk_query().as_dgram_slice().to_vec(), -// mk_query().as_dgram_slice().to_vec(), -// mk_query().as_dgram_slice().to_vec(), -// mk_query().as_dgram_slice().to_vec(), -// mk_query().as_dgram_slice().to_vec(), +#[tokio::test(flavor = "current_thread", start_paused = true)] +async fn service_test() { + // Initialize tracing based logging. Override with env var RUST_LOG, e.g. + // RUST_LOG=trace. + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_thread_ids(true) + .without_time() + .try_init() + .ok(); + + let (srv_handle, server_status_printer_handle) = { + let fast_client = MockClientConfig { + new_message_every: Duration::from_millis(100), + messages: VecDeque::from([ + mk_query().as_dgram_slice().to_vec(), + mk_query().as_dgram_slice().to_vec(), + mk_query().as_dgram_slice().to_vec(), + mk_query().as_dgram_slice().to_vec(), + mk_query().as_dgram_slice().to_vec(), ]), -// mk_query().as_dgram_slice().to_vec(), -// mk_query().as_dgram_slice().to_vec(), + client_port: 1, + }; + let slow_client = MockClientConfig { + new_message_every: Duration::from_millis(3000), + messages: VecDeque::from([ + mk_query().as_dgram_slice().to_vec(), + mk_query().as_dgram_slice().to_vec(), ]), + client_port: 2, + }; + let num_messages = + fast_client.messages.len() + slow_client.messages.len(); + let streams_to_read = VecDeque::from([fast_client, slow_client]); + let new_client_every = Duration::from_millis(2000); + let listener = MockListener::new(streams_to_read, new_client_every); + let ready_flag = listener.get_ready_flag(); + + let buf = MockBufSource; + let my_service = Arc::new(MandatoryMiddlewareSvc::new(MyService::new())); + // let my_service = Arc::new(MyService::new()); + let srv = + Arc::new(StreamServer::new(listener, buf, my_service.clone())); + + let metrics = srv.metrics(); + let server_status_printer_handle = tokio::spawn(async move { + loop { + sleep(Duration::from_millis(250)).await; + eprintln!( + "Server status: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}", + metrics.num_connections(), + metrics.num_inflight_requests(), + metrics.num_pending_writes(), + metrics.num_received_requests(), + metrics.num_sent_responses(), + ); + } + }); + + let spawned_srv = srv.clone(); + let srv_handle = tokio::spawn(async move { spawned_srv.run().await }); + + eprintln!("Clients sleeping"); + sleep(Duration::from_secs(1)).await; + + eprintln!("Clients connecting"); + ready_flag.store(true, Ordering::Relaxed); + + // Simulate a wait long enough that all simulated clients had time + // to connect, communicate and disconnect. + sleep(Duration::from_secs(20)).await; -// let metrics = srv.metrics(); -// let server_status_printer_handle = tokio::spawn(async move { -// loop { -// sleep(Duration::from_millis(250)).await; -// eprintln!( -// "Server status: #conn={:?}, #in-flight={}, #pending-writes={}, #msgs-recvd={}, #msgs-sent={}", -// metrics.num_connections(), -// metrics.num_inflight_requests(), -// metrics.num_pending_writes(), -// metrics.num_received_requests(), -// metrics.num_sent_responses(), -// ); -// } -// }); - -// let spawned_srv = srv.clone(); -// let srv_handle = tokio::spawn(async move { spawned_srv.run().await }); - -// eprintln!("Clients sleeping"); -// sleep(Duration::from_secs(1)).await; - -// eprintln!("Clients connecting"); -// ready_flag.store(true, Ordering::Relaxed); - -// // Simulate a wait long enough that all simulated clients had time -// // to connect, communicate and disconnect. -// sleep(Duration::from_secs(20)).await; - -// // Verify that all simulated clients connected. -// assert_eq!(0, srv.source().streams_remaining()); - -// // Verify that no requests or responses are in progress still in -// // the server. -// assert_eq!(srv.metrics().num_connections(), 0); -// assert_eq!(srv.metrics().num_inflight_requests(), 0); -// assert_eq!(srv.metrics().num_pending_writes(), 0); -// assert_eq!(srv.metrics().num_received_requests(), num_messages); -// assert_eq!(srv.metrics().num_sent_responses(), num_messages); - -// eprintln!("Shutting down"); -// srv.shutdown().unwrap(); -// eprintln!("Shutdown command sent"); - -// (srv_handle, server_status_printer_handle) -// }; - -// eprintln!("Waiting for service to shutdown"); -// let _ = srv_handle.await; - -// // Terminate the task that periodically prints the server status -// server_status_printer_handle.abort(); -// } + // Verify that all simulated clients connected. + assert_eq!(0, srv.source().streams_remaining()); + + // Verify that no requests or responses are in progress still in + // the server. + assert_eq!(srv.metrics().num_connections(), 0); + assert_eq!(srv.metrics().num_inflight_requests(), 0); + assert_eq!(srv.metrics().num_pending_writes(), 0); + assert_eq!(srv.metrics().num_received_requests(), num_messages); + assert_eq!(srv.metrics().num_sent_responses(), num_messages); + + eprintln!("Shutting down"); + srv.shutdown().unwrap(); + eprintln!("Shutdown command sent"); + + (srv_handle, server_status_printer_handle) + }; + + eprintln!("Waiting for service to shutdown"); + let _ = srv_handle.await; + + // Terminate the task that periodically prints the server status + server_status_printer_handle.abort(); +} diff --git a/src/net/server/util.rs b/src/net/server/util.rs index c84cf76cb..034d70861 100644 --- a/src/net/server/util.rs +++ b/src/net/server/util.rs @@ -1,9 +1,14 @@ //! Small utilities for building and working with servers. +use core::future::Ready; + use std::string::{String, ToString}; +use futures::stream::Once; use octseq::{Octets, OctetsBuilder}; +use smallvec::SmallVec; use tracing::warn; +use crate::base::iana::{OptRcode, OptionCode, Rcode}; use crate::base::message_builder::{ AdditionalBuilder, OptBuilder, PushError, QuestionBuilder, }; @@ -14,9 +19,7 @@ use crate::base::{MessageBuilder, ParsedDname, Rtype, StreamTarget}; use crate::rdata::AllRecordData; use super::message::Request; -use super::service::{CallResult, Service, ServiceError}; -use crate::base::iana::{OptionCode, Rcode}; -use smallvec::SmallVec; +use super::service::{Service, ServiceResult}; //----------- mk_builder_for_target() ---------------------------------------- @@ -33,7 +36,7 @@ where ) } -//------------ service_fn() -------------------------------------------------- +//------------ streaming_service_fn() ---------------------------------------- /// Helper to simplify making a [`Service`] impl. /// @@ -41,9 +44,17 @@ where /// those of its associated types, but this makes implementing it for simple /// cases quite verbose. /// -/// `service_fn()` enables you to write a slightly simpler function definition -/// that implements the [`Service`] trait than implementing [`Service`] -/// directly. +/// `streaming_service_fn()` enables you to write a slightly simpler function +/// definition that implements the [`Service`] trait than implementing +/// [`Service`] directly. +/// +/// The provided function must produce a future that results in a stream of +/// futures. The envisaged use case for producing a stream of results in the +/// context of DNS is zone transfers. If you need to implement zone transfer +/// or other streaming support yourself then you should implement [`Service`] +/// directly or via `streaming_service_fn`. +/// +/// Most users should probably use `service_fn` instead. /// /// # Example /// @@ -99,18 +110,19 @@ where /// [`Vec`]: std::vec::Vec /// [`CallResult`]: crate::net::server::service::CallResult /// [`Result::Ok`]: std::result::Result::Ok -pub fn service_fn( +pub fn service_fn( request_handler: T, metadata: Metadata, -) -> impl Service + Clone +) -> impl Service< + RequestOctets, + Target = Target, + Stream = Once>>, + Future = Ready>>>, +> + Clone where - RequestOctets: AsRef<[u8]>, - Stream: futures::stream::Stream< - Item = Result, ServiceError>, - > + Send - + 'static, + RequestOctets: AsRef<[u8]> + Send + Sync + Unpin, Metadata: Clone, - T: Fn(Request, Metadata) -> Stream + Clone, + T: Fn(Request, Metadata) -> ServiceResult + Clone, { move |request| request_handler(request, metadata.clone()) } @@ -156,7 +168,7 @@ pub(crate) fn to_pcap_text>( /// On internal error this function will attempt to set RCODE ServFail in the /// returned message. pub fn start_reply( - request: &Request, + msg: &Message, ) -> QuestionBuilder> where RequestOctets: Octets, @@ -167,7 +179,7 @@ where // RFC (1035?) compliance - copy question from request to response. let mut abort = false; let mut builder = builder.question(); - for rr in request.message().question() { + for rr in msg.question() { match rr { Ok(rr) => { if let Err(err) = builder.push(rr) { @@ -193,6 +205,30 @@ where builder } +//------------ mk_error_response --------------------------------------------- + +pub fn mk_error_response( + msg: &Message, + rcode: OptRcode, +) -> AdditionalBuilder> +where + RequestOctets: Octets, + Target: Composer + Default, +{ + let mut additional = start_reply(msg).additional(); + + // Note: if rcode is non-extended this will also correctly handle + // setting the rcode in the main message header. + if let Err(err) = add_edns_options(&mut additional, |_, opt| { + opt.set_rcode(rcode); + Ok(()) + }) { + warn!("Failed to set (extended) error '{rcode}' in response: {err}"); + } + + additional +} + //----------- add_edns_option ------------------------------------------------ /// Adds one or more EDNS OPT options to a response. diff --git a/src/zonetree/zone.rs b/src/zonetree/zone.rs index d160b14b3..bdac14649 100644 --- a/src/zonetree/zone.rs +++ b/src/zonetree/zone.rs @@ -15,7 +15,7 @@ use super::{ReadableZone, StoredDname, ZoneStore}; //------------ Zone ---------------------------------------------------------- /// A single DNS zone. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct Zone { store: Arc, } diff --git a/tests/net-server.rs b/tests/net-server.rs index c046b45cc..99b9d9fd6 100644 --- a/tests/net-server.rs +++ b/tests/net-server.rs @@ -1,11 +1,12 @@ #![cfg(feature = "net")] mod net; +use std::boxed::Box; use std::collections::VecDeque; use std::fs::File; -use std::future::Future; use std::net::IpAddr; use std::path::PathBuf; +use std::result::Result; use std::sync::Arc; use std::time::Duration; @@ -18,18 +19,18 @@ use domain::base::iana::Rcode; use domain::base::wire::Composer; use domain::base::{Dname, ToDname}; use domain::net::client::{dgram, stream}; +use domain::net::server; use domain::net::server::buf::VecBufSource; use domain::net::server::dgram::DgramServer; use domain::net::server::message::Request; -use domain::net::server::middleware::builder::MiddlewareBuilder; #[cfg(feature = "siphasher")] -use domain::net::server::middleware::processors::cookies::CookiesMiddlewareProcessor; -use domain::net::server::middleware::processors::edns::EdnsMiddlewareProcessor; -use domain::net::server::service::{ - CallResult, Service, ServiceError, Transaction, -}; +use domain::net::server::middleware::cookies::CookiesMiddlewareSvc; +use domain::net::server::middleware::edns::EdnsMiddlewareSvc; +use domain::net::server::middleware::mandatory::MandatoryMiddlewareSvc; +use domain::net::server::service::{CallResult, Service, ServiceResult}; use domain::net::server::stream::StreamServer; -use domain::net::server::util::{mk_builder_for_target, service_fn}; +use domain::net::server::util::mk_builder_for_target; +use domain::net::server::util::service_fn; use domain::zonefile::inplace::{Entry, ScannedRecord, Zonefile}; use net::stelline::channel::ClientServerChannel; @@ -59,32 +60,85 @@ async fn server_tests(#[files("test-data/server/*.rpl")] rpl_file: PathBuf) { // and which responses will be expected, and how the server that // answers them should be configured. + // Initialize tracing based logging. Override with env var RUST_LOG, e.g. + // RUST_LOG=trace. DEBUG level will show the .rpl file name, Stelline step + // numbers and types as they are being executed. + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_thread_ids(true) + .without_time() + .try_init() + .ok(); + let file = File::open(&rpl_file).unwrap(); let stelline = parse_file(&file, rpl_file.to_str().unwrap()); let server_config = parse_server_config(&stelline.config); // Create a service to answer queries received by the DNS servers. let zonefile = server_config.zonefile.clone(); - let service: Arc<_> = service_fn(test_service, zonefile).into(); - - // Create dgram and stream servers for answering requests - let (dgram_srv, dgram_conn, stream_srv, stream_conn) = - mk_servers(service, &server_config); - - // Create a client factory for sending requests - let client_factory = mk_client_factory(dgram_conn, stream_conn); - // Run the Stelline test! - let step_value = Arc::new(CurrStepValue::new()); - do_client(&stelline, &step_value, client_factory).await; + let with_cookies = server_config.cookies.enabled + && server_config.cookies.secret.is_some(); + + async fn finish_svc<'a, RequestOctets, Svc>( + svc: Svc, + server_config: ServerConfig<'a>, + stelline: &parse_stelline::Stelline, + ) where + RequestOctets: Octets + Send + Sync + Unpin, + Svc: Service + Send + Sync + 'static, + // TODO: Why are the following bounds needed to persuade the compiler + // that the `svc` value created _within the function_ (not the one + // passed in as an argument) is actually an impl of the Service trait? + MandatoryMiddlewareSvc, Svc>: Service + Send + Sync, + , Svc> as Service>::Target: + Composer + Default + Send + Sync, + , Svc> as Service>::Stream: + Send + Sync, + , Svc> as Service>::Future: + Send + Sync, + { + let svc = MandatoryMiddlewareSvc::, _>::new(svc); + let svc = Arc::new(svc); + + // Create dgram and stream servers for answering requests + let (dgram_srv, dgram_conn, stream_srv, stream_conn) = + mk_servers(svc, &server_config); + + // Create a client factory for sending requests + let client_factory = mk_client_factory(dgram_conn, stream_conn); + + // Run the Stelline test! + let step_value = Arc::new(CurrStepValue::new()); + do_client(stelline, &step_value, client_factory).await; + + // Await shutdown + if !dgram_srv.await_shutdown(Duration::from_secs(5)).await { + warn!("Datagram server did not shutdown on time."); + } - // Await shutdown - if !dgram_srv.await_shutdown(Duration::from_secs(5)).await { - warn!("Datagram server did not shutdown on time."); + if !stream_srv.await_shutdown(Duration::from_secs(5)).await { + warn!("Stream server did not shutdown on time."); + } } - if !stream_srv.await_shutdown(Duration::from_secs(5)).await { - warn!("Stream server did not shutdown on time."); + let svc = service_fn(test_service, zonefile); + if with_cookies { + #[cfg(not(feature = "siphasher"))] + panic!("The test uses cookies but the required 'siphasher' feature is not enabled."); + + #[cfg(feature = "siphasher")] + let secret = server_config.cookies.secret.unwrap(); + let secret = hex::decode(secret).unwrap(); + let secret = <[u8; 16]>::try_from(secret).unwrap(); + let svc = CookiesMiddlewareSvc::new(svc, secret) + .with_denied_ips(server_config.cookies.ip_deny_list.clone()); + finish_svc(svc, server_config, &stelline).await; + } else if server_config.edns_tcp_keepalive { + let svc = EdnsMiddlewareSvc::new(svc); + finish_svc(svc, server_config, &stelline).await; + } else { + finish_svc(svc, server_config, &stelline).await; } } @@ -104,6 +158,7 @@ where Svc: Service + Send + Sync + 'static, Svc::Future: Send, Svc::Target: Composer + Default + Send + Sync, + Svc::Stream: Send, { // Prepare middleware to be used by the DNS servers to pre-process // received requests and post-process created responses. @@ -182,49 +237,15 @@ fn mk_client_factory( ]) } -fn mk_server_configs( +fn mk_server_configs( config: &ServerConfig, -) -> ( - domain::net::server::dgram::Config, - domain::net::server::stream::Config, -) -where - RequestOctets: Octets, - Target: Composer + Default, -{ - let mut middleware = MiddlewareBuilder::minimal(); - - if config.cookies.enabled { - #[cfg(feature = "siphasher")] - if let Some(secret) = config.cookies.secret { - let secret = hex::decode(secret).unwrap(); - let secret = <[u8; 16]>::try_from(secret).unwrap(); - let processor = CookiesMiddlewareProcessor::new(secret); - let processor = processor - .with_denied_ips(config.cookies.ip_deny_list.clone()); - middleware.push(processor.into()); - } - - #[cfg(not(feature = "siphasher"))] - panic!("The test uses cookies but the required 'siphasher' feature is not enabled."); - } - - if config.edns_tcp_keepalive { - let processor = EdnsMiddlewareProcessor::new(); - middleware.push(processor.into()); - } - - let middleware = middleware.build(); - - let mut dgram_config = domain::net::server::dgram::Config::default(); - dgram_config.set_middleware_chain(middleware.clone()); +) -> (server::dgram::Config, server::stream::Config) { + let dgram_config = server::dgram::Config::default(); - let mut stream_config = domain::net::server::stream::Config::default(); + let mut stream_config = server::stream::Config::default(); if let Some(idle_timeout) = config.idle_timeout { - let mut connection_config = - domain::net::server::ConnectionConfig::default(); + let mut connection_config = server::ConnectionConfig::default(); connection_config.set_idle_timeout(idle_timeout); - connection_config.set_middleware_chain(middleware); stream_config.set_connection_config(connection_config); } @@ -248,13 +269,7 @@ where fn test_service( request: Request>, zonefile: Zonefile, -) -> Result< - Transaction< - Vec, - impl Future>, ServiceError>> + Send, - >, - ServiceError, -> { +) -> ServiceResult> { fn as_record_and_dname( r: ScannedRecord, ) -> Option<(ScannedRecord, Dname>)> { @@ -275,43 +290,41 @@ fn test_service( } trace!("Service received request"); - Ok(Transaction::single(async move { - trace!("Service is constructing a single response"); - // If given a single question: - let answer = request - .message() - .sole_question() - .ok() - .and_then(|q| { - // Walk the zone to find the queried name - zonefile - .clone() - .filter_map(as_records) - .filter_map(as_record_and_dname) - .find(|(_record, dname)| dname == q.qname()) - }) - .map_or_else( - || { - // The Qname was not found in the zone: - mk_builder_for_target() - .start_answer(request.message(), Rcode::NXDOMAIN) - .unwrap() - }, - |(record, _)| { - // Respond with the found record: - let mut answer = mk_builder_for_target() - .start_answer(request.message(), Rcode::NOERROR) - .unwrap(); - // As we serve all answers from our own zones we are the - // authority for the domain in question. - answer.header_mut().set_aa(true); - answer.push(record).unwrap(); - answer - }, - ); - - Ok(CallResult::new(answer.additional())) - })) + trace!("Service is constructing a single response"); + // If given a single question: + let answer = request + .message() + .sole_question() + .ok() + .and_then(|q| { + // Walk the zone to find the queried name + zonefile + .clone() + .filter_map(as_records) + .filter_map(as_record_and_dname) + .find(|(_record, dname)| dname == q.qname()) + }) + .map_or_else( + || { + // The Qname was not found in the zone: + mk_builder_for_target() + .start_answer(request.message(), Rcode::NXDOMAIN) + .unwrap() + }, + |(record, _)| { + // Respond with the found record: + let mut answer = mk_builder_for_target() + .start_answer(request.message(), Rcode::NOERROR) + .unwrap(); + // As we serve all answers from our own zones we are the + // authority for the domain in question. + answer.header_mut().set_aa(true); + answer.push(record).unwrap(); + answer + }, + ); + + Ok(CallResult::new(answer.additional())) } //----------- Stelline config block parsing ----------------------------------- From 9fecccecdc3d559498448d7694a34d332a870cee Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 30 Apr 2024 16:41:43 +0200 Subject: [PATCH 13/28] cargo fmt. --- src/net/server/service.rs | 4 +++- src/net/server/tests.rs | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/net/server/service.rs b/src/net/server/service.rs index 854320a22..46712372f 100644 --- a/src/net/server/service.rs +++ b/src/net/server/service.rs @@ -240,7 +240,9 @@ where F: Fn(Request) -> ServiceResult, { type Target = Target; - type Stream = futures::stream::Once>>; + type Stream = futures::stream::Once< + core::future::Ready>, + >; type Future = core::future::Ready; fn call(&self, request: Request) -> Self::Future { diff --git a/src/net/server/tests.rs b/src/net/server/tests.rs index 04e4cc818..9b181f418 100644 --- a/src/net/server/tests.rs +++ b/src/net/server/tests.rs @@ -293,7 +293,8 @@ impl futures::stream::Stream for MySingle { idle_timeout: Some(Duration::from_millis(5000)), }; - let call_result = CallResult::new(response).with_feedback(command); + let call_result = + CallResult::new(response).with_feedback(command); self.done = true; Poll::Ready(Some(Ok(call_result))) @@ -392,7 +393,8 @@ async fn service_test() { let ready_flag = listener.get_ready_flag(); let buf = MockBufSource; - let my_service = Arc::new(MandatoryMiddlewareSvc::new(MyService::new())); + let my_service = + Arc::new(MandatoryMiddlewareSvc::new(MyService::new())); // let my_service = Arc::new(MyService::new()); let srv = Arc::new(StreamServer::new(listener, buf, my_service.clone())); From 3a09c12a6476d9e3bafa899947c2ac4d1934ee95 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 30 Apr 2024 20:42:15 +0200 Subject: [PATCH 14/28] Merge branch 'main' into service-layering --- .github/workflows/ci.yml | 2 +- Cargo.toml | 47 +- Changelog.md | 55 ++- examples/client-transports.rs | 22 +- examples/client.rs | 9 +- examples/common/serve-utils.rs | 6 +- examples/download-rust-lang.rs | 6 +- examples/lookup.rs | 16 +- examples/query-zone.rs | 10 +- examples/resolv-sync.rs | 4 +- examples/serve-zone.rs | 2 +- examples/server-transports.rs | 45 +- src/base/cmp.rs | 4 +- src/base/iana/class.rs | 10 + src/base/iana/macros.rs | 25 +- src/base/iana/rcode.rs | 112 +++-- src/base/message.rs | 66 +-- src/base/message_builder.rs | 72 +-- src/base/mod.rs | 2 +- src/base/name/{dname.rs => absolute.rs} | 613 +++++++++++++----------- src/base/name/builder.rs | 150 ++++-- src/base/name/chain.rs | 194 ++++---- src/base/name/mod.rs | 56 +-- src/base/name/parsed.rs | 233 +++++---- src/base/name/relative.rs | 496 ++++++++++--------- src/base/name/traits.rs | 151 +++--- src/base/name/uncertain.rs | 197 ++++---- src/base/net/nostd.rs | 6 + src/base/opt/chain.rs | 28 +- src/base/opt/macros.rs | 4 +- src/base/opt/mod.rs | 10 +- src/base/question.rs | 42 +- src/base/record.rs | 78 +-- src/base/scan.rs | 12 +- src/base/wire.rs | 8 +- src/net/client/cache.rs | 27 +- src/net/client/mod.rs | 4 +- src/net/client/protocol.rs | 27 +- src/net/client/request.rs | 8 +- src/net/server/message.rs | 18 +- src/net/server/middleware/cookies.rs | 154 +++--- src/net/server/middleware/edns.rs | 28 +- src/net/server/middleware/mandatory.rs | 25 +- src/net/server/service.rs | 8 +- src/net/server/tests.rs | 31 +- src/net/server/util.rs | 210 ++++++-- src/rdata/dname.rs | 6 +- src/rdata/dnssec.rs | 74 +-- src/rdata/macros.rs | 88 ++-- src/rdata/mod.rs | 8 +- src/rdata/rfc1035/minfo.rs | 46 +- src/rdata/rfc1035/mod.rs | 4 +- src/rdata/rfc1035/mx.rs | 36 +- src/rdata/rfc1035/{dname.rs => name.rs} | 22 +- src/rdata/rfc1035/soa.rs | 46 +- src/rdata/srv.rs | 34 +- src/rdata/svcb/rdata.rs | 34 +- src/rdata/tsig.rs | 32 +- src/rdata/zonemd.rs | 4 +- src/resolv/lookup/addr.rs | 89 +--- src/resolv/lookup/host.rs | 14 +- src/resolv/lookup/srv.rs | 36 +- src/resolv/resolver.rs | 6 +- src/resolv/stub/conf.rs | 45 +- src/resolv/stub/mod.rs | 63 +-- src/sign/key.rs | 6 +- src/sign/records.rs | 44 +- src/sign/ring.rs | 4 +- src/tsig/interop.rs | 22 +- src/tsig/mod.rs | 28 +- src/validate.rs | 48 +- src/zonefile/inplace.rs | 67 ++- src/zonefile/mod.rs | 3 - src/zonetree/answer.rs | 9 +- src/{zonefile => zonetree}/error.rs | 95 +++- src/zonetree/in_memory/builder.rs | 18 +- src/zonetree/in_memory/nodes.rs | 10 +- src/zonetree/in_memory/read.rs | 6 +- src/zonetree/mod.rs | 19 +- src/{zonefile => zonetree}/parsed.rs | 64 ++- src/zonetree/traits.rs | 25 +- src/zonetree/tree.rs | 62 +-- src/zonetree/types.rs | 29 +- src/zonetree/walk.rs | 8 +- src/zonetree/zone.rs | 7 +- tests/net-client-cache.rs | 4 +- tests/net-server.rs | 6 +- tests/net/stelline/matches.rs | 9 +- tests/net/stelline/parse_query.rs | 36 +- 89 files changed, 2493 insertions(+), 2156 deletions(-) rename src/base/name/{dname.rs => absolute.rs} (74%) rename src/rdata/rfc1035/{dname.rs => name.rs} (91%) rename src/{zonefile => zonetree}/error.rs (63%) rename src/{zonefile => zonetree}/parsed.rs (83%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 777d5ae4c..c2863e432 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ jobs: strategy: matrix: os: [ubuntu-latest, windows-latest, macOS-latest] - rust: [1.67.0, stable, beta, nightly] + rust: [1.70.0, stable, beta, nightly] env: RUSTFLAGS: "-D warnings" steps: diff --git a/Cargo.toml b/Cargo.toml index 99bb3ff51..7073e2877 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "domain" -version = "0.10.0-dev" -rust-version = "1.67.0" +version = "0.10.1-dev" +rust-version = "1.70.0" edition = "2021" authors = ["NLnet Labs "] description = "A DNS library for Rust." @@ -30,7 +30,7 @@ futures-util = { version = "0.3", optional = true } heapless = { version = "0.8", optional = true } hex = { version = "0.4", optional = true } libc = { version = "0.2.153", default-features = false, optional = true } # 0.2.79 is the first version that has IP_PMTUDISC_OMIT -parking_lot = { version = "0.11.2", optional = true } +parking_lot = { version = "0.12.2", optional = true } moka = { version = "0.12.3", optional = true, features = ["future"] } proc-macro2 = { version = "1.0.69", optional = true } # Force proc-macro2 to at least 1.0.69 for minimal-version build ring = { version = "0.17", optional = true } @@ -38,15 +38,11 @@ serde = { version = "1.0.130", optional = true, features = ["derive"] } siphasher = { version = "1", optional = true } smallvec = { version = "1.3", optional = true } tokio = { version = "1.33", optional = true, features = ["io-util", "macros", "net", "time", "sync", "rt-multi-thread" ] } -tokio-rustls = { version = "0.24", optional = true, features = [] } +tokio-rustls = { version = "0.26", optional = true, default-features = false } tracing = { version = "0.1.40", optional = true } # For testing in integration tests: -mock_instant = { version = "0.3.2", optional = true, features = ["sync"] } - -[target.'cfg(macos)'.dependencies] -# specifying this overrides minimum-version mio's 0.2.69 libc dependency, which allows the build to work -libc = { version = "0.2.153", default-features = false, optional = true } +#mock_instant = { version = "0.4.0", optional = true } [features] default = ["std", "rand"] @@ -58,14 +54,14 @@ serde = ["dep:serde", "octseq/serde"] sign = ["std"] smallvec = ["dep:smallvec", "octseq/smallvec"] std = ["bytes?/std", "octseq/std", "time/std"] -net = ["bytes", "futures-util", "rand", "std", "tokio", "tokio-rustls"] +net = ["bytes", "futures-util", "rand", "std", "tokio"] tsig = ["bytes", "ring", "smallvec"] validate = ["std", "ring"] zonefile = ["bytes", "serde", "std"] # Unstable features -unstable-client-transport = [ "moka", "tracing" ] -unstable-server-transport = ["arc-swap", "chrono/clock", "hex", "libc", "tracing"] +unstable-client-transport = [ "moka", "net", "tracing" ] +unstable-server-transport = ["arc-swap", "chrono/clock", "hex", "libc", "net", "tracing"] unstable-zonetree = ["futures", "parking_lot", "serde", "tokio", "tracing"] # Test features @@ -77,22 +73,19 @@ unstable-zonetree = ["futures", "parking_lot", "serde", "tokio", "tracing"] #mock-time = ["mock_instant"] [dev-dependencies] -rstest = "0.18.2" -rustls = { version = "0.21.9" } -serde_test = "1.0.130" -serde_json = "1.0.113" -serde_yaml = "0.9" -tokio = { version = "1.37", features = ["rt-multi-thread", "io-util", "net", "test-util"] } -tokio-test = "0.4" -webpki-roots = { version = "0.25" } - -rustls-pemfile = { version = "1.0" } -socket2 = { version = "0.5.5" } -tokio-tfo = { version = "0.2.0" } - -# For tracing support in integration tests: -lazy_static = { version = "1.4.0" } # Force lazy_static to > 1.0.0 for https://github.com/rust-lang-nursery/lazy-static.rs/pull/107 +lazy_static = { version = "1.4.0" } +rstest = "0.19.0" +rustls-pemfile = { version = "2.1.2" } +serde_test = "1.0.130" +serde_json = "1.0.113" +serde_yaml = "0.9" +socket2 = { version = "0.5.5" } +tokio = { version = "1.37", features = ["rt-multi-thread", "io-util", "net", "test-util"] } +tokio-rustls = { version = "0.26", default-features = false, features = [ "ring", "logging", "tls12" ] } +tokio-test = "0.4" +tokio-tfo = { version = "0.2.0" } tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } +webpki-roots = { version = "0.26" } # For the "mysql-zone" example #sqlx = { version = "0.6", features = [ "runtime-tokio-native-tls", "mysql" ] } diff --git a/Changelog.md b/Changelog.md index e7d636c16..7fdb35087 100644 --- a/Changelog.md +++ b/Changelog.md @@ -4,14 +4,31 @@ Breaking changes +New + +Bug fixes + +Other changes + + +## 0.10.0 + +Released 2024-04-30. + +Breaking changes + +* All types and functions referring to domain names have been changed from + using the term “dname” to just “name.” For instance, `Dname` has become + `Name`, `ToDname` has become `ToName`, and `ToDname::to_dname` has become + `ToName::to_name`. ([#290]) +* The `ToName` and `ToRelativeName` traits have been changed to have a + pair of methods a la `try_to_name` and `to_name` for octets builders + with limited and unlimited buffers, reflecting the pattern used + elsewhere. ([#285]) * The types for IANA-registered parameters in `base::iana` have been changed from enums to a newtypes around their underlying integer type and associated constants for the registered values. (This was really - always the better way to structure this.) ([#276]) -* The `ToDname` and `ToRelativeDname` traits have been changed to have a - pair of methods a la `try_to_dname` and `to_dname` for octets builders - with limited and unlimited buffers, reflecting the pattern used - elsewhere. ([#285]) + always the better way to structure this.) ([#276], [#298]) * The `Txt` record data type now rejects empty record data as invalid. As a consequence `TxtBuilder` converts an empty builder into TXT record data consisting of one empty character string which requires @@ -31,6 +48,10 @@ Breaking changes * The stub resolver now uses the new client transports. This doesn’t change how it is used but does change how it queries the configured servers. ([#215]) +* The sub resolver’s server configuration `Transport` type has been + changed to be either `Transport::UdpTcp` for trying UDP and if that + leads to a truncated answer try TCP and `Transport::Tcp` for only trying + TCP. The stub resolver uses these accordingly now ([#296]) * Many error types have been changed from enums to structs that hide internal error details. Enums have been kept for errors where distinguishing variants might be meaningful for dealing with the error. @@ -39,6 +60,7 @@ Breaking changes * Split RRSIG timestamp handling from `Serial` into a new type `rdata::dnssec::Timestamp`. ([#294]) * Upgraded `octseq` to 0.5. ([#257]) +* The minimum Rust version is now 1.70. ([#304]) New @@ -52,6 +74,11 @@ New the specific options types that didn’t have them yet. ([#257]) * Added missing ordering impls to `ZoneRecordData`, `AllRecordData`, `Opt`, and `SvcbRdata`. ([#293]) +* Added `Name::reverse_from_addr` that creates a domain name for the + reverse lookup of an IP address. ([#289]) +* Added `OptBuilder::clone_from` to replace the OPT record with the + content of another OPT record. ([#299]) +* Added `Message::for_slice_ref` that returns a `Message<&[u8]>`. ([#300]) Bug fixes @@ -66,22 +93,19 @@ Bug fixes Unstable features -* Add the module `net::client` with experimental support for client +* Added the module `net::client` with experimental support for client message transport, i.e., sending of requests and receiving responses as well as caching of responses. This is gated by the `unstable-client-transport` feature. ([#215],[#275]) -* Add the module `net::server` with experimental support for server +* Added the module `net::server` with experimental support for server transports, processing requests through a middleware chain and a service trait. This is gated by the `unstable-server-transport` feature. ([#274]) -* Add the module `zonetree` providing basic traits representing a +* Added the module `zonetree` providing basic traits representing a collection of zones and their data. The `zonetree::in_memory` module - provides an in-memory implementation. The `zonefile::parsed` module + provides an in-memory implementation. The `zonetree::parsed` module provides a way to classify RRsets before inserting them into a tree. This is gated by the `unstable-zonetree` feature. ([#286]) - - -Other changes [#215]: https://github.com/NLnetLabs/domain/pull/215 @@ -102,8 +126,15 @@ Other changes [#285]: https://github.com/NLnetLabs/domain/pull/285 [#286]: https://github.com/NLnetLabs/domain/pull/286 [#288]: https://github.com/NLnetLabs/domain/pull/288 +[#289]: https://github.com/NLnetLabs/domain/pull/289 +[#290]: https://github.com/NLnetLabs/domain/pull/290 [#292]: https://github.com/NLnetLabs/domain/pull/292 [#293]: https://github.com/NLnetLabs/domain/pull/293 +[#296]: https://github.com/NLnetLabs/domain/pull/296 +[#298]: https://github.com/NLnetLabs/domain/pull/298 +[#299]: https://github.com/NLnetLabs/domain/pull/299 +[#300]: https://github.com/NLnetLabs/domain/pull/300 +[#304]: https://github.com/NLnetLabs/domain/pull/304 [@torin-carey]: https://github.com/torin-carey [@hunts]: https://github.com/hunts diff --git a/examples/client-transports.rs b/examples/client-transports.rs index 682f033c3..7166d0a08 100644 --- a/examples/client-transports.rs +++ b/examples/client-transports.rs @@ -1,6 +1,6 @@ -/// Using the `domain::net::client` module for sending a query. -use domain::base::Dname; use domain::base::MessageBuilder; +/// Using the `domain::net::client` module for sending a query. +use domain::base::Name; use domain::base::Rtype; use domain::net::client::cache; use domain::net::client::dgram; @@ -15,7 +15,7 @@ use std::str::FromStr; use std::time::Duration; use tokio::net::TcpStream; use tokio::time::timeout; -use tokio_rustls::rustls::{ClientConfig, OwnedTrustAnchor, RootCertStore}; +use tokio_rustls::rustls::{ClientConfig, RootCertStore}; #[tokio::main] async fn main() { @@ -30,7 +30,7 @@ async fn main() { let mut msg = MessageBuilder::new_vec(); msg.header_mut().set_rd(true); let mut msg = msg.question(); - msg.push((Dname::vec_from_str("example.com").unwrap(), Rtype::AAAA)) + msg.push((Name::vec_from_str("example.com").unwrap(), Rtype::AAAA)) .unwrap(); let req = RequestMessage::new(msg); @@ -135,20 +135,12 @@ async fn main() { drop(request); // Some TLS boiler plate for the root certificates. - let mut root_store = RootCertStore::empty(); - root_store.add_trust_anchors(webpki_roots::TLS_SERVER_ROOTS.iter().map( - |ta| { - OwnedTrustAnchor::from_subject_spki_name_constraints( - ta.subject, - ta.spki, - ta.name_constraints, - ) - }, - )); + let root_store = RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.into(), + }; // TLS config let client_config = ClientConfig::builder() - .with_safe_defaults() .with_root_certificates(root_store) .with_no_client_auth(); diff --git a/examples/client.rs b/examples/client.rs index 1b3cb90cf..fc59286da 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -3,7 +3,7 @@ use std::str::FromStr; use domain::base::opt::AllOptData; use domain::base::{ - Dname, Message, MessageBuilder, Rtype, StaticCompressor, StreamTarget, + Message, MessageBuilder, Name, Rtype, StaticCompressor, StreamTarget, }; use domain::rdata::AllRecordData; @@ -22,11 +22,8 @@ fn create_message() -> StreamTarget> { let mut msg = msg.question(); // Add a hard-coded question and proceed to the answer section. - msg.push(( - Dname::>::from_str("example.com.").unwrap(), - Rtype::A, - )) - .unwrap(); + msg.push((Name::>::from_str("example.com.").unwrap(), Rtype::A)) + .unwrap(); // Skip to the additional section let mut msg = msg.additional(); diff --git a/examples/common/serve-utils.rs b/examples/common/serve-utils.rs index c1becae82..0978481c2 100644 --- a/examples/common/serve-utils.rs +++ b/examples/common/serve-utils.rs @@ -1,10 +1,10 @@ use bytes::Bytes; -use domain::base::{Dname, Message, MessageBuilder, ParsedDname, Rtype}; +use domain::base::{Message, MessageBuilder, Name, ParsedName, Rtype}; use domain::rdata::ZoneRecordData; use domain::zonetree::Answer; pub fn generate_wire_query( - qname: &Dname, + qname: &Name, qtype: Rtype, ) -> Message> { let query = MessageBuilder::new_vec(); @@ -99,7 +99,7 @@ pub fn print_dig_style_response( for record in section { let record = record .unwrap() - .into_record::>>() + .into_record::>>() .unwrap() .unwrap(); diff --git a/examples/download-rust-lang.rs b/examples/download-rust-lang.rs index 06d126b25..8f789d3d4 100644 --- a/examples/download-rust-lang.rs +++ b/examples/download-rust-lang.rs @@ -1,6 +1,6 @@ use std::str::FromStr; -use domain::base::name::Dname; +use domain::base::name::Name; use domain::resolv::StubResolver; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; @@ -9,9 +9,7 @@ use tokio::net::TcpStream; async fn main() { let resolver = StubResolver::new(); let addr = match resolver - .lookup_host( - &Dname::>::from_str("www.rust-lang.org").unwrap(), - ) + .lookup_host(&Name::>::from_str("www.rust-lang.org").unwrap()) .await { Ok(addr) => addr, diff --git a/examples/lookup.rs b/examples/lookup.rs index 898d6fb23..ff7709e6a 100644 --- a/examples/lookup.rs +++ b/examples/lookup.rs @@ -1,21 +1,17 @@ -use domain::base::name::UncertainDname; +use domain::base::name::UncertainName; use domain::resolv::StubResolver; use std::env; use std::net::IpAddr; use std::str::FromStr; -async fn forward(resolver: &StubResolver, name: UncertainDname>) { +async fn forward(resolver: &StubResolver, name: UncertainName>) { let answer = match name { - UncertainDname::Absolute(ref name) => { - resolver.lookup_host(name).await - } - UncertainDname::Relative(ref name) => { - resolver.search_host(name).await - } + UncertainName::Absolute(ref name) => resolver.lookup_host(name).await, + UncertainName::Relative(ref name) => resolver.search_host(name).await, }; match answer { Ok(answer) => { - if let UncertainDname::Relative(_) = name { + if let UncertainName::Relative(_) = name { println!("Found answer for {}", answer.qname()); } let canon = answer.canonical_name(); @@ -55,7 +51,7 @@ async fn main() { for name in names { if let Ok(addr) = IpAddr::from_str(&name) { reverse(&resolver, addr).await; - } else if let Ok(name) = UncertainDname::from_str(&name) { + } else if let Ok(name) = UncertainName::from_str(&name) { forward(&resolver, name).await; } else { println!("Not a domain name: {name}"); diff --git a/examples/query-zone.rs b/examples/query-zone.rs index 2313f5d92..e2b044448 100644 --- a/examples/query-zone.rs +++ b/examples/query-zone.rs @@ -8,7 +8,7 @@ use std::{process::exit, str::FromStr}; use bytes::Bytes; use domain::base::iana::{Class, Rcode}; use domain::base::record::ComposeRecord; -use domain::base::{Dname, ParsedDname, Rtype}; +use domain::base::{Name, ParsedName, Rtype}; use domain::base::{ParsedRecord, Record}; use domain::rdata::ZoneRecordData; use domain::zonefile::inplace; @@ -130,7 +130,7 @@ fn main() { #[allow(clippy::type_complexity)] fn process_dig_style_args( args: env::Args, -) -> Result<(Verbosity, Vec<(String, File)>, Rtype, Dname, bool), String> +) -> Result<(Verbosity, Vec<(String, File)>, Rtype, Name, bool), String> { let mut abort_with_usage = false; let mut verbosity = Verbosity::Normal; @@ -178,7 +178,7 @@ fn process_dig_style_args( .map_err(|err| format!("Cannot parse qtype: {err}"))?; i += 1; - let qname = Dname::::from_str(&args[i]) + let qname = Name::::from_str(&args[i]) .map_err(|err| format!("Cannot parse qname: {err}"))?; Ok((verbosity, zone_files, qtype, qname, short)) @@ -187,7 +187,7 @@ fn process_dig_style_args( } } -fn dump_rrset(owner: Dname, rrset: &Rrset) { +fn dump_rrset(owner: Name, rrset: &Rrset) { // // The following code renders an owner + rrset (IN class, TTL, RDATA) // into zone presentation format. This can be used for diagnostic @@ -200,7 +200,7 @@ fn dump_rrset(owner: Dname, rrset: &Rrset) { let mut parser = Parser::from_ref(&target); if let Ok(parsed_record) = ParsedRecord::parse(&mut parser) { if let Ok(Some(record)) = parsed_record - .into_record::>>() + .into_record::>>() { println!("> {record}"); } diff --git a/examples/resolv-sync.rs b/examples/resolv-sync.rs index eed325d3b..854aae491 100644 --- a/examples/resolv-sync.rs +++ b/examples/resolv-sync.rs @@ -1,4 +1,4 @@ -use domain::base::name::Dname; +use domain::base::name::Name; use domain::base::Rtype; use domain::rdata::AllRecordData; use domain::resolv::StubResolver; @@ -9,7 +9,7 @@ fn main() { let mut args = env::args().skip(1); let name = args .next() - .and_then(|arg| Dname::>::from_str(&arg).ok()); + .and_then(|arg| Name::>::from_str(&arg).ok()); let rtype = args.next().and_then(|arg| Rtype::from_str(&arg).ok()); let (name, rtype) = match (name, rtype) { (Some(name), Some(rtype)) => (name, rtype), diff --git a/examples/serve-zone.rs b/examples/serve-zone.rs index 849cef8e8..f5478eeff 100644 --- a/examples/serve-zone.rs +++ b/examples/serve-zone.rs @@ -24,7 +24,7 @@ use tokio::net::{TcpListener, UdpSocket}; use tracing_subscriber::EnvFilter; use domain::base::iana::Rcode; -use domain::base::ToDname; +use domain::base::ToName; use domain::net::server::buf::VecBufSource; use domain::net::server::dgram::DgramServer; use domain::net::server::message::Request; diff --git a/examples/server-transports.rs b/examples/server-transports.rs index aec0c7e0d..fc8e5de68 100644 --- a/examples/server-transports.rs +++ b/examples/server-transports.rs @@ -8,7 +8,6 @@ use std::fs::File; use std::io; use std::io::BufReader; use std::net::SocketAddr; -use std::path::Path; use std::pin::Pin; use std::sync::Arc; use std::sync::RwLock; @@ -16,10 +15,9 @@ use std::sync::RwLock; use futures::channel::mpsc::unbounded; use futures::stream::{once, Empty, Once, Stream}; use octseq::{FreezeBuilder, Octets}; -use rustls_pemfile::{certs, rsa_private_keys}; use tokio::net::{TcpListener, TcpSocket, TcpStream, UdpSocket}; use tokio::time::Instant; -use tokio_rustls::rustls::{Certificate, PrivateKey}; +use tokio_rustls::rustls; use tokio_rustls::TlsAcceptor; use tokio_tfo::{TfoListener, TfoStream}; use tracing_subscriber::EnvFilter; @@ -28,7 +26,7 @@ use domain::base::iana::{Class, Rcode}; use domain::base::message_builder::{AdditionalBuilder, PushError}; use domain::base::name::ToLabelIter; use domain::base::wire::Composer; -use domain::base::{Dname, MessageBuilder, Rtype, Serial, StreamTarget, Ttl}; +use domain::base::{MessageBuilder, Name, Rtype, Serial, StreamTarget, Ttl}; use domain::net::server::buf::VecBufSource; use domain::net::server::dgram::DgramServer; use domain::net::server::message::Request; @@ -62,7 +60,7 @@ where let mut answer = builder.start_answer(msg.message(), Rcode::NOERROR).unwrap(); answer.push(( - Dname::root_ref(), + Name::root_ref(), Class::IN, 86400, A::from_octets(192, 0, 2, 1), @@ -78,12 +76,12 @@ where Target: Octets + Composer + FreezeBuilder, ::AppendError: fmt::Debug, { - let mname: Dname> = "a.root-servers.net".parse().unwrap(); + let mname: Name> = "a.root-servers.net".parse().unwrap(); let rname = "nstld.verisign-grs.com".parse().unwrap(); let mut answer = builder.start_answer(msg.message(), Rcode::NOERROR).unwrap(); answer.push(( - Dname::root_slice(), + Name::root_slice(), 86390, Soa::new( mname, @@ -229,7 +227,7 @@ fn name_to_ip(request: Request>) -> ServiceResult> { .start_answer(request.message(), Rcode::NOERROR) .unwrap(); answer - .push((Dname::root_ref(), Class::IN, 86400, a_rec)) + .push((Name::root_ref(), Class::IN, 86400, a_rec)) .unwrap(); out_answer = Some(answer); } @@ -833,31 +831,22 @@ async fn main() { // ----------------------------------------------------------------------- // Demonstrate using a TLS secured TCP DNS server. - fn load_certs(path: &Path) -> io::Result> { - certs(&mut BufReader::new(File::open(path)?)) - .map_err(|_| { - io::Error::new(io::ErrorKind::InvalidInput, "invalid cert") - }) - .map(|mut certs| certs.drain(..).map(Certificate).collect()) - } - - fn load_keys(path: &Path) -> io::Result> { - rsa_private_keys(&mut BufReader::new(File::open(path)?)) - .map_err(|_| { - io::Error::new(io::ErrorKind::InvalidInput, "invalid key") - }) - .map(|mut keys| keys.drain(..).map(PrivateKey).collect()) - } - // Credit: The sample.(pem|rsa) files used here were taken from // https://github.com/rustls/hyper-rustls/blob/main/examples/ - let certs = load_certs(Path::new("examples/sample.pem")).unwrap(); - let mut keys = load_keys(Path::new("examples/sample.rsa")).unwrap(); + let certs = rustls_pemfile::certs(&mut BufReader::new( + File::open("examples/sample.pem").unwrap(), + )) + .collect::, _>>() + .unwrap(); + let key = rustls_pemfile::private_key(&mut BufReader::new( + File::open("examples/sample.rsa").unwrap(), + )) + .unwrap() + .unwrap(); let config = rustls::ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() - .with_single_cert(certs, keys.remove(0)) + .with_single_cert(certs, key) .unwrap(); let acceptor = TlsAcceptor::from(Arc::new(config)); let listener = TcpListener::bind("127.0.0.1:8443").await.unwrap(); diff --git a/src/base/cmp.rs b/src/base/cmp.rs index 025f87727..b2cebc5e0 100644 --- a/src/base/cmp.rs +++ b/src/base/cmp.rs @@ -32,8 +32,8 @@ use core::cmp::Ordering; /// from right to left (i.e, starting from the root label) with each pair of /// labels compared as octet sequences with ASCII letters lowercased /// before comparison. The `name_cmp` methods of the -/// [`ToDname`][crate::base::name::ToDname::name_cmp] and -/// [`ToRelativeDname`][crate::base::name::ToRelativeDname::name_cmp] +/// [`ToDname`][crate::base::name::ToName::name_cmp] and +/// [`ToRelativeDname`][crate::base::name::ToRelativeName::name_cmp] /// traits can be used to implement this canonical order for name types. /// /// Resource records within an RR set are ordered by comparing the canonical diff --git a/src/base/iana/class.rs b/src/base/iana/class.rs index 6e7a80894..51caef05e 100644 --- a/src/base/iana/class.rs +++ b/src/base/iana/class.rs @@ -62,6 +62,7 @@ int_enum_str_with_prefix!(Class, "CLASS", b"CLASS", u16, "unknown class"); #[cfg(test)] mod test { + #[cfg(feature = "serde")] #[test] fn ser_de() { @@ -73,4 +74,13 @@ mod test { assert_tokens(&Class::IN.compact(), &[Token::U16(1)]); assert_tokens(&Class(5).compact(), &[Token::U16(5)]); } + + #[cfg(feature = "std")] + #[test] + fn debug() { + use super::Class; + + assert_eq!(format!("{:?}", Class::IN), "Class::IN"); + assert_eq!(format!("{:?}", Class(69)), "Class(69)"); + } } diff --git a/src/base/iana/macros.rs b/src/base/iana/macros.rs index 9a1b1fb38..93c1c2b2a 100644 --- a/src/base/iana/macros.rs +++ b/src/base/iana/macros.rs @@ -12,7 +12,7 @@ macro_rules! int_enum { $( $(#[$variant_attr:meta])* ( $variant:ident => $value:expr, $mnemonic:expr) )* ) => { $(#[$attr])* - #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] + #[derive(Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct $ianatype($inttype); impl $ianatype { @@ -101,6 +101,29 @@ macro_rules! int_enum { value.to_int() } } + + //--- Debug + + impl core::fmt::Debug for $ianatype { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self.to_mnemonic().and_then(|bytes| { + core::str::from_utf8(bytes).ok() + }) { + Some(mnemonic) => { + write!( + f, + concat!(stringify!($ianatype), "::{}"), + mnemonic + ) + } + None => { + f.debug_tuple(stringify!($ianatype)) + .field(&self.0) + .finish() + } + } + } + } } } diff --git a/src/base/iana/rcode.rs b/src/base/iana/rcode.rs index 96a83c530..e9d1525b0 100644 --- a/src/base/iana/rcode.rs +++ b/src/base/iana/rcode.rs @@ -45,7 +45,7 @@ use core::fmt; /// [IANA DNS RCODEs]: http://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-6 /// [RFC 1035]: https://tools.ietf.org/html/rfc1035 /// [RFC 2671]: https://tools.ietf.org/html/rfc2671 -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct Rcode(u8); impl Rcode { @@ -193,6 +193,25 @@ impl Rcode { pub const fn to_int(self) -> u8 { self.0 } + + /// Returns the mnemonic for this value if there is one. + #[must_use] + pub const fn to_mnemonic(self) -> Option<&'static [u8]> { + match self { + Rcode::NOERROR => Some(b"NOERROR"), + Rcode::FORMERR => Some(b"FORMERR"), + Rcode::SERVFAIL => Some(b"SERVFAIL"), + Rcode::NXDOMAIN => Some(b"NXDOMAIN"), + Rcode::NOTIMP => Some(b"NOTIMP"), + Rcode::REFUSED => Some(b"REFUSED"), + Rcode::YXDOMAIN => Some(b"YXDOMAIN"), + Rcode::YXRRSET => Some(b"YXRRSET"), + Rcode::NXRRSET => Some(b"NXRRSET"), + Rcode::NOTAUTH => Some(b"NOAUTH"), + Rcode::NOTZONE => Some(b"NOTZONE"), + _ => None, + } + } } //--- TryFrom and From @@ -211,23 +230,28 @@ impl From for u8 { } } -//--- Display +//--- Display and Debug impl fmt::Display for Rcode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - Rcode::NOERROR => "NOERROR".fmt(f), - Rcode::FORMERR => "FORMERR".fmt(f), - Rcode::SERVFAIL => "SERVFAIL".fmt(f), - Rcode::NXDOMAIN => "NXDOMAIN".fmt(f), - Rcode::NOTIMP => "NOTIMP".fmt(f), - Rcode::REFUSED => "REFUSED".fmt(f), - Rcode::YXDOMAIN => "YXDOMAIN".fmt(f), - Rcode::YXRRSET => "YXRRSET".fmt(f), - Rcode::NXRRSET => "NXRRSET".fmt(f), - Rcode::NOTAUTH => "NOAUTH".fmt(f), - Rcode::NOTZONE => "NOTZONE".fmt(f), - _ => self.0.fmt(f), + match self + .to_mnemonic() + .and_then(|bytes| core::str::from_utf8(bytes).ok()) + { + Some(mnemonic) => f.write_str(mnemonic), + None => self.0.fmt(f), + } + } +} + +impl fmt::Debug for Rcode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self + .to_mnemonic() + .and_then(|bytes| core::str::from_utf8(bytes).ok()) + { + Some(mnemonic) => write!(f, "Rcode::{}", mnemonic), + None => f.debug_tuple("Rcode").field(&self.0).finish(), } } } @@ -314,7 +338,7 @@ impl<'de> serde::Deserialize<'de> for Rcode { /// [RFC 2845]: https://tools.ietf.org/html/rfc2845 /// [RFC 2930]: https://tools.ietf.org/html/rfc2930 /// [RFC 6891]: https://tools.ietf.org/html/rfc6891 -#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[derive(Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct OptRcode(u16); impl OptRcode { @@ -514,6 +538,27 @@ impl OptRcode { pub fn ext(self) -> u8 { self.to_parts().1 } + + /// Returns the mnemonic for this value if there is one. + #[must_use] + pub const fn to_mnemonic(self) -> Option<&'static [u8]> { + match self { + OptRcode::NOERROR => Some(b"NOERROR"), + OptRcode::FORMERR => Some(b"FORMERR"), + OptRcode::SERVFAIL => Some(b"SERVFAIL"), + OptRcode::NXDOMAIN => Some(b"NXDOMAIN"), + OptRcode::NOTIMP => Some(b"NOTIMP"), + OptRcode::REFUSED => Some(b"REFUSED"), + OptRcode::YXDOMAIN => Some(b"YXDOMAIN"), + OptRcode::YXRRSET => Some(b"YXRRSET"), + OptRcode::NXRRSET => Some(b"NXRRSET"), + OptRcode::NOTAUTH => Some(b"NOAUTH"), + OptRcode::NOTZONE => Some(b"NOTZONE"), + OptRcode::BADVERS => Some(b"BADVERS"), + OptRcode::BADCOOKIE => Some(b"BADCOOKIE"), + _ => None, + } + } } //--- TryFrom and From @@ -538,25 +583,28 @@ impl From for OptRcode { } } -//--- Display +//--- Display and Debug impl fmt::Display for OptRcode { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match *self { - OptRcode::NOERROR => "NOERROR".fmt(f), - OptRcode::FORMERR => "FORMERR".fmt(f), - OptRcode::SERVFAIL => "SERVFAIL".fmt(f), - OptRcode::NXDOMAIN => "NXDOMAIN".fmt(f), - OptRcode::NOTIMP => "NOTIMP".fmt(f), - OptRcode::REFUSED => "REFUSED".fmt(f), - OptRcode::YXDOMAIN => "YXDOMAIN".fmt(f), - OptRcode::YXRRSET => "YXRRSET".fmt(f), - OptRcode::NXRRSET => "NXRRSET".fmt(f), - OptRcode::NOTAUTH => "NOAUTH".fmt(f), - OptRcode::NOTZONE => "NOTZONE".fmt(f), - OptRcode::BADVERS => "BADVER".fmt(f), - OptRcode::BADCOOKIE => "BADCOOKIE".fmt(f), - _ => self.0.fmt(f), + match self + .to_mnemonic() + .and_then(|bytes| core::str::from_utf8(bytes).ok()) + { + Some(mnemonic) => f.write_str(mnemonic), + None => self.0.fmt(f), + } + } +} + +impl fmt::Debug for OptRcode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self + .to_mnemonic() + .and_then(|bytes| core::str::from_utf8(bytes).ok()) + { + Some(mnemonic) => write!(f, "Rcode::{}", mnemonic), + None => f.debug_tuple("Rcode").field(&self.0).finish(), } } } diff --git a/src/base/message.rs b/src/base/message.rs index d0e13bdc9..d14f128dc 100644 --- a/src/base/message.rs +++ b/src/base/message.rs @@ -13,7 +13,7 @@ use super::header::{Header, HeaderCounts, HeaderSection}; use super::iana::{Class, OptRcode, Rcode, Rtype}; use super::message_builder::{AdditionalBuilder, AnswerBuilder, PushError}; -use super::name::ParsedDname; +use super::name::ParsedName; use super::opt::{Opt, OptRecord}; use super::question::Question; use super::rdata::{ParseAnyRecordData, ParseRecordData}; @@ -268,6 +268,14 @@ impl Message { { unsafe { Message::from_slice_unchecked(self.octets.as_ref()) } } + + /// Returns a message for a slice reference. + pub fn for_slice_ref(&self) -> Message<&[u8]> + where + Octs: AsRef<[u8]>, + { + unsafe { Message::from_octets_unchecked(self.octets.as_ref()) } + } } /// # Header Section @@ -432,7 +440,7 @@ impl Message { /// parsing fails. pub fn first_question( &self, - ) -> Option>>> { + ) -> Option>>> { match self.question().next() { None | Some(Err(..)) => None, Some(Ok(question)) => Some(question), @@ -447,7 +455,7 @@ impl Message { /// [`first_question`]: #method.first_question pub fn sole_question( &self, - ) -> Result>>, ParseError> { + ) -> Result>>, ParseError> { match self.header_counts().qdcount() { 0 => return Err(ParseError::form_error("no question")), 1 => {} @@ -494,7 +502,7 @@ impl Message { // must have a loop. While the ANCOUNT could be unreasonably large, the // iterator would break off in this case and we break out with a None // right away. - pub fn canonical_name(&self) -> Option>> { + pub fn canonical_name(&self) -> Option>> { let question = match self.first_question() { None => return None, Some(question) => question, @@ -547,7 +555,7 @@ impl Message { /// `None`. pub fn get_last_additional<'s, Data: ParseRecordData<'s, Octs>>( &'s self, - ) -> Option>, Data>> { + ) -> Option>, Data>> { let mut section = match self.additional() { Ok(section) => section, Err(_) => return None, @@ -794,7 +802,7 @@ impl<'a, Octs: ?Sized> Copy for QuestionSection<'a, Octs> {} //--- Iterator impl<'a, Octs: Octets + ?Sized> Iterator for QuestionSection<'a, Octs> { - type Item = Result>>, ParseError>; + type Item = Result>>, ParseError>; fn next(&mut self) -> Option { match self.count { @@ -946,7 +954,7 @@ impl<'a, Octs: Octets + ?Sized> RecordSection<'a, Octs> { /// The record type is given through its record data type. Since the data /// is being parsed, this type must implement [`ParseRecordData`]. For /// record data types that are generic over domain name types, this is - /// normally achieved by giving them a [`ParsedDname`]. As a convenience, + /// normally achieved by giving them a [`ParsedName`]. As a convenience, /// type aliases for all the fundamental record data types exist in the /// [domain::rdata::parsed] module. /// @@ -954,7 +962,7 @@ impl<'a, Octs: Octets + ?Sized> RecordSection<'a, Octs> { /// of `self`. It will *not* start from the beginning of the section. /// /// [`ParseRecordData`]: ../rdata/trait.ParseRecordData.html - /// [`ParsedDname`]: ../name/struct.ParsedDname.html + /// [`ParsedName`]: ../name/struct.ParsedName.html /// [domain::rdata::parsed]: ../../rdata/parsed/index.html #[must_use] pub fn limit_to>( @@ -1163,8 +1171,7 @@ where Octs: Octets + ?Sized, Data: ParseRecordData<'a, Octs>, { - type Item = - Result>, Data>, ParseError>; + type Item = Result>, Data>, ParseError>; fn next(&mut self) -> Option { loop { @@ -1252,8 +1259,7 @@ where Octs: Octets + ?Sized, Data: ParseAnyRecordData<'a, Octs>, { - type Item = - Result>, Data>, ParseError>; + type Item = Result>, Data>, ParseError>; fn next(&mut self) -> Option { let record = match self.section.next() { @@ -1329,7 +1335,7 @@ mod test { #[cfg(feature = "std")] use crate::base::message_builder::MessageBuilder; #[cfg(feature = "std")] - use crate::base::name::Dname; + use crate::base::name::Name; #[cfg(feature = "std")] use crate::rdata::{AllRecordData, Ns}; #[cfg(feature = "std")] @@ -1341,16 +1347,16 @@ mod test { let msg = MessageBuilder::new_vec(); let mut msg = msg.answer(); msg.push(( - Dname::vec_from_str("foo.example.com.").unwrap(), + Name::vec_from_str("foo.example.com.").unwrap(), 86000, - Cname::new(Dname::vec_from_str("baz.example.com.").unwrap()), + Cname::new(Name::vec_from_str("baz.example.com.").unwrap()), )) .unwrap(); let mut msg = msg.authority(); msg.push(( - Dname::vec_from_str("bar.example.com.").unwrap(), + Name::vec_from_str("bar.example.com.").unwrap(), 86000, - Ns::new(Dname::vec_from_str("baz.example.com.").unwrap()), + Ns::new(Name::vec_from_str("baz.example.com.").unwrap()), )) .unwrap(); msg.into_message() @@ -1369,50 +1375,50 @@ mod test { // Message without CNAMEs. let mut msg = MessageBuilder::new_vec().question(); - msg.push((Dname::vec_from_str("example.com.").unwrap(), Rtype::A)) + msg.push((Name::vec_from_str("example.com.").unwrap(), Rtype::A)) .unwrap(); let msg_ref = msg.as_message(); assert_eq!( - Dname::vec_from_str("example.com.").unwrap(), + Name::vec_from_str("example.com.").unwrap(), msg_ref.canonical_name().unwrap() ); // Message with CNAMEs. let mut msg = msg.answer(); msg.push(( - Dname::vec_from_str("bar.example.com.").unwrap(), + Name::vec_from_str("bar.example.com.").unwrap(), 86000, - Cname::new(Dname::vec_from_str("baz.example.com.").unwrap()), + Cname::new(Name::vec_from_str("baz.example.com.").unwrap()), )) .unwrap(); msg.push(( - Dname::vec_from_str("example.com.").unwrap(), + Name::vec_from_str("example.com.").unwrap(), 86000, - Cname::new(Dname::vec_from_str("foo.example.com.").unwrap()), + Cname::new(Name::vec_from_str("foo.example.com.").unwrap()), )) .unwrap(); msg.push(( - Dname::vec_from_str("foo.example.com.").unwrap(), + Name::vec_from_str("foo.example.com.").unwrap(), 86000, - Cname::new(Dname::vec_from_str("bar.example.com.").unwrap()), + Cname::new(Name::vec_from_str("bar.example.com.").unwrap()), )) .unwrap(); let msg_ref = msg.as_message(); assert_eq!( - Dname::vec_from_str("baz.example.com.").unwrap(), + Name::vec_from_str("baz.example.com.").unwrap(), msg_ref.canonical_name().unwrap() ); // CNAME loop. msg.push(( - Dname::vec_from_str("baz.example.com").unwrap(), + Name::vec_from_str("baz.example.com").unwrap(), 86000, - Cname::new(Dname::vec_from_str("foo.example.com").unwrap()), + Cname::new(Name::vec_from_str("foo.example.com").unwrap()), )) .unwrap(); assert!(msg.as_message().canonical_name().is_none()); msg.push(( - Dname::vec_from_str("baz.example.com").unwrap(), + Name::vec_from_str("baz.example.com").unwrap(), 86000, A::from_octets(127, 0, 0, 1), )) @@ -1442,7 +1448,7 @@ mod test { let target = MessageBuilder::new_vec().question(); let res = msg.copy_records(target.answer(), |rr| { if let Ok(Some(rr)) = - rr.into_record::>>() + rr.into_record::>>() { if rr.rtype() == Rtype::CNAME { return Some(rr); diff --git a/src/base/message_builder.rs b/src/base/message_builder.rs index a6831100f..1edcb3e3f 100644 --- a/src/base/message_builder.rs +++ b/src/base/message_builder.rs @@ -71,12 +71,12 @@ #![cfg_attr(not(feature = "std"), doc = "```ignore")] //! use std::str::FromStr; //! use domain::base::{ -//! Dname, MessageBuilder, Rtype, StaticCompressor, StreamTarget +//! Name, MessageBuilder, Rtype, StaticCompressor, StreamTarget //! }; //! use domain::rdata::A; //! //! // Make a domain name we can use later on. -//! let name = Dname::>::from_str("example.com").unwrap(); +//! let name = Name::>::from_str("example.com").unwrap(); //! //! // Create a message builder wrapping a compressor wrapping a stream //! // target. @@ -134,8 +134,8 @@ use super::header::{CountOverflow, Header, HeaderCounts, HeaderSection}; use super::iana::Rtype; use super::iana::{OptRcode, OptionCode, Rcode}; use super::message::Message; -use super::name::{Label, ToDname}; -use super::opt::{ComposeOptData, OptHeader}; +use super::name::{Label, ToName}; +use super::opt::{ComposeOptData, OptHeader, OptRecord}; use super::question::ComposeQuestion; use super::record::ComposeRecord; use super::wire::{Compose, Composer}; @@ -259,7 +259,7 @@ impl MessageBuilder { /// Sets a random ID, pushes the domain and the AXFR record type into /// the question section, and converts the builder into an answer builder. #[cfg(feature = "rand")] - pub fn request_axfr( + pub fn request_axfr( mut self, apex: N, ) -> Result, PushError> { @@ -520,14 +520,14 @@ impl QuestionBuilder { /// #[cfg_attr(feature = "std", doc = "```")] #[cfg_attr(not(feature = "std"), doc = "```ignore")] - /// use domain::base::{Dname, MessageBuilder, Question, Rtype}; + /// use domain::base::{Name, MessageBuilder, Question, Rtype}; /// use domain::base::iana::Class; /// /// let mut msg = MessageBuilder::new_vec().question(); - /// msg.push(Question::new_in(Dname::root_ref(), Rtype::A)).unwrap(); - /// msg.push(&Question::new_in(Dname::root_ref(), Rtype::A)).unwrap(); - /// msg.push((Dname::root_ref(), Rtype::A, Class::IN)).unwrap(); - /// msg.push((Dname::root_ref(), Rtype::A)).unwrap(); + /// msg.push(Question::new_in(Name::root_ref(), Rtype::A)).unwrap(); + /// msg.push(&Question::new_in(Name::root_ref(), Rtype::A)).unwrap(); + /// msg.push((Name::root_ref(), Rtype::A, Class::IN)).unwrap(); + /// msg.push((Name::root_ref(), Rtype::A)).unwrap(); /// ``` pub fn push( &mut self, @@ -772,21 +772,21 @@ impl AnswerBuilder { /// #[cfg_attr(feature = "std", doc = "```")] #[cfg_attr(not(feature = "std"), doc = "```ignore")] - /// use domain::base::{Dname, MessageBuilder, Record, Rtype, Ttl}; + /// use domain::base::{Name, MessageBuilder, Record, Rtype, Ttl}; /// use domain::base::iana::Class; /// use domain::rdata::A; /// /// let mut msg = MessageBuilder::new_vec().answer(); /// let record = Record::new( - /// Dname::root_ref(), Class::IN, Ttl::from_secs(86400), A::from_octets(192, 0, 2, 1) + /// Name::root_ref(), Class::IN, Ttl::from_secs(86400), A::from_octets(192, 0, 2, 1) /// ); /// msg.push(&record).unwrap(); /// msg.push(record).unwrap(); /// msg.push( - /// (Dname::root_ref(), Class::IN, 86400, A::from_octets(192, 0, 2, 1)) + /// (Name::root_ref(), Class::IN, 86400, A::from_octets(192, 0, 2, 1)) /// ).unwrap(); /// msg.push( - /// (Dname::root_ref(), 86400, A::from_octets(192, 0, 2, 1)) + /// (Name::root_ref(), 86400, A::from_octets(192, 0, 2, 1)) /// ).unwrap(); /// ``` /// @@ -1024,22 +1024,22 @@ impl AuthorityBuilder { /// #[cfg_attr(feature = "std", doc = "```")] #[cfg_attr(not(feature = "std"), doc = "```ignore")] - /// use domain::base::{Dname, MessageBuilder, Record, Rtype, Ttl}; + /// use domain::base::{Name, MessageBuilder, Record, Rtype, Ttl}; /// use domain::base::iana::Class; /// use domain::rdata::A; /// /// let mut msg = MessageBuilder::new_vec().authority(); /// let record = Record::new( - /// Dname::root_ref(), Class::IN, Ttl::from_secs(86400), + /// Name::root_ref(), Class::IN, Ttl::from_secs(86400), /// A::from_octets(192, 0, 2, 1) /// ); /// msg.push(&record).unwrap(); /// msg.push(record).unwrap(); /// msg.push( - /// (Dname::root_ref(), Class::IN, 86400, A::from_octets(192, 0, 2, 1)) + /// (Name::root_ref(), Class::IN, 86400, A::from_octets(192, 0, 2, 1)) /// ).unwrap(); /// msg.push( - /// (Dname::root_ref(), 86400, A::from_octets(192, 0, 2, 1)) + /// (Name::root_ref(), 86400, A::from_octets(192, 0, 2, 1)) /// ).unwrap(); /// ``` pub fn push( @@ -1282,21 +1282,21 @@ impl AdditionalBuilder { /// #[cfg_attr(feature = "std", doc = "```")] #[cfg_attr(not(feature = "std"), doc = "```ignore")] - /// use domain::base::{Dname, MessageBuilder, Record, Rtype, Ttl}; + /// use domain::base::{Name, MessageBuilder, Record, Rtype, Ttl}; /// use domain::base::iana::Class; /// use domain::rdata::A; /// /// let mut msg = MessageBuilder::new_vec().additional(); /// let record = Record::new( - /// Dname::root_ref(), Class::IN, Ttl::from_secs(86400), A::from_octets(192, 0, 2, 1) + /// Name::root_ref(), Class::IN, Ttl::from_secs(86400), A::from_octets(192, 0, 2, 1) /// ); /// msg.push(&record).unwrap(); /// msg.push(record).unwrap(); /// msg.push( - /// (Dname::root_ref(), Class::IN, 86400, A::from_octets(192, 0, 2, 1)) + /// (Name::root_ref(), Class::IN, 86400, A::from_octets(192, 0, 2, 1)) /// ).unwrap(); /// msg.push( - /// (Dname::root_ref(), 86400, A::from_octets(192, 0, 2, 1)) + /// (Name::root_ref(), 86400, A::from_octets(192, 0, 2, 1)) /// ).unwrap(); /// ``` pub fn push( @@ -1593,6 +1593,16 @@ impl<'a, Target: Composer + ?Sized> OptBuilder<'a, Target> { } } + /// Replaces the contents of this [`OptBuilder`] with the given + /// [`OptRecord`]`. + pub fn clone_from>( + &mut self, + source: &OptRecord, + ) -> Result<(), Target::AppendError> { + self.target.truncate(self.start); + source.as_record().compose(self.target) + } + /// Appends an option to the OPT record. pub fn push( &mut self, @@ -1844,12 +1854,12 @@ where Target: Composer, Target::AppendError: Into, { - fn append_compressed_dname( + fn append_compressed_name( &mut self, name: &N, ) -> Result<(), Self::AppendError> { self.target - .append_compressed_dname(name) + .append_compressed_name(name) .map_err(Into::into)?; self.update_shim() } @@ -1980,7 +1990,7 @@ impl OctetsBuilder for StaticCompressor { } impl Composer for StaticCompressor { - fn append_compressed_dname( + fn append_compressed_name( &mut self, name: &N, ) -> Result<(), Self::AppendError> { @@ -2197,7 +2207,7 @@ impl OctetsBuilder for TreeCompressor { #[cfg(feature = "std")] impl Composer for TreeCompressor { - fn append_compressed_dname( + fn append_compressed_name( &mut self, name: &N, ) -> Result<(), Self::AppendError> { @@ -2284,14 +2294,14 @@ impl std::error::Error for PushError {} mod test { use super::*; use crate::base::opt; - use crate::base::{Dname, Serial, Ttl}; + use crate::base::{Name, Serial, Ttl}; use crate::rdata::{Ns, Soa, A}; use core::str::FromStr; #[test] fn message_builder() { // Make a domain name we can use later on. - let name = Dname::>::from_str("example.com").unwrap(); + let name = Name::>::from_str("example.com").unwrap(); // Create a message builder wrapping a compressor wrapping a stream // target. @@ -2390,14 +2400,14 @@ mod test { msg.header_mut().set_ra(true); msg.header_mut().set_qr(true); - msg.push((&"example".parse::>>().unwrap(), Rtype::NS)) + msg.push((&"example".parse::>>().unwrap(), Rtype::NS)) .unwrap(); let mut msg = msg.authority(); - let mname: Dname> = "a.root-servers.net".parse().unwrap(); + let mname: Name> = "a.root-servers.net".parse().unwrap(); let rname = "nstld.verisign-grs.com".parse().unwrap(); msg.push(( - Dname::root_slice(), + Name::root_slice(), 86390, Soa::new( mname, diff --git a/src/base/mod.rs b/src/base/mod.rs index b8908722c..0b55899fd 100644 --- a/src/base/mod.rs +++ b/src/base/mod.rs @@ -104,7 +104,7 @@ pub use self::message_builder::{ MessageBuilder, RecordSectionBuilder, StaticCompressor, StreamTarget, }; pub use self::name::{ - Dname, DnameBuilder, ParsedDname, RelativeDname, ToDname, ToRelativeDname, + Name, NameBuilder, ParsedName, RelativeName, ToName, ToRelativeName, }; pub use self::question::Question; pub use self::rdata::{ParseRecordData, RecordData, UnknownRecordData}; diff --git a/src/base/name/dname.rs b/src/base/name/absolute.rs similarity index 74% rename from src/base/name/dname.rs rename to src/base/name/absolute.rs index 0505914c4..fa6a264f9 100644 --- a/src/base/name/dname.rs +++ b/src/base/name/absolute.rs @@ -3,12 +3,13 @@ //! This is a private module. Its public types are re-exported by the parent. use super::super::cmp::CanonicalOrd; +use super::super::net::IpAddr; use super::super::scan::{Scanner, Symbol, SymbolCharsError, Symbols}; use super::super::wire::{FormError, ParseError}; -use super::builder::{DnameBuilder, FromStrError}; +use super::builder::{FromStrError, NameBuilder, PushError}; use super::label::{Label, LabelTypeError, SplitLabelError}; -use super::relative::{DnameIter, RelativeDname}; -use super::traits::{FlattenInto, ToDname, ToLabelIter}; +use super::relative::{NameIter, RelativeName}; +use super::traits::{FlattenInto, ToLabelIter, ToName}; #[cfg(feature = "bytes")] use bytes::Bytes; use core::ops::{Bound, RangeBounds}; @@ -24,7 +25,7 @@ use octseq::serde::{DeserializeOctets, SerializeOctets}; #[cfg(feature = "std")] use std::vec::Vec; -//------------ Dname --------------------------------------------------------- +//------------ Name ---------------------------------------------------------- /// An uncompressed, absolute domain name. /// @@ -34,32 +35,32 @@ use std::vec::Vec; /// etc. /// /// You can construct a domain name from a string via the `FromStr` trait or -/// manually via a [`DnameBuilder`]. In addition, you can also parse it from +/// manually via a [`NameBuilder`]. In addition, you can also parse it from /// a message. This will, however, require the name to be uncompressed. /// Otherwise, you would receive a [`ParsedDname`] which can be converted into -/// `Dname` via [`ToDname::to_dname`]. +/// `Name` via [`ToName::to_name`]. /// /// The canonical way to convert a domain name into its presentation format is /// using [`to_string`] or by using its [`Display`] implementation (which /// performs no allocations). /// -/// [`DnameBuilder`]: struct.DnameBuilder.html +/// [`NameBuilder`]: struct.NameBuilder.html /// [`ParsedDname`]: struct.ParsedDname.html -/// [`RelativeDname`]: struct.RelativeDname.html -/// [`ToDname::to_dname`]: trait.ToDname.html#method.to_dname +/// [`RelativeName`]: struct.RelativeName.html +/// [`ToName::to_name`]: trait.ToName.html#method.to_name /// [`to_string`]: `std::string::ToString::to_string` /// [`Display`]: `std::fmt::Display` #[derive(Clone)] -pub struct Dname(Octs); +pub struct Name(Octs); -impl Dname<()> { +impl Name<()> { /// Domain names have a maximum length of 255 octets. pub const MAX_LEN: usize = 255; } /// # Creating Values /// -impl Dname { +impl Name { /// Creates a domain name from the underlying octets without any check. /// /// Since this will allow to actually construct an incorrectly encoded @@ -71,7 +72,7 @@ impl Dname { /// encoded absolute domain name. It must be at most 255 octets long. /// It must contain the root label exactly once as its last label. pub const unsafe fn from_octets_unchecked(octets: Octs) -> Self { - Dname(octets) + Self(octets) } /// Creates a domain name from an octet sequence. @@ -79,12 +80,12 @@ impl Dname { /// This will only succeed if `octets` contains a properly encoded /// absolute domain name in wire format. Because the function checks for /// correctness, this will take a wee bit of time. - pub fn from_octets(octets: Octs) -> Result + pub fn from_octets(octets: Octs) -> Result where Octs: AsRef<[u8]>, { - Dname::check_slice(octets.as_ref())?; - Ok(unsafe { Dname::from_octets_unchecked(octets) }) + Name::check_slice(octets.as_ref())?; + Ok(unsafe { Self::from_octets_unchecked(octets) }) } pub fn from_symbols(symbols: Sym) -> Result @@ -96,7 +97,7 @@ impl Dname { + AsMut<[u8]>, Sym: IntoIterator, { - // DnameBuilder can’t deal with a single dot, so we need to special + // NameBuilder can’t deal with a single dot, so we need to special // case that. let mut symbols = symbols.into_iter(); let first = match symbols.next() { @@ -119,10 +120,10 @@ impl Dname { } } - let mut builder = DnameBuilder::::new(); + let mut builder = NameBuilder::::new(); builder.push_symbol(first)?; builder.append_symbols(symbols)?; - builder.into_dname().map_err(Into::into) + builder.into_name().map_err(Into::into) } /// Creates a domain name from a sequence of characters. @@ -158,10 +159,10 @@ impl Dname { } /// Reads a name in presentation format from the beginning of a scanner. - pub fn scan>( + pub fn scan>( scanner: &mut S, ) -> Result { - scanner.scan_dname() + scanner.scan_name() } /// Returns a domain name consisting of the root label only. @@ -181,12 +182,48 @@ impl Dname { { unsafe { Self::from_octets_unchecked(b"\0".as_ref().into()) } } + + /// Creates a domain name for reverse IP address lookup. + /// + /// The returned name will use the standard suffixes of `in-addr.arpa.` + /// for IPv4 addresses and `ip6.arpa.` for IPv6. + pub fn reverse_from_addr(addr: IpAddr) -> Result + where + Octs: FromBuilder, + ::Builder: EmptyBuilder + + FreezeBuilder + + AsRef<[u8]> + + AsMut<[u8]>, + { + let mut builder = + NameBuilder::<::Builder>::new(); + match addr { + IpAddr::V4(addr) => { + let [a, b, c, d] = addr.octets(); + builder.append_dec_u8_label(d)?; + builder.append_dec_u8_label(c)?; + builder.append_dec_u8_label(b)?; + builder.append_dec_u8_label(a)?; + builder.append_label(b"in-addr")?; + builder.append_label(b"arpa")?; + } + IpAddr::V6(addr) => { + for &item in addr.octets().iter().rev() { + builder.append_hex_digit_label(item)?; + builder.append_hex_digit_label(item >> 4)?; + } + builder.append_label(b"ip6")?; + builder.append_label(b"arpa")?; + } + } + builder.into_name() + } } -impl Dname<[u8]> { +impl Name<[u8]> { /// Creates a domain name from an octet slice without checking, unsafe fn from_slice_unchecked(slice: &[u8]) -> &Self { - &*(slice as *const [u8] as *const Dname<[u8]>) + &*(slice as *const [u8] as *const Name<[u8]>) } /// Creates a domain name from an octets slice. @@ -196,15 +233,15 @@ impl Dname<[u8]> { /// # Examples /// /// ``` - /// use domain::base::name::Dname; - /// Dname::from_slice(b"\x07example\x03com\x00"); + /// use domain::base::name::Name; + /// Name::from_slice(b"\x07example\x03com\x00"); /// ``` /// /// # Errors /// /// This will only succeed if `slice` contains a properly encoded /// absolute domain name. - pub fn from_slice(slice: &[u8]) -> Result<&Self, DnameError> { + pub fn from_slice(slice: &[u8]) -> Result<&Self, NameError> { Self::check_slice(slice)?; Ok(unsafe { Self::from_slice_unchecked(slice) }) } @@ -216,9 +253,9 @@ impl Dname<[u8]> { } /// Checks whether an octet slice contains a correctly encoded name. - fn check_slice(mut slice: &[u8]) -> Result<(), DnameError> { - if slice.len() > Dname::MAX_LEN { - return Err(DnameError(DnameErrorEnum::LongName)); + fn check_slice(mut slice: &[u8]) -> Result<(), NameError> { + if slice.len() > Name::MAX_LEN { + return Err(NameError(DnameErrorEnum::LongName)); } loop { let (label, tail) = Label::split_from(slice)?; @@ -226,11 +263,11 @@ impl Dname<[u8]> { if tail.is_empty() { break; } else { - return Err(DnameError(DnameErrorEnum::TrailingData)); + return Err(NameError(DnameErrorEnum::TrailingData)); } } if tail.is_empty() { - return Err(DnameError(DnameErrorEnum::RelativeName)); + return Err(NameError(DnameErrorEnum::RelativeName)); } slice = tail; } @@ -238,7 +275,7 @@ impl Dname<[u8]> { } } -impl Dname<&'static [u8]> { +impl Name<&'static [u8]> { /// Creates a domain name for the root label only atop a slice reference. #[must_use] pub fn root_ref() -> Self { @@ -247,7 +284,7 @@ impl Dname<&'static [u8]> { } #[cfg(feature = "std")] -impl Dname> { +impl Name> { /// Creates a domain name for the root label only atop a `Vec`. #[must_use] pub fn root_vec() -> Self { @@ -261,7 +298,7 @@ impl Dname> { } #[cfg(feature = "bytes")] -impl Dname { +impl Name { /// Creates a domain name for the root label only atop a bytes values. pub fn root_bytes() -> Self { Self::root() @@ -275,7 +312,7 @@ impl Dname { /// # Conversions /// -impl Dname { +impl Name { /// Returns a reference to the underlying octets sequence. /// /// These octets contain the domain name in wire format. @@ -292,18 +329,18 @@ impl Dname { } /// Converts the name into a relative name by dropping the root label. - pub fn into_relative(mut self) -> RelativeDname + pub fn into_relative(mut self) -> RelativeName where Octs: Sized + AsRef<[u8]> + Truncate, { let len = self.0.as_ref().len() - 1; self.0.truncate(len); - unsafe { RelativeDname::from_octets_unchecked(self.0) } + unsafe { RelativeName::from_octets_unchecked(self.0) } } /// Returns a domain name using a reference to the octets. - pub fn for_ref(&self) -> Dname<&Octs> { - unsafe { Dname::from_octets_unchecked(&self.0) } + pub fn for_ref(&self) -> Name<&Octs> { + unsafe { Name::from_octets_unchecked(&self.0) } } /// Returns a reference to the underlying octets slice. @@ -317,11 +354,11 @@ impl Dname { } /// Returns a domain name for the octets slice of the content. - pub fn for_slice(&self) -> &Dname<[u8]> + pub fn for_slice(&self) -> &Name<[u8]> where Octs: AsRef<[u8]>, { - unsafe { Dname::from_slice_unchecked(self.0.as_ref()) } + unsafe { Name::from_slice_unchecked(self.0.as_ref()) } } /// Converts the domain name into its canonical form. @@ -338,7 +375,7 @@ impl Dname { /// # Properties /// -impl + ?Sized> Dname { +impl + ?Sized> Name { /// Returns whether the name is the root label only. pub fn is_root(&self) -> bool { self.0.as_ref().len() == 1 @@ -363,10 +400,10 @@ impl + ?Sized> Dname { /// # Working with Labels /// -impl + ?Sized> Dname { +impl + ?Sized> Name { /// Returns an iterator over the labels of the domain name. - pub fn iter(&self) -> DnameIter { - DnameIter::new(self.0.as_ref()) + pub fn iter(&self) -> NameIter { + NameIter::new(self.0.as_ref()) } /// Returns an iterator over the suffixes of the name. @@ -485,10 +522,10 @@ impl + ?Sized> Dname { pub fn slice( &self, range: impl RangeBounds, - ) -> &RelativeDname<[u8]> { + ) -> &RelativeName<[u8]> { self.check_bounds(&range); unsafe { - RelativeDname::from_slice_unchecked(self.0.as_ref().range(range)) + RelativeName::from_slice_unchecked(self.0.as_ref().range(range)) } } @@ -509,9 +546,9 @@ impl + ?Sized> Dname { /// label or is out of bounds. /// /// [`range_from`]: #method.range_from - pub fn slice_from(&self, begin: usize) -> &Dname<[u8]> { + pub fn slice_from(&self, begin: usize) -> &Name<[u8]> { self.check_index(begin); - unsafe { Dname::from_slice_unchecked(&self.0.as_ref()[begin..]) } + unsafe { Name::from_slice_unchecked(&self.0.as_ref()[begin..]) } } /// Returns the part of the name indicated by start and end positions. @@ -534,12 +571,12 @@ impl + ?Sized> Dname { pub fn range( &self, range: impl RangeBounds, - ) -> RelativeDname<::Range<'_>> + ) -> RelativeName<::Range<'_>> where Octs: Octets, { self.check_bounds(&range); - unsafe { RelativeDname::from_octets_unchecked(self.0.range(range)) } + unsafe { RelativeName::from_octets_unchecked(self.0.range(range)) } } /// Returns the part of the name starting at the given position. @@ -556,7 +593,7 @@ impl + ?Sized> Dname { pub fn range_from( &self, begin: usize, - ) -> Dname<::Range<'_>> + ) -> Name<::Range<'_>> where Octs: Octets, { @@ -568,15 +605,15 @@ impl + ?Sized> Dname { unsafe fn range_from_unchecked( &self, begin: usize, - ) -> Dname<::Range<'_>> + ) -> Name<::Range<'_>> where Octs: Octets, { - Dname::from_octets_unchecked(self.0.range(begin..)) + Name::from_octets_unchecked(self.0.range(begin..)) } } -impl + ?Sized> Dname { +impl + ?Sized> Name { /// Splits the name into two at the given position. /// /// Returns a pair of the left and right part of the split name. @@ -588,15 +625,15 @@ impl + ?Sized> Dname { pub fn split( &self, mid: usize, - ) -> (RelativeDname>, Dname>) + ) -> (RelativeName>, Name>) where Octs: Octets, { self.check_index(mid); unsafe { ( - RelativeDname::from_octets_unchecked(self.0.range(..mid)), - Dname::from_octets_unchecked(self.0.range(mid..)), + RelativeName::from_octets_unchecked(self.0.range(..mid)), + Name::from_octets_unchecked(self.0.range(mid..)), ) } } @@ -610,13 +647,13 @@ impl + ?Sized> Dname { /// /// The method will panic if `len` is not the index of a new label or if /// it is out of bounds. - pub fn truncate(mut self, len: usize) -> RelativeDname + pub fn truncate(mut self, len: usize) -> RelativeName where Octs: Truncate + Sized, { self.check_index(len); self.0.truncate(len); - unsafe { RelativeDname::from_octets_unchecked(self.0) } + unsafe { RelativeName::from_octets_unchecked(self.0) } } /// Splits off the first label. @@ -624,7 +661,7 @@ impl + ?Sized> Dname { /// If this name is longer than just the root label, returns a pair /// of that label and the remaining name. If the name is only the root /// label, returns `None`. - pub fn split_first(&self) -> Option<(&Label, Dname>)> + pub fn split_first(&self) -> Option<(&Label, Name>)> where Octs: Octets, { @@ -638,7 +675,7 @@ impl + ?Sized> Dname { /// Returns the parent of the current name. /// /// If the name consists of the root label only, returns `None`. - pub fn parent(&self) -> Option>> + pub fn parent(&self) -> Option>> where Octs: Octets, { @@ -650,10 +687,10 @@ impl + ?Sized> Dname { /// If `base` is indeed a suffix, returns a relative domain name with the /// remainder of the name. Otherwise, returns an error with an unmodified /// `self`. - pub fn strip_suffix( + pub fn strip_suffix( self, base: &N, - ) -> Result, Self> + ) -> Result, Self> where Octs: Truncate + Sized, { @@ -666,7 +703,7 @@ impl + ?Sized> Dname { } } -impl Dname { +impl Name { /// Reads a name in wire format from the beginning of a parser. pub fn parse<'a, Src: Octets = Octs> + ?Sized>( parser: &mut Parser<'a, Src>, @@ -693,8 +730,8 @@ impl Dname { } parser.remaining() - tmp.len() }; - if len > Dname::MAX_LEN { - Err(DnameError(DnameErrorEnum::LongName).into()) + if len > Name::MAX_LEN { + Err(NameError(DnameErrorEnum::LongName).into()) } else { Ok(len) } @@ -703,13 +740,13 @@ impl Dname { //--- AsRef -impl AsRef for Dname { +impl AsRef for Name { fn as_ref(&self) -> &Octs { &self.0 } } -impl + ?Sized> AsRef<[u8]> for Dname { +impl + ?Sized> AsRef<[u8]> for Name { fn as_ref(&self) -> &[u8] { self.0.as_ref() } @@ -717,13 +754,13 @@ impl + ?Sized> AsRef<[u8]> for Dname { //--- OctetsFrom -impl OctetsFrom> for Dname +impl OctetsFrom> for Name where Octs: OctetsFrom, { type Error = Octs::Error; - fn try_octets_from(source: Dname) -> Result { + fn try_octets_from(source: Name) -> Result { Octs::try_octets_from(source.0) .map(|octets| unsafe { Self::from_octets_unchecked(octets) }) } @@ -731,7 +768,7 @@ where //--- FromStr -impl FromStr for Dname +impl FromStr for Name where Octs: FromBuilder, ::Builder: EmptyBuilder @@ -760,38 +797,38 @@ where //--- FlattenInto -impl FlattenInto> for Dname +impl FlattenInto> for Name where Target: OctetsFrom, { type AppendError = Target::Error; - fn try_flatten_into(self) -> Result, Self::AppendError> { + fn try_flatten_into(self) -> Result, Self::AppendError> { Target::try_octets_from(self.0) - .map(|octets| unsafe { Dname::from_octets_unchecked(octets) }) + .map(|octets| unsafe { Name::from_octets_unchecked(octets) }) } } //--- PartialEq, and Eq -impl PartialEq for Dname +impl PartialEq for Name where Octs: AsRef<[u8]> + ?Sized, - N: ToDname + ?Sized, + N: ToName + ?Sized, { fn eq(&self, other: &N) -> bool { self.name_eq(other) } } -impl + ?Sized> Eq for Dname {} +impl + ?Sized> Eq for Name {} //--- PartialOrd, Ord, and CanonicalOrd -impl PartialOrd for Dname +impl PartialOrd for Name where Octs: AsRef<[u8]> + ?Sized, - N: ToDname + ?Sized, + N: ToName + ?Sized, { /// Returns the ordering between `self` and `other`. /// @@ -804,7 +841,7 @@ where } } -impl + ?Sized> Ord for Dname { +impl + ?Sized> Ord for Name { /// Returns the ordering between `self` and `other`. /// /// Domain name order is determined according to the ‘canonical DNS @@ -816,10 +853,10 @@ impl + ?Sized> Ord for Dname { } } -impl CanonicalOrd for Dname +impl CanonicalOrd for Name where Octs: AsRef<[u8]> + ?Sized, - N: ToDname + ?Sized, + N: ToName + ?Sized, { fn canonical_cmp(&self, other: &N) -> cmp::Ordering { self.name_cmp(other) @@ -828,7 +865,7 @@ where //--- Hash -impl + ?Sized> hash::Hash for Dname { +impl + ?Sized> hash::Hash for Name { fn hash(&self, state: &mut H) { for item in self.iter() { item.hash(state) @@ -836,13 +873,13 @@ impl + ?Sized> hash::Hash for Dname { } } -//--- ToLabelIter and ToDname +//--- ToLabelIter and ToName -impl ToLabelIter for Dname +impl ToLabelIter for Name where Octs: AsRef<[u8]> + ?Sized, { - type LabelIter<'a> = DnameIter<'a> where Octs: 'a; + type LabelIter<'a> = NameIter<'a> where Octs: 'a; fn iter_labels(&self) -> Self::LabelIter<'_> { self.iter() @@ -853,7 +890,7 @@ where } } -impl + ?Sized> ToDname for Dname { +impl + ?Sized> ToName for Name { fn as_flat_slice(&self) -> Option<&[u8]> { Some(self.0.as_ref()) } @@ -861,12 +898,12 @@ impl + ?Sized> ToDname for Dname { //--- IntoIterator -impl<'a, Octs> IntoIterator for &'a Dname +impl<'a, Octs> IntoIterator for &'a Name where Octs: AsRef<[u8]> + ?Sized, { type Item = &'a Label; - type IntoIter = DnameIter<'a>; + type IntoIter = NameIter<'a>; fn into_iter(self) -> Self::IntoIter { self.iter() @@ -875,7 +912,7 @@ where //--- Display -impl + ?Sized> fmt::Display for Dname { +impl + ?Sized> fmt::Display for Name { /// Formats the domain name. /// /// This will produce the domain name in ‘common display format’ without @@ -899,41 +936,41 @@ impl + ?Sized> fmt::Display for Dname { //--- Debug -impl + ?Sized> fmt::Debug for Dname { +impl + ?Sized> fmt::Debug for Name { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "Dname({})", self.fmt_with_dot()) + write!(f, "Name({})", self.fmt_with_dot()) } } //--- AsRef and Borrow -impl> AsRef> for Dname { - fn as_ref(&self) -> &Dname<[u8]> { +impl> AsRef> for Name { + fn as_ref(&self) -> &Name<[u8]> { self.for_slice() } } /// Borrow a domain name. /// -/// Containers holding an owned `Dname<_>` may be queried with name over a +/// Containers holding an owned `Name<_>` may be queried with name over a /// slice. This `Borrow<_>` impl supports user code querying containers with /// compatible-but-different types like the following example: /// /// ``` /// use std::collections::HashMap; /// -/// use domain::base::Dname; +/// use domain::base::Name; /// /// fn get_description( -/// hash: &HashMap>, String> +/// hash: &HashMap>, String> /// ) -> Option<&str> { -/// let lookup_name: &Dname<[u8]> = -/// Dname::from_slice(b"\x03www\x07example\x03com\0").unwrap(); +/// let lookup_name: &Name<[u8]> = +/// Name::from_slice(b"\x03www\x07example\x03com\0").unwrap(); /// hash.get(lookup_name).map(|x| x.as_ref()) /// } /// ``` -impl> borrow::Borrow> for Dname { - fn borrow(&self) -> &Dname<[u8]> { +impl> borrow::Borrow> for Name { + fn borrow(&self) -> &Name<[u8]> { self.for_slice() } } @@ -941,7 +978,7 @@ impl> borrow::Borrow> for Dname { //--- Serialize and Deserialize #[cfg(feature = "serde")] -impl serde::Serialize for Dname +impl serde::Serialize for Name where Octs: AsRef<[u8]> + SerializeOctets + ?Sized, { @@ -951,10 +988,10 @@ where ) -> Result { if serializer.is_human_readable() { serializer - .serialize_newtype_struct("Dname", &format_args!("{}", self)) + .serialize_newtype_struct("Name", &format_args!("{}", self)) } else { serializer.serialize_newtype_struct( - "Dname", + "Name", &self.0.as_serialized_octets(), ) } @@ -962,7 +999,7 @@ where } #[cfg(feature = "serde")] -impl<'de, Octs> serde::Deserialize<'de> for Dname +impl<'de, Octs> serde::Deserialize<'de> for Name where Octs: FromBuilder + DeserializeOctets<'de>, ::Builder: FreezeBuilder @@ -985,7 +1022,7 @@ where + AsRef<[u8]> + AsMut<[u8]>, { - type Value = Dname; + type Value = Name; fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str("an absolute domain name") @@ -995,7 +1032,7 @@ where self, v: &str, ) -> Result { - Dname::from_str(v).map_err(E::custom) + Name::from_str(v).map_err(E::custom) } fn visit_borrowed_bytes( @@ -1003,7 +1040,7 @@ where value: &'de [u8], ) -> Result { self.0.visit_borrowed_bytes(value).and_then(|octets| { - Dname::from_octets(octets).map_err(E::custom) + Name::from_octets(octets).map_err(E::custom) }) } @@ -1013,7 +1050,7 @@ where value: std::vec::Vec, ) -> Result { self.0.visit_byte_buf(value).and_then(|octets| { - Dname::from_octets(octets).map_err(E::custom) + Name::from_octets(octets).map_err(E::custom) }) } } @@ -1028,7 +1065,7 @@ where + AsRef<[u8]> + AsMut<[u8]>, { - type Value = Dname; + type Value = Name; fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str("an absolute domain name") @@ -1051,7 +1088,7 @@ where } deserializer - .deserialize_newtype_struct("Dname", NewtypeVisitor(PhantomData)) + .deserialize_newtype_struct("Name", NewtypeVisitor(PhantomData)) } } @@ -1060,13 +1097,13 @@ where /// An iterator over ever shorter suffixes of a domain name. #[derive(Clone)] pub struct SuffixIter<'a, Octs: ?Sized> { - name: &'a Dname, + name: &'a Name, start: Option, } impl<'a, Octs: ?Sized> SuffixIter<'a, Octs> { /// Creates a new iterator cloning `name`. - fn new(name: &'a Dname) -> Self { + fn new(name: &'a Name) -> Self { SuffixIter { name, start: Some(0), @@ -1075,7 +1112,7 @@ impl<'a, Octs: ?Sized> SuffixIter<'a, Octs> { } impl<'a, Octs: Octets + ?Sized> Iterator for SuffixIter<'a, Octs> { - type Item = Dname>; + type Item = Name>; fn next(&mut self) -> Option { let start = self.start?; @@ -1092,7 +1129,7 @@ impl<'a, Octs: Octets + ?Sized> Iterator for SuffixIter<'a, Octs> { //------------ DisplayWithDot ------------------------------------------------ -struct DisplayWithDot<'a>(&'a Dname<[u8]>); +struct DisplayWithDot<'a>(&'a Name<[u8]>); impl<'a> fmt::Display for DisplayWithDot<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -1111,11 +1148,11 @@ impl<'a> fmt::Display for DisplayWithDot<'a> { //============ Error Types =================================================== -//------------ DnameError ---------------------------------------------------- +//------------ NameError ----------------------------------------------------- /// A domain name wasn’t encoded correctly. #[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub struct DnameError(DnameErrorEnum); +pub struct NameError(DnameErrorEnum); #[derive(Clone, Copy, Debug, Eq, PartialEq)] enum DnameErrorEnum { @@ -1140,13 +1177,13 @@ enum DnameErrorEnum { //--- From -impl From for DnameError { +impl From for NameError { fn from(err: LabelTypeError) -> Self { Self(DnameErrorEnum::BadLabel(err)) } } -impl From for DnameError { +impl From for NameError { fn from(err: SplitLabelError) -> Self { Self(match err { SplitLabelError::Pointer(_) => DnameErrorEnum::CompressedName, @@ -1156,8 +1193,8 @@ impl From for DnameError { } } -impl From for FormError { - fn from(err: DnameError) -> FormError { +impl From for FormError { + fn from(err: NameError) -> FormError { FormError::new(match err.0 { DnameErrorEnum::BadLabel(_) => "unknown label type", DnameErrorEnum::CompressedName => "compressed domain name", @@ -1169,8 +1206,8 @@ impl From for FormError { } } -impl From for ParseError { - fn from(err: DnameError) -> ParseError { +impl From for ParseError { + fn from(err: NameError) -> ParseError { match err.0 { DnameErrorEnum::ShortInput => ParseError::ShortInput, _ => ParseError::Form(err.into()), @@ -1180,7 +1217,7 @@ impl From for ParseError { //--- Display and Error -impl fmt::Display for DnameError { +impl fmt::Display for NameError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.0 { DnameErrorEnum::BadLabel(ref err) => err.fmt(f), @@ -1196,7 +1233,7 @@ impl fmt::Display for DnameError { } #[cfg(feature = "std")] -impl std::error::Error for DnameError {} +impl std::error::Error for NameError {} //============ Testing ======================================================= // @@ -1217,16 +1254,16 @@ pub(crate) mod test { #[test] fn impls() { - fn assert_to_dname(_: &T) {} + fn assert_to_name(_: &T) {} - assert_to_dname(Dname::from_slice(b"\0".as_ref()).unwrap()); - assert_to_dname(&Dname::from_octets(b"\0").unwrap()); - assert_to_dname(&Dname::from_octets(b"\0".as_ref()).unwrap()); + assert_to_name(Name::from_slice(b"\0".as_ref()).unwrap()); + assert_to_name(&Name::from_octets(b"\0").unwrap()); + assert_to_name(&Name::from_octets(b"\0".as_ref()).unwrap()); #[cfg(feature = "std")] { - assert_to_dname( - &Dname::from_octets(Vec::from(b"\0".as_ref())).unwrap(), + assert_to_name( + &Name::from_octets(Vec::from(b"\0".as_ref())).unwrap(), ); } } @@ -1234,27 +1271,27 @@ pub(crate) mod test { #[cfg(feature = "bytes")] #[test] fn impls_bytes() { - fn assert_to_dname(_: &T) {} + fn assert_to_name(_: &T) {} - assert_to_dname( - &Dname::from_octets(Bytes::from(b"\0".as_ref())).unwrap(), + assert_to_name( + &Name::from_octets(Bytes::from(b"\0".as_ref())).unwrap(), ); } #[test] fn root() { - assert_eq!(Dname::root_ref().as_slice(), b"\0"); + assert_eq!(Name::root_ref().as_slice(), b"\0"); #[cfg(feature = "std")] { - assert_eq!(Dname::root_vec().as_slice(), b"\0"); + assert_eq!(Name::root_vec().as_slice(), b"\0"); } - assert_eq!(Dname::root_slice().as_slice(), b"\0"); + assert_eq!(Name::root_slice().as_slice(), b"\0"); } #[cfg(feature = "bytes")] #[test] fn root_bytes() { - assert_eq!(Dname::root_bytes().as_slice(), b"\0"); + assert_eq!(Name::root_bytes().as_slice(), b"\0"); } #[test] @@ -1262,7 +1299,7 @@ pub(crate) mod test { fn from_slice() { // a simple good name assert_eq!( - Dname::from_slice(b"\x03www\x07example\x03com\0") + Name::from_slice(b"\x03www\x07example\x03com\0") .unwrap() .as_slice(), b"\x03www\x07example\x03com\0" @@ -1270,23 +1307,23 @@ pub(crate) mod test { // relative name assert_eq!( - Dname::from_slice(b"\x03www\x07example\x03com"), - Err(DnameError(DnameErrorEnum::RelativeName)) + Name::from_slice(b"\x03www\x07example\x03com"), + Err(NameError(DnameErrorEnum::RelativeName)) ); // bytes shorter than what label length says. assert_eq!( - Dname::from_slice(b"\x03www\x07exa"), - Err(DnameError(DnameErrorEnum::ShortInput)) + Name::from_slice(b"\x03www\x07exa"), + Err(NameError(DnameErrorEnum::ShortInput)) ); // label 63 long ok, 64 bad. let mut slice = [0u8; 65]; slice[0] = 63; - assert!(Dname::from_slice(&slice[..]).is_ok()); + assert!(Name::from_slice(&slice[..]).is_ok()); let mut slice = [0u8; 66]; slice[0] = 64; - assert!(Dname::from_slice(&slice[..]).is_err()); + assert!(Name::from_slice(&slice[..]).is_err()); // name 255 long ok, 256 bad. let mut buf = std::vec::Vec::new(); @@ -1296,42 +1333,64 @@ pub(crate) mod test { assert_eq!(buf.len(), 250); let mut tmp = buf.clone(); tmp.extend_from_slice(b"\x03123\0"); - assert_eq!(Dname::from_slice(&tmp).map(|_| ()), Ok(())); + assert_eq!(Name::from_slice(&tmp).map(|_| ()), Ok(())); buf.extend_from_slice(b"\x041234\0"); - assert!(Dname::from_slice(&buf).is_err()); + assert!(Name::from_slice(&buf).is_err()); // trailing data - assert!(Dname::from_slice(b"\x03com\0\x03www\0").is_err()); + assert!(Name::from_slice(b"\x03com\0\x03www\0").is_err()); // bad label heads: compressed, other types. assert_eq!( - Dname::from_slice(b"\xa2asdasds"), + Name::from_slice(b"\xa2asdasds"), Err(LabelTypeError::Undefined.into()) ); assert_eq!( - Dname::from_slice(b"\x62asdasds"), + Name::from_slice(b"\x62asdasds"), Err(LabelTypeError::Extended(0x62).into()) ); assert_eq!( - Dname::from_slice(b"\xccasdasds"), - Err(DnameError(DnameErrorEnum::CompressedName)) + Name::from_slice(b"\xccasdasds"), + Err(NameError(DnameErrorEnum::CompressedName)) ); // empty input assert_eq!( - Dname::from_slice(b""), - Err(DnameError(DnameErrorEnum::ShortInput)) + Name::from_slice(b""), + Err(NameError(DnameErrorEnum::ShortInput)) ); } - // `Dname::from_chars` is covered in the `FromStr` test. + #[test] + fn test_dname_from_addr() { + type TestName = Name>; + + assert_eq!( + TestName::reverse_from_addr([192, 0, 2, 12].into()).unwrap(), + TestName::from_str("12.2.0.192.in-addr.arpa").unwrap() + ); + assert_eq!( + TestName::reverse_from_addr( + [0x2001, 0xdb8, 0x1234, 0x0, 0x5678, 0x1, 0x9abc, 0xdef] + .into() + ) + .unwrap(), + TestName::from_str( + "f.e.d.0.c.b.a.9.1.0.0.0.8.7.6.5.\ + 0.0.0.0.4.3.2.1.8.b.d.0.1.0.0.2.\ + ip6.arpa" + ) + .unwrap() + ); + } + // `Name::from_chars` is covered in the `FromStr` test. // // No tests for the simple conversion methods because, well, simple. #[test] fn into_relative() { assert_eq!( - Dname::from_octets(b"\x03www\0".as_ref()) + Name::from_octets(b"\x03www\0".as_ref()) .unwrap() .into_relative() .as_slice(), @@ -1342,20 +1401,19 @@ pub(crate) mod test { #[test] #[cfg(feature = "std")] fn make_canonical() { - let mut name = - RelativeDname::vec_from_str("wWw.exAmpLE.coM").unwrap(); + let mut name = RelativeName::vec_from_str("wWw.exAmpLE.coM").unwrap(); name.make_canonical(); assert_eq!( name, - RelativeDname::from_octets(b"\x03www\x07example\x03com").unwrap() + RelativeName::from_octets(b"\x03www\x07example\x03com").unwrap() ); } #[test] fn is_root() { - assert!(Dname::from_slice(b"\0").unwrap().is_root()); - assert!(!Dname::from_slice(b"\x03www\0").unwrap().is_root()); - assert!(Dname::root_ref().is_root()); + assert!(Name::from_slice(b"\0").unwrap().is_root()); + assert!(!Name::from_slice(b"\x03www\0").unwrap().is_root()); + assert!(Name::root_ref().is_root()); } pub fn cmp_iter(mut iter: I, labels: &[&[u8]]) @@ -1378,9 +1436,9 @@ pub(crate) mod test { #[test] fn iter() { - cmp_iter(Dname::root_ref().iter(), &[b""]); + cmp_iter(Name::root_ref().iter(), &[b""]); cmp_iter( - Dname::from_slice(b"\x03www\x07example\x03com\0") + Name::from_slice(b"\x03www\x07example\x03com\0") .unwrap() .iter(), &[b"www", b"example", b"com", b""], @@ -1407,9 +1465,9 @@ pub(crate) mod test { #[test] fn iter_back() { - cmp_iter_back(Dname::root_ref().iter(), &[b""]); + cmp_iter_back(Name::root_ref().iter(), &[b""]); cmp_iter_back( - Dname::from_slice(b"\x03www\x07example\x03com\0") + Name::from_slice(b"\x03www\x07example\x03com\0") .unwrap() .iter(), &[b"", b"com", b"example", b"www"], @@ -1418,9 +1476,9 @@ pub(crate) mod test { #[test] fn iter_suffixes() { - cmp_iter(Dname::root_ref().iter_suffixes(), &[b"\0"]); + cmp_iter(Name::root_ref().iter_suffixes(), &[b"\0"]); cmp_iter( - Dname::from_octets(b"\x03www\x07example\x03com\0".as_ref()) + Name::from_octets(b"\x03www\x07example\x03com\0".as_ref()) .unwrap() .iter_suffixes(), &[ @@ -1434,9 +1492,9 @@ pub(crate) mod test { #[test] fn label_count() { - assert_eq!(Dname::root_ref().label_count(), 1); + assert_eq!(Name::root_ref().label_count(), 1); assert_eq!( - Dname::from_slice(b"\x03www\x07example\x03com\0") + Name::from_slice(b"\x03www\x07example\x03com\0") .unwrap() .label_count(), 4 @@ -1445,9 +1503,9 @@ pub(crate) mod test { #[test] fn first() { - assert_eq!(Dname::root_ref().first().as_slice(), b""); + assert_eq!(Name::root_ref().first().as_slice(), b""); assert_eq!( - Dname::from_slice(b"\x03www\x07example\x03com\0") + Name::from_slice(b"\x03www\x07example\x03com\0") .unwrap() .first() .as_slice(), @@ -1457,9 +1515,9 @@ pub(crate) mod test { #[test] fn last() { - assert_eq!(Dname::root_ref().last().as_slice(), b""); + assert_eq!(Name::root_ref().last().as_slice(), b""); assert_eq!( - Dname::from_slice(b"\x03www\x07example\x03com\0") + Name::from_slice(b"\x03www\x07example\x03com\0") .unwrap() .last() .as_slice(), @@ -1469,45 +1527,44 @@ pub(crate) mod test { #[test] fn starts_with() { - let root = Dname::root_ref(); - let wecr = - Dname::from_octets(b"\x03www\x07example\x03com\0".as_ref()) - .unwrap(); + let root = Name::root_ref(); + let wecr = Name::from_octets(b"\x03www\x07example\x03com\0".as_ref()) + .unwrap(); assert!(root.starts_with(&root)); assert!(wecr.starts_with(&wecr)); - assert!(root.starts_with(&RelativeDname::empty_ref())); - assert!(wecr.starts_with(&RelativeDname::empty_ref())); + assert!(root.starts_with(&RelativeName::empty_ref())); + assert!(wecr.starts_with(&RelativeName::empty_ref())); - let test = RelativeDname::from_slice(b"\x03www").unwrap(); + let test = RelativeName::from_slice(b"\x03www").unwrap(); assert!(!root.starts_with(&test)); assert!(wecr.starts_with(&test)); - let test = RelativeDname::from_slice(b"\x03www\x07example").unwrap(); + let test = RelativeName::from_slice(b"\x03www\x07example").unwrap(); assert!(!root.starts_with(&test)); assert!(wecr.starts_with(&test)); let test = - RelativeDname::from_slice(b"\x03www\x07example\x03com").unwrap(); + RelativeName::from_slice(b"\x03www\x07example\x03com").unwrap(); assert!(!root.starts_with(&test)); assert!(wecr.starts_with(&test)); - let test = RelativeDname::from_slice(b"\x07example\x03com").unwrap(); + let test = RelativeName::from_slice(b"\x07example\x03com").unwrap(); assert!(!root.starts_with(&test)); assert!(!wecr.starts_with(&test)); - let test = RelativeDname::from_octets(b"\x03www".as_ref()) + let test = RelativeName::from_octets(b"\x03www".as_ref()) .unwrap() .chain( - RelativeDname::from_octets(b"\x07example".as_ref()).unwrap(), + RelativeName::from_octets(b"\x07example".as_ref()).unwrap(), ) .unwrap(); assert!(!root.starts_with(&test)); assert!(wecr.starts_with(&test)); let test = test - .chain(RelativeDname::from_octets(b"\x03com".as_ref()).unwrap()) + .chain(RelativeName::from_octets(b"\x03com".as_ref()).unwrap()) .unwrap(); assert!(!root.starts_with(&test)); assert!(wecr.starts_with(&test)); @@ -1515,10 +1572,9 @@ pub(crate) mod test { #[test] fn ends_with() { - let root = Dname::root_ref(); - let wecr = - Dname::from_octets(b"\x03www\x07example\x03com\0".as_ref()) - .unwrap(); + let root = Name::root_ref(); + let wecr = Name::from_octets(b"\x03www\x07example\x03com\0".as_ref()) + .unwrap(); for name in wecr.iter_suffixes() { if name.is_root() { @@ -1532,7 +1588,7 @@ pub(crate) mod test { #[test] fn is_label_start() { - let wecr = Dname::from_slice(b"\x03www\x07example\x03com\0").unwrap(); + let wecr = Name::from_slice(b"\x03www\x07example\x03com\0").unwrap(); assert!(wecr.is_label_start(0)); // \x03 assert!(!wecr.is_label_start(1)); // w @@ -1558,7 +1614,7 @@ pub(crate) mod test { #[test] #[cfg(feature = "std")] fn slice() { - let wecr = Dname::from_slice(b"\x03www\x07example\x03com\0").unwrap(); + let wecr = Name::from_slice(b"\x03www\x07example\x03com\0").unwrap(); assert_eq!(wecr.slice(..4).as_slice(), b"\x03www"); assert_eq!(wecr.slice(..12).as_slice(), b"\x03www\x07example"); @@ -1577,7 +1633,7 @@ pub(crate) mod test { #[test] #[cfg(feature = "std")] fn slice_from() { - let wecr = Dname::from_slice(b"\x03www\x07example\x03com\0").unwrap(); + let wecr = Name::from_slice(b"\x03www\x07example\x03com\0").unwrap(); assert_eq!( wecr.slice_from(0).as_slice(), @@ -1594,9 +1650,8 @@ pub(crate) mod test { #[test] #[cfg(feature = "std")] fn range() { - let wecr = - Dname::from_octets(b"\x03www\x07example\x03com\0".as_ref()) - .unwrap(); + let wecr = Name::from_octets(b"\x03www\x07example\x03com\0".as_ref()) + .unwrap(); assert_eq!(wecr.range(0..4).as_slice(), b"\x03www"); assert_eq!(wecr.range(0..12).as_slice(), b"\x03www\x07example"); @@ -1615,9 +1670,8 @@ pub(crate) mod test { #[test] #[cfg(feature = "std")] fn range_from() { - let wecr = - Dname::from_octets(b"\x03www\x07example\x03com\0".as_ref()) - .unwrap(); + let wecr = Name::from_octets(b"\x03www\x07example\x03com\0".as_ref()) + .unwrap(); assert_eq!( wecr.range_from(0).as_slice(), @@ -1634,9 +1688,8 @@ pub(crate) mod test { #[test] #[cfg(feature = "std")] fn split() { - let wecr = - Dname::from_octets(b"\x03www\x07example\x03com\0".as_ref()) - .unwrap(); + let wecr = Name::from_octets(b"\x03www\x07example\x03com\0".as_ref()) + .unwrap(); let (left, right) = wecr.split(0); assert_eq!(left.as_slice(), b""); @@ -1663,9 +1716,8 @@ pub(crate) mod test { #[test] #[cfg(feature = "std")] fn truncate() { - let wecr = - Dname::from_octets(b"\x03www\x07example\x03com\0".as_ref()) - .unwrap(); + let wecr = Name::from_octets(b"\x03www\x07example\x03com\0".as_ref()) + .unwrap(); assert_eq!(wecr.clone().truncate(0).as_slice(), b""); assert_eq!(wecr.clone().truncate(4).as_slice(), b"\x03www"); @@ -1686,9 +1738,8 @@ pub(crate) mod test { #[test] fn split_first() { - let wecr = - Dname::from_octets(b"\x03www\x07example\x03com\0".as_ref()) - .unwrap(); + let wecr = Name::from_octets(b"\x03www\x07example\x03com\0".as_ref()) + .unwrap(); let (label, wecr) = wecr.split_first().unwrap(); assert_eq!(label, b"www".as_ref()); @@ -1706,9 +1757,8 @@ pub(crate) mod test { #[test] fn parent() { - let wecr = - Dname::from_octets(b"\x03www\x07example\x03com\0".as_ref()) - .unwrap(); + let wecr = Name::from_octets(b"\x03www\x07example\x03com\0".as_ref()) + .unwrap(); let wecr = wecr.parent().unwrap(); assert_eq!(wecr.as_slice(), b"\x07example\x03com\0"); @@ -1721,18 +1771,16 @@ pub(crate) mod test { #[test] fn strip_suffix() { - let wecr = - Dname::from_octets(b"\x03www\x07example\x03com\0".as_ref()) - .unwrap(); + let wecr = Name::from_octets(b"\x03www\x07example\x03com\0".as_ref()) + .unwrap(); let ecr = - Dname::from_octets(b"\x07example\x03com\0".as_ref()).unwrap(); - let cr = Dname::from_octets(b"\x03com\0".as_ref()).unwrap(); - let wenr = - Dname::from_octets(b"\x03www\x07example\x03net\0".as_ref()) - .unwrap(); + Name::from_octets(b"\x07example\x03com\0".as_ref()).unwrap(); + let cr = Name::from_octets(b"\x03com\0".as_ref()).unwrap(); + let wenr = Name::from_octets(b"\x03www\x07example\x03net\0".as_ref()) + .unwrap(); let enr = - Dname::from_octets(b"\x07example\x03net\0".as_ref()).unwrap(); - let nr = Dname::from_octets(b"\x03net\0".as_ref()).unwrap(); + Name::from_octets(b"\x07example\x03net\0".as_ref()).unwrap(); + let nr = Name::from_octets(b"\x03net\0".as_ref()).unwrap(); assert_eq!(wecr.clone().strip_suffix(&wecr).unwrap().as_slice(), b""); assert_eq!( @@ -1745,7 +1793,7 @@ pub(crate) mod test { ); assert_eq!( wecr.clone() - .strip_suffix(&Dname::root_slice()) + .strip_suffix(&Name::root_slice()) .unwrap() .as_slice(), b"\x03www\x07example\x03com" @@ -1771,30 +1819,30 @@ pub(crate) mod test { // Parse a correctly formatted name. let mut p = Parser::from_static(b"\x03www\x07example\x03com\0af"); assert_eq!( - Dname::parse(&mut p).unwrap().as_slice(), + Name::parse(&mut p).unwrap().as_slice(), b"\x03www\x07example\x03com\0" ); assert_eq!(p.peek_all(), b"af"); // Short buffer in middle of label. let mut p = Parser::from_static(b"\x03www\x07exam"); - assert_eq!(Dname::parse(&mut p), Err(ParseError::ShortInput)); + assert_eq!(Name::parse(&mut p), Err(ParseError::ShortInput)); // Short buffer at end of label. let mut p = Parser::from_static(b"\x03www\x07example"); - assert_eq!(Dname::parse(&mut p), Err(ParseError::ShortInput)); + assert_eq!(Name::parse(&mut p), Err(ParseError::ShortInput)); // Compressed name. let mut p = Parser::from_static(b"\x03com\x03www\x07example\xc0\0"); p.advance(4).unwrap(); assert_eq!( - Dname::parse(&mut p), - Err(DnameError(DnameErrorEnum::CompressedName).into()) + Name::parse(&mut p), + Err(NameError(DnameErrorEnum::CompressedName).into()) ); // Bad label header. let mut p = Parser::from_static(b"\x03www\x07example\xbffoo"); - assert!(Dname::parse(&mut p).is_err()); + assert!(Name::parse(&mut p).is_err()); // Long name: 255 bytes is fine. let mut buf = Vec::new(); @@ -1804,7 +1852,7 @@ pub(crate) mod test { buf.extend_from_slice(b"\x03123\0"); assert_eq!(buf.len(), 255); let mut p = Parser::from_ref(buf.as_slice()); - assert!(Dname::parse(&mut p).is_ok()); + assert!(Name::parse(&mut p).is_ok()); assert_eq!(p.peek_all(), b""); // Long name: 256 bytes are bad. @@ -1816,8 +1864,8 @@ pub(crate) mod test { assert_eq!(buf.len(), 256); let mut p = Parser::from_ref(buf.as_slice()); assert_eq!( - Dname::parse(&mut p), - Err(DnameError(DnameErrorEnum::LongName).into()) + Name::parse(&mut p), + Err(NameError(DnameErrorEnum::LongName).into()) ); } @@ -1831,7 +1879,7 @@ pub(crate) mod test { let mut buf = Vec::new(); infallible( - Dname::from_slice(b"\x03wWw\x07exaMPle\x03com\0") + Name::from_slice(b"\x03wWw\x07exaMPle\x03com\0") .unwrap() .compose_canonical(&mut buf), ); @@ -1841,25 +1889,22 @@ pub(crate) mod test { #[test] #[cfg(feature = "std")] fn from_str() { - // Another simple test. `DnameBuilder` does all the heavy lifting, + // Another simple test. `NameBuilder` does all the heavy lifting, // so we don’t need to test all the escape sequence shenanigans here. // Just check that we’ll always get a name, final dot or not, unless // the string is empty. use core::str::FromStr; use std::vec::Vec; + assert_eq!(Name::>::from_str(".").unwrap().as_slice(), b"\0"); assert_eq!( - Dname::>::from_str(".").unwrap().as_slice(), - b"\0" - ); - assert_eq!( - Dname::>::from_str("www.example.com") + Name::>::from_str("www.example.com") .unwrap() .as_slice(), b"\x03www\x07example\x03com\0" ); assert_eq!( - Dname::>::from_str("www.example.com.") + Name::>::from_str("www.example.com.") .unwrap() .as_slice(), b"\x03www\x07example\x03com\0" @@ -1869,57 +1914,51 @@ pub(crate) mod test { #[test] fn eq() { assert_eq!( - Dname::from_slice(b"\x03www\x07example\x03com\0").unwrap(), - Dname::from_slice(b"\x03www\x07example\x03com\0").unwrap() + Name::from_slice(b"\x03www\x07example\x03com\0").unwrap(), + Name::from_slice(b"\x03www\x07example\x03com\0").unwrap() ); assert_eq!( - Dname::from_slice(b"\x03www\x07example\x03com\0").unwrap(), - Dname::from_slice(b"\x03wWw\x07eXAMple\x03Com\0").unwrap() + Name::from_slice(b"\x03www\x07example\x03com\0").unwrap(), + Name::from_slice(b"\x03wWw\x07eXAMple\x03Com\0").unwrap() ); assert_eq!( - Dname::from_slice(b"\x03www\x07example\x03com\0").unwrap(), - &RelativeDname::from_octets(b"\x03www".as_ref()) + Name::from_slice(b"\x03www\x07example\x03com\0").unwrap(), + &RelativeName::from_octets(b"\x03www".as_ref()) .unwrap() .chain( - RelativeDname::from_octets( - b"\x07example\x03com".as_ref() - ) - .unwrap() + RelativeName::from_octets(b"\x07example\x03com".as_ref()) + .unwrap() ) .unwrap() - .chain(Dname::root_ref()) + .chain(Name::root_ref()) .unwrap() ); assert_eq!( - Dname::from_slice(b"\x03www\x07example\x03com\0").unwrap(), - &RelativeDname::from_octets(b"\x03wWw".as_ref()) + Name::from_slice(b"\x03www\x07example\x03com\0").unwrap(), + &RelativeName::from_octets(b"\x03wWw".as_ref()) .unwrap() .chain( - RelativeDname::from_octets( - b"\x07eXAMple\x03coM".as_ref() - ) - .unwrap() + RelativeName::from_octets(b"\x07eXAMple\x03coM".as_ref()) + .unwrap() ) .unwrap() - .chain(Dname::root_ref()) + .chain(Name::root_ref()) .unwrap() ); assert_ne!( - Dname::from_slice(b"\x03www\x07example\x03com\0").unwrap(), - Dname::from_slice(b"\x03ww4\x07example\x03com\0").unwrap() + Name::from_slice(b"\x03www\x07example\x03com\0").unwrap(), + Name::from_slice(b"\x03ww4\x07example\x03com\0").unwrap() ); assert_ne!( - Dname::from_slice(b"\x03www\x07example\x03com\0").unwrap(), - &RelativeDname::from_octets(b"\x03www".as_ref()) + Name::from_slice(b"\x03www\x07example\x03com\0").unwrap(), + &RelativeName::from_octets(b"\x03www".as_ref()) .unwrap() .chain( - RelativeDname::from_octets( - b"\x073xample\x03com".as_ref() - ) - .unwrap() + RelativeName::from_octets(b"\x073xample\x03com".as_ref()) + .unwrap() ) .unwrap() - .chain(Dname::root_ref()) + .chain(Name::root_ref()) .unwrap() ); } @@ -1930,15 +1969,15 @@ pub(crate) mod test { // The following is taken from section 6.1 of RFC 4034. let names = [ - Dname::from_slice(b"\x07example\0").unwrap(), - Dname::from_slice(b"\x01a\x07example\0").unwrap(), - Dname::from_slice(b"\x08yljkjljk\x01a\x07example\0").unwrap(), - Dname::from_slice(b"\x01Z\x01a\x07example\0").unwrap(), - Dname::from_slice(b"\x04zABC\x01a\x07example\0").unwrap(), - Dname::from_slice(b"\x01z\x07example\0").unwrap(), - Dname::from_slice(b"\x01\x01\x01z\x07example\0").unwrap(), - Dname::from_slice(b"\x01*\x01z\x07example\0").unwrap(), - Dname::from_slice(b"\x01\xc8\x01z\x07example\0").unwrap(), + Name::from_slice(b"\x07example\0").unwrap(), + Name::from_slice(b"\x01a\x07example\0").unwrap(), + Name::from_slice(b"\x08yljkjljk\x01a\x07example\0").unwrap(), + Name::from_slice(b"\x01Z\x01a\x07example\0").unwrap(), + Name::from_slice(b"\x04zABC\x01a\x07example\0").unwrap(), + Name::from_slice(b"\x01z\x07example\0").unwrap(), + Name::from_slice(b"\x01\x01\x01z\x07example\0").unwrap(), + Name::from_slice(b"\x01*\x01z\x07example\0").unwrap(), + Name::from_slice(b"\x01\xc8\x01z\x07example\0").unwrap(), ]; for i in 0..names.len() { for j in 0..names.len() { @@ -1948,8 +1987,8 @@ pub(crate) mod test { } } - let n1 = Dname::from_slice(b"\x03www\x07example\x03com\0").unwrap(); - let n2 = Dname::from_slice(b"\x03wWw\x07eXAMple\x03Com\0").unwrap(); + let n1 = Name::from_slice(b"\x03www\x07example\x03com\0").unwrap(); + let n2 = Name::from_slice(b"\x03wWw\x07eXAMple\x03Com\0").unwrap(); assert_eq!(n1.partial_cmp(n2), Some(Ordering::Equal)); assert_eq!(n1.cmp(n2), Ordering::Equal); } @@ -1962,10 +2001,10 @@ pub(crate) mod test { let mut s1 = DefaultHasher::new(); let mut s2 = DefaultHasher::new(); - Dname::from_slice(b"\x03www\x07example\x03com\0") + Name::from_slice(b"\x03www\x07example\x03com\0") .unwrap() .hash(&mut s1); - Dname::from_slice(b"\x03wWw\x07eXAMple\x03Com\0") + Name::from_slice(b"\x03wWw\x07eXAMple\x03Com\0") .unwrap() .hash(&mut s2); assert_eq!(s1.finish(), s2.finish()); @@ -1979,7 +2018,7 @@ pub(crate) mod test { use std::string::ToString; fn cmp(bytes: &[u8], fmt: &str, fmt_with_dot: &str) { - let name = Dname::from_octets(bytes).unwrap(); + let name = Name::from_octets(bytes).unwrap(); assert_eq!(name.to_string(), fmt); assert_eq!(format!("{}", name.fmt_with_dot()), fmt_with_dot); } @@ -1994,24 +2033,24 @@ pub(crate) mod test { fn ser_de() { use serde_test::{assert_tokens, Configure, Token}; - let name = Dname::>::from_str("www.example.com.").unwrap(); + let name = Name::>::from_str("www.example.com.").unwrap(); assert_tokens( &name.clone().compact(), &[ - Token::NewtypeStruct { name: "Dname" }, + Token::NewtypeStruct { name: "Name" }, Token::ByteBuf(b"\x03www\x07example\x03com\0"), ], ); assert_tokens( &name.readable(), &[ - Token::NewtypeStruct { name: "Dname" }, + Token::NewtypeStruct { name: "Name" }, Token::Str("www.example.com"), ], ); assert_tokens( - &Dname::root_vec().readable(), - &[Token::NewtypeStruct { name: "Dname" }, Token::Str(".")], + &Name::root_vec().readable(), + &[Token::NewtypeStruct { name: "Name" }, Token::Str(".")], ); } } diff --git a/src/base/name/builder.rs b/src/base/name/builder.rs index c6896a497..41cde5110 100644 --- a/src/base/name/builder.rs +++ b/src/base/name/builder.rs @@ -1,12 +1,12 @@ //! Building a domain name. //! -//! This is a private module for tidiness. `DnameBuilder` and `PushError` +//! This is a private module for tidiness. `NameBuilder` and `PushError` //! are re-exported by the parent module. use super::super::scan::{BadSymbol, Symbol, SymbolCharsError, Symbols}; -use super::dname::Dname; -use super::relative::{RelativeDname, RelativeDnameError}; -use super::traits::{ToDname, ToRelativeDname}; +use super::absolute::Name; +use super::relative::{RelativeName, RelativeNameError}; +use super::traits::{ToName, ToRelativeName}; use super::Label; #[cfg(feature = "bytes")] use bytes::BytesMut; @@ -15,7 +15,7 @@ use octseq::builder::{EmptyBuilder, FreezeBuilder, OctetsBuilder, ShortBuf}; #[cfg(feature = "std")] use std::vec::Vec; -//------------ DnameBuilder -------------------------------------------------- +//------------ NameBuilder -------------------------------------------------- /// Builds a domain name step by step by appending data. /// @@ -34,7 +34,7 @@ use std::vec::Vec; /// The name builder currently is not aware of internationalized domain /// names. The octets passed to it are used as is and are not converted. #[derive(Clone)] -pub struct DnameBuilder { +pub struct NameBuilder { /// The buffer to build the name in. builder: Builder, @@ -44,14 +44,14 @@ pub struct DnameBuilder { head: Option, } -impl DnameBuilder { +impl NameBuilder { /// Creates a new domain name builder from an octets builder. /// /// Whatever is in the buffer already is considered to be a relative /// domain name. Since that may not be the case, this function is /// unsafe. pub(super) unsafe fn from_builder_unchecked(builder: Builder) -> Self { - DnameBuilder { + NameBuilder { builder, head: None, } @@ -63,7 +63,7 @@ impl DnameBuilder { where Builder: EmptyBuilder, { - unsafe { DnameBuilder::from_builder_unchecked(Builder::empty()) } + unsafe { NameBuilder::from_builder_unchecked(Builder::empty()) } } /// Creates a new, empty builder with a given capacity. @@ -73,7 +73,7 @@ impl DnameBuilder { Builder: EmptyBuilder, { unsafe { - DnameBuilder::from_builder_unchecked(Builder::with_capacity( + NameBuilder::from_builder_unchecked(Builder::with_capacity( capacity, )) } @@ -83,17 +83,17 @@ impl DnameBuilder { /// /// The function checks that whatever is in the builder already /// consititutes a correctly encoded relative domain name. - pub fn from_builder(builder: Builder) -> Result + pub fn from_builder(builder: Builder) -> Result where Builder: OctetsBuilder + AsRef<[u8]>, { - RelativeDname::check_slice(builder.as_ref())?; - Ok(unsafe { DnameBuilder::from_builder_unchecked(builder) }) + RelativeName::check_slice(builder.as_ref())?; + Ok(unsafe { NameBuilder::from_builder_unchecked(builder) }) } } #[cfg(feature = "std")] -impl DnameBuilder> { +impl NameBuilder> { /// Creates an empty domain name builder atop a `Vec`. #[must_use] pub fn new_vec() -> Self { @@ -111,7 +111,7 @@ impl DnameBuilder> { } #[cfg(feature = "bytes")] -impl DnameBuilder { +impl NameBuilder { /// Creates an empty domain name bulider atop a bytes value. pub fn new_bytes() -> Self { Self::new() @@ -126,7 +126,7 @@ impl DnameBuilder { } } -impl> DnameBuilder { +impl> NameBuilder { /// Returns the already assembled domain name as an octets slice. pub fn as_slice(&self) -> &[u8] { self.builder.as_ref() @@ -143,7 +143,7 @@ impl> DnameBuilder { } } -impl DnameBuilder +impl NameBuilder where Builder: OctetsBuilder + AsRef<[u8]> + AsMut<[u8]>, { @@ -266,6 +266,70 @@ where Ok(()) } + /// Appends a label with the decimal representation of `u8`. + /// + /// If there currently is a label under construction, it will be ended + /// before appending `label`. + /// + /// Returns an error if appending would result in a name longer than 254 + /// bytes. + pub fn append_dec_u8_label( + &mut self, + value: u8, + ) -> Result<(), PushError> { + self.end_label(); + let hecto = value / 100; + if hecto > 0 { + self.push(hecto + b'0')?; + } + let deka = (value / 10) % 10; + if hecto > 0 || deka > 0 { + self.push(deka + b'0')?; + } + self.push(value % 10 + b'0')?; + self.end_label(); + Ok(()) + } + + /// Appends a label with the hex digit. + /// + /// If there currently is a label under construction, it will be ended + /// before appending `label`. + /// + /// Returns an error if appending would result in a name longer than 254 + /// bytes. + pub fn append_hex_digit_label( + &mut self, + nibble: u8, + ) -> Result<(), PushError> { + fn hex_digit(nibble: u8) -> u8 { + match nibble & 0x0F { + 0 => b'0', + 1 => b'1', + 2 => b'2', + 3 => b'3', + 4 => b'4', + 5 => b'5', + 6 => b'6', + 7 => b'7', + 8 => b'8', + 9 => b'9', + 10 => b'A', + 11 => b'B', + 12 => b'C', + 13 => b'D', + 14 => b'E', + 15 => b'F', + _ => unreachable!(), + } + } + + self.end_label(); + self.push(hex_digit(nibble))?; + self.end_label(); + Ok(()) + } + /// Appends a relative domain name. /// /// If there currently is a label under construction, it will be ended @@ -275,7 +339,7 @@ where /// bytes. // // XXX NEEDS TESTS - pub fn append_name( + pub fn append_name( &mut self, name: &N, ) -> Result<(), PushNameError> { @@ -346,48 +410,48 @@ where /// explicitely. /// /// This method converts the builder into a relative name. If you would - /// like to turn it into an absolute name, use [`into_dname`] which + /// like to turn it into an absolute name, use [`into_name`] which /// appends the root label before finishing. /// /// [`end_label`]: #method.end_label - /// [`into_dname`]: #method.into_dname - pub fn finish(mut self) -> RelativeDname + /// [`into_name`]: #method.into_name + pub fn finish(mut self) -> RelativeName where Builder: FreezeBuilder, { self.end_label(); - unsafe { RelativeDname::from_octets_unchecked(self.builder.freeze()) } + unsafe { RelativeName::from_octets_unchecked(self.builder.freeze()) } } - /// Appends the root label to the name and returns it as a `Dname`. + /// Appends the root label to the name and returns it as a `Name`. /// /// If there currently is a label under construction, ends the label. /// Then adds the empty root label and transforms the name into a - /// `Dname`. - pub fn into_dname(mut self) -> Result, PushError> + /// `Name`. + pub fn into_name(mut self) -> Result, PushError> where Builder: FreezeBuilder, { self.end_label(); self._append_slice(&[0])?; - Ok(unsafe { Dname::from_octets_unchecked(self.builder.freeze()) }) + Ok(unsafe { Name::from_octets_unchecked(self.builder.freeze()) }) } - /// Appends an origin and returns the resulting `Dname`. + /// Appends an origin and returns the resulting `Name`. /// If there currently is a label under construction, ends the label. /// Then adds the `origin` and transforms the name into a - /// `Dname`. + /// `Name`. // // XXX NEEDS TESTS - pub fn append_origin( + pub fn append_origin( mut self, origin: &N, - ) -> Result, PushNameError> + ) -> Result, PushNameError> where Builder: FreezeBuilder, { self.end_label(); - if self.len() + usize::from(origin.compose_len()) > Dname::MAX_LEN { + if self.len() + usize::from(origin.compose_len()) > Name::MAX_LEN { return Err(PushNameError::LongName); } for label in origin.iter_labels() { @@ -395,13 +459,13 @@ where .compose(&mut self.builder) .map_err(|_| PushNameError::ShortBuf)?; } - Ok(unsafe { Dname::from_octets_unchecked(self.builder.freeze()) }) + Ok(unsafe { Name::from_octets_unchecked(self.builder.freeze()) }) } } //--- Default -impl Default for DnameBuilder { +impl Default for NameBuilder { fn default() -> Self { Self::new() } @@ -409,7 +473,7 @@ impl Default for DnameBuilder { //--- AsRef -impl> AsRef<[u8]> for DnameBuilder { +impl> AsRef<[u8]> for NameBuilder { fn as_ref(&self) -> &[u8] { self.builder.as_ref() } @@ -708,7 +772,7 @@ mod test { #[test] fn compose() { - let mut builder = DnameBuilder::new_vec(); + let mut builder = NameBuilder::new_vec(); builder.push(b'w').unwrap(); builder.append_slice(b"ww").unwrap(); builder.end_label(); @@ -723,7 +787,7 @@ mod test { #[test] fn build_by_label() { - let mut builder = DnameBuilder::new_vec(); + let mut builder = NameBuilder::new_vec(); builder.append_label(b"www").unwrap(); builder.append_label(b"example").unwrap(); builder.append_label(b"com").unwrap(); @@ -732,7 +796,7 @@ mod test { #[test] fn build_mixed() { - let mut builder = DnameBuilder::new_vec(); + let mut builder = NameBuilder::new_vec(); builder.push(b'w').unwrap(); builder.append_slice(b"ww").unwrap(); builder.append_label(b"example").unwrap(); @@ -742,7 +806,7 @@ mod test { #[test] fn name_limit() { - let mut builder = DnameBuilder::new_vec(); + let mut builder = NameBuilder::new_vec(); for _ in 0..25 { // 9 bytes label is 10 bytes in total builder.append_label(b"123456789").unwrap(); @@ -761,7 +825,7 @@ mod test { #[test] fn label_limit() { - let mut builder = DnameBuilder::new_vec(); + let mut builder = NameBuilder::new_vec(); builder.append_label(&[0u8; 63][..]).unwrap(); assert_eq!( builder.append_label(&[0u8; 64][..]), @@ -782,7 +846,7 @@ mod test { #[test] fn finish() { - let mut builder = DnameBuilder::new_vec(); + let mut builder = NameBuilder::new_vec(); builder.append_label(b"www").unwrap(); builder.append_label(b"example").unwrap(); builder.append_slice(b"com").unwrap(); @@ -790,13 +854,13 @@ mod test { } #[test] - fn into_dname() { - let mut builder = DnameBuilder::new_vec(); + fn into_name() { + let mut builder = NameBuilder::new_vec(); builder.append_label(b"www").unwrap(); builder.append_label(b"example").unwrap(); builder.append_slice(b"com").unwrap(); assert_eq!( - builder.into_dname().unwrap().as_slice(), + builder.into_name().unwrap().as_slice(), b"\x03www\x07example\x03com\x00" ); } diff --git a/src/base/name/chain.rs b/src/base/name/chain.rs index efa6dccea..1d56198c5 100644 --- a/src/base/name/chain.rs +++ b/src/base/name/chain.rs @@ -5,10 +5,10 @@ use super::super::scan::Scanner; use super::label::Label; -use super::relative::DnameIter; -use super::traits::{FlattenInto, ToDname, ToLabelIter, ToRelativeDname}; -use super::uncertain::UncertainDname; -use super::Dname; +use super::relative::NameIter; +use super::traits::{FlattenInto, ToLabelIter, ToName, ToRelativeName}; +use super::uncertain::UncertainName; +use super::Name; use core::{fmt, iter}; use octseq::builder::{ BuilderAppendError, EmptyBuilder, FreezeBuilder, FromBuilder, @@ -19,20 +19,20 @@ use octseq::builder::{ /// Two domain names chained together. /// /// This type is the result of calling the `chain` method on -/// [`RelativeDname`], [`UncertainDname`], or on [`Chain`] itself. +/// [`RelativeName`], [`UncertainName`], or on [`Chain`] itself. /// /// The chain can be both an absolute or relative domain name—and implements -/// the respective traits [`ToDname`] or [`ToRelativeDname`]—, depending on +/// the respective traits [`ToName`] or [`ToRelativeName`]—, depending on /// whether the second name is absolute or relative. /// /// A chain on an uncertain name is special in that the second name is only /// used if the uncertain name is relative. /// -/// [`RelativeDname`]: struct.RelativeDname.html#method.chain +/// [`RelativeName`]: struct.RelativeName.html#method.chain /// [`Chain`]: #method.chain -/// [`ToDname`]: trait.ToDname.html -/// [`ToRelativeDname`]: trait.ToRelativeDname.html -/// [`UncertainDname`]: struct.UncertainDname.html#method.chain +/// [`ToName`]: trait.ToName.html +/// [`ToRelativeName`]: trait.ToRelativeName.html +/// [`UncertainName`]: struct.UncertainName.html#method.chain #[derive(Clone, Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct Chain { @@ -47,9 +47,9 @@ impl Chain { /// Creates a new chain from a first and second name. pub(super) fn new(left: L, right: R) -> Result { if usize::from(left.compose_len() + right.compose_len()) - > Dname::MAX_LEN + > Name::MAX_LEN { - // TODO can't infer a specific type for Dname here + // TODO can't infer a specific type for Name here Err(LongChainError(())) } else { Ok(Chain { left, right }) @@ -57,18 +57,18 @@ impl Chain { } } -impl, R: ToLabelIter> Chain, R> { +impl, R: ToLabelIter> Chain, R> { /// Creates a chain from an uncertain name. /// /// This function is separate because the ultimate size depends on the /// variant of the left name. pub(super) fn new_uncertain( - left: UncertainDname, + left: UncertainName, right: R, ) -> Result { - if let UncertainDname::Relative(ref name) = left { + if let UncertainName::Relative(ref name) = left { if usize::from(name.compose_len() + right.compose_len()) - > Dname::MAX_LEN + > Name::MAX_LEN { return Err(LongChainError(())); } @@ -78,27 +78,27 @@ impl, R: ToLabelIter> Chain, R> { } impl Chain { - pub fn scan>( + pub fn scan>( scanner: &mut S, ) -> Result { - scanner.scan_dname() + scanner.scan_name() } } -impl Chain { +impl Chain { /// Extends the chain with another domain name. /// /// While the method accepts anything [`Compose`] as the second element of - /// the chain, the resulting `Chain` will only implement [`ToDname`] or - /// [`ToRelativeDname`] if if also implements [`ToDname`] or - /// [`ToRelativeDname`], respectively. + /// the chain, the resulting `Chain` will only implement [`ToName`] or + /// [`ToRelativeName`] if if also implements [`ToName`] or + /// [`ToRelativeName`], respectively. /// /// The method will fail with an error if the chained name is longer than /// 255 bytes. /// /// [`Compose`]: ../compose/trait.Compose.html - /// [`ToDname`]: trait.ToDname.html - /// [`ToRelativeDname`]: trait.ToRelativeDname.html + /// [`ToName`]: trait.ToName.html + /// [`ToRelativeName`]: trait.ToRelativeName.html pub fn chain( self, other: N, @@ -129,9 +129,9 @@ impl Chain { } } -//--- ToLabelIter, ToRelativeDname, ToDname +//--- ToLabelIter, ToRelativeName, ToName -impl ToLabelIter for Chain { +impl ToLabelIter for Chain { type LabelIter<'a> = ChainIter<'a, L, R> where L: 'a, R: 'a; fn iter_labels(&self) -> Self::LabelIter<'_> { @@ -146,20 +146,20 @@ impl ToLabelIter for Chain { } } -impl ToLabelIter for Chain, R> +impl ToLabelIter for Chain, R> where Octs: AsRef<[u8]>, - R: ToDname, + R: ToName, { type LabelIter<'a> = UncertainChainIter<'a, Octs, R> where Octs: 'a, R: 'a; fn iter_labels(&self) -> Self::LabelIter<'_> { match self.left { - UncertainDname::Absolute(ref name) => { + UncertainName::Absolute(ref name) => { UncertainChainIter::Absolute(name.iter_labels()) } - UncertainDname::Relative(ref name) => { + UncertainName::Relative(ref name) => { UncertainChainIter::Relative(ChainIter( name.iter_labels().chain(self.right.iter_labels()), )) @@ -169,8 +169,8 @@ where fn compose_len(&self) -> u16 { match self.left { - UncertainDname::Absolute(ref name) => name.compose_len(), - UncertainDname::Relative(ref name) => name + UncertainName::Absolute(ref name) => name.compose_len(), + UncertainName::Relative(ref name) => name .compose_len() .checked_add(self.right.compose_len()) .expect("long domain name"), @@ -178,37 +178,37 @@ where } } -impl ToRelativeDname for Chain {} +impl ToRelativeName for Chain {} -impl ToDname for Chain {} +impl ToName for Chain {} -impl ToDname for Chain, R> +impl ToName for Chain, R> where Octets: AsRef<[u8]>, - R: ToDname, + R: ToName, { } //--- FlattenInto -impl FlattenInto> for Chain +impl FlattenInto> for Chain where - L: ToRelativeDname, - R: ToDname, - R: FlattenInto, AppendError = BuilderAppendError>, + L: ToRelativeName, + R: ToName, + R: FlattenInto, AppendError = BuilderAppendError>, Target: FromBuilder, ::Builder: EmptyBuilder, { type AppendError = BuilderAppendError; - fn try_flatten_into(self) -> Result, Self::AppendError> { + fn try_flatten_into(self) -> Result, Self::AppendError> { if self.left.is_empty() { self.right.try_flatten_into() } else { let mut builder = Target::Builder::with_capacity(self.compose_len().into()); self.compose(&mut builder)?; - Ok(unsafe { Dname::from_octets_unchecked(builder.freeze()) }) + Ok(unsafe { Name::from_octets_unchecked(builder.freeze()) }) } } } @@ -283,8 +283,8 @@ where /// The label iterator for domain name chains with uncertain domain names. pub enum UncertainChainIter<'a, Octets: AsRef<[u8]>, R: ToLabelIter> { - Absolute(DnameIter<'a>), - Relative(ChainIter<'a, UncertainDname, R>), + Absolute(NameIter<'a>), + Relative(ChainIter<'a, UncertainName, R>), } impl<'a, Octets, R> Clone for UncertainChainIter<'a, Octets, R> @@ -381,51 +381,49 @@ impl std::error::Error for LongChainError {} #[cfg(feature = "std")] mod test { use super::*; - use crate::base::name::RelativeDname; + use crate::base::name::RelativeName; use octseq::builder::infallible; - /// Tests that `ToDname` and `ToRelativeDname` are implemented for the + /// Tests that `ToName` and `ToRelativeName` are implemented for the /// right types. #[test] #[cfg(feature = "std")] fn impls() { - fn assert_to_dname(_: &T) {} - fn assert_to_relative_dname(_: &T) {} + fn assert_to_name(_: &T) {} + fn assert_to_relative_name(_: &T) {} - let rel = RelativeDname::empty_ref() - .chain(RelativeDname::empty_ref()) + let rel = RelativeName::empty_ref() + .chain(RelativeName::empty_ref()) .unwrap(); - assert_to_relative_dname(&rel); - assert_to_dname( - &RelativeDname::empty_ref().chain(Dname::root_ref()).unwrap(), + assert_to_relative_name(&rel); + assert_to_name( + &RelativeName::empty_ref().chain(Name::root_ref()).unwrap(), ); - assert_to_dname( - &RelativeDname::empty_ref() - .chain(RelativeDname::empty_ref()) + assert_to_name( + &RelativeName::empty_ref() + .chain(RelativeName::empty_ref()) .unwrap() - .chain(Dname::root_ref()) + .chain(Name::root_ref()) .unwrap(), ); - assert_to_dname(&rel.clone().chain(Dname::root_ref()).unwrap()); - assert_to_relative_dname( - &rel.chain(RelativeDname::empty_ref()).unwrap(), + assert_to_name(&rel.clone().chain(Name::root_ref()).unwrap()); + assert_to_relative_name( + &rel.chain(RelativeName::empty_ref()).unwrap(), ); - assert_to_dname( - &UncertainDname::root_vec().chain(Dname::root_vec()).unwrap(), + assert_to_name( + &UncertainName::root_vec().chain(Name::root_vec()).unwrap(), ); - assert_to_dname( - &UncertainDname::empty_vec() - .chain(Dname::root_vec()) - .unwrap(), + assert_to_name( + &UncertainName::empty_vec().chain(Name::root_vec()).unwrap(), ); } /// Tests that a chain never becomes too long. #[test] fn name_limit() { - use crate::base::name::DnameBuilder; + use crate::base::name::NameBuilder; - let mut builder = DnameBuilder::new_vec(); + let mut builder = NameBuilder::new_vec(); for _ in 0..25 { // 9 bytes label is 10 bytes in total builder.append_label(b"123456789").unwrap(); @@ -433,14 +431,14 @@ mod test { let left = builder.finish(); assert_eq!(left.len(), 250); - let mut builder = DnameBuilder::new_vec(); + let mut builder = NameBuilder::new_vec(); builder.append_slice(b"123").unwrap(); - let five_abs = builder.clone().into_dname().unwrap(); + let five_abs = builder.clone().into_name().unwrap(); assert_eq!(five_abs.len(), 5); builder.push(b'4').unwrap(); let five_rel = builder.clone().finish(); assert_eq!(five_rel.len(), 5); - let six_abs = builder.clone().into_dname().unwrap(); + let six_abs = builder.clone().into_name().unwrap(); assert_eq!(six_abs.len(), 6); builder.push(b'5').unwrap(); let six_rel = builder.finish(); @@ -469,11 +467,11 @@ mod test { .chain(five_rel) .is_err()); - let left = UncertainDname::from(left); + let left = UncertainName::from(left); assert_eq!(left.clone().chain(five_abs).unwrap().compose_len(), 255); assert!(left.clone().chain(six_abs.clone()).is_err()); - let left = UncertainDname::from(left.into_absolute().unwrap()); + let left = UncertainName::from(left.into_absolute().unwrap()); println!("{:?}", left); assert_eq!(left.chain(six_abs).unwrap().compose_len(), 251); } @@ -490,12 +488,12 @@ mod test { ); } - let w = RelativeDname::from_octets(b"\x03www".as_ref()).unwrap(); - let ec = RelativeDname::from_octets(b"\x07example\x03com".as_ref()) + let w = RelativeName::from_octets(b"\x03www".as_ref()).unwrap(); + let ec = RelativeName::from_octets(b"\x07example\x03com".as_ref()) .unwrap(); let ecr = - Dname::from_octets(b"\x07example\x03com\x00".as_ref()).unwrap(); - let fbr = Dname::from_octets(b"\x03foo\x03bar\x00".as_ref()).unwrap(); + Name::from_octets(b"\x07example\x03com\x00".as_ref()).unwrap(); + let fbr = Name::from_octets(b"\x03foo\x03bar\x00".as_ref()).unwrap(); check_impl( w.clone().chain(ec.clone()).unwrap(), @@ -509,25 +507,23 @@ mod test { w.clone() .chain(ec.clone()) .unwrap() - .chain(Dname::root_ref()) + .chain(Name::root_ref()) .unwrap(), &[b"www", b"example", b"com", b""], ); check_impl( - RelativeDname::empty_slice() - .chain(Dname::root_slice()) + RelativeName::empty_slice() + .chain(Name::root_slice()) .unwrap(), &[b""], ); check_impl( - UncertainDname::from(w.clone()).chain(ecr.clone()).unwrap(), + UncertainName::from(w.clone()).chain(ecr.clone()).unwrap(), &[b"www", b"example", b"com", b""], ); check_impl( - UncertainDname::from(ecr.clone()) - .chain(fbr.clone()) - .unwrap(), + UncertainName::from(ecr.clone()).chain(fbr.clone()).unwrap(), &[b"example", b"com", b""], ); } @@ -537,12 +533,12 @@ mod test { fn compose() { use std::vec::Vec; - let w = RelativeDname::from_octets(b"\x03www".as_ref()).unwrap(); - let ec = RelativeDname::from_octets(b"\x07example\x03com".as_ref()) + let w = RelativeName::from_octets(b"\x03www".as_ref()).unwrap(); + let ec = RelativeName::from_octets(b"\x07example\x03com".as_ref()) .unwrap(); let ecr = - Dname::from_octets(b"\x07example\x03com\x00".as_ref()).unwrap(); - let fbr = Dname::from_octets(b"\x03foo\x03bar\x00".as_ref()).unwrap(); + Name::from_octets(b"\x07example\x03com\x00".as_ref()).unwrap(); + let fbr = Name::from_octets(b"\x03foo\x03bar\x00".as_ref()).unwrap(); let mut buf = Vec::new(); infallible(w.clone().chain(ec.clone()).unwrap().compose(&mut buf)); @@ -557,7 +553,7 @@ mod test { w.clone() .chain(ec.clone()) .unwrap() - .chain(Dname::root_ref()) + .chain(Name::root_ref()) .unwrap() .compose(&mut buf), ); @@ -565,7 +561,7 @@ mod test { let mut buf = Vec::new(); infallible( - UncertainDname::from(w.clone()) + UncertainName::from(w.clone()) .chain(ecr.clone()) .unwrap() .compose(&mut buf), @@ -574,7 +570,7 @@ mod test { let mut buf = Vec::new(); infallible( - UncertainDname::from(ecr.clone()) + UncertainName::from(ecr.clone()) .chain(fbr.clone()) .unwrap() .compose(&mut buf), @@ -605,31 +601,31 @@ mod test { } // An empty relative name. - let empty = &RelativeDname::from_octets(b"".as_slice()).unwrap(); + let empty = &RelativeName::from_octets(b"".as_slice()).unwrap(); // An empty relative name wrapped in an uncertain name. - let uempty = &UncertainDname::from(empty.clone()); + let uempty = &UncertainName::from(empty.clone()); // A non-empty relative name. We are using two labels here just to // have that covered as well. let rel = - &RelativeDname::from_octets(b"\x03www\x07example".as_slice()) + &RelativeName::from_octets(b"\x03www\x07example".as_slice()) .unwrap(); // A non-empty relative name wrapped in an uncertain name. - let urel = &UncertainDname::from(rel.clone()); + let urel = &UncertainName::from(rel.clone()); // The root name which is an absolute name. - let root = &Dname::from_octets(b"\0".as_slice()).unwrap(); + let root = &Name::from_octets(b"\0".as_slice()).unwrap(); // The root name wrapped in an uncertain name. - let uroot = &UncertainDname::from(root.clone()); + let uroot = &UncertainName::from(root.clone()); // A “normal” absolute name. - let abs = &Dname::from_octets(b"\x03com\0".as_slice()).unwrap(); + let abs = &Name::from_octets(b"\x03com\0".as_slice()).unwrap(); // A “normal” absolute name wrapped in an uncertain name. - let uabs = &UncertainDname::from(abs.clone()); + let uabs = &UncertainName::from(abs.clone()); // Now we produce all possible cases and their expected result. First // result is for normal display, second is for fmt_with_dot. diff --git a/src/base/name/mod.rs b/src/base/name/mod.rs index f1aae3458..8b5f2b18e 100644 --- a/src/base/name/mod.rs +++ b/src/base/name/mod.rs @@ -2,13 +2,13 @@ //! //! This module provides various types for working with domain names. //! -//! Main types: [`Dname`], [`RelativeDname`], [`ParsedDname`], -//! [`UncertainDname`], [`DnameBuilder`].
-//! Main traits: [`ToDname`], [`ToRelativeDname`]. +//! Main types: [`Name`], [`RelativeName`], [`ParsedName`], [`UncertainName`], +//! [`NameBuilder`].
+//! Main traits: [`ToName`], [`ToRelativeName`]. //! -//! Domain names are a hierarchical description of the location of -//! records in a tree. They are formed from a sequence of *labels* that -//! describe the path through the tree upward from the leaf node to the root. +//! Domain names are a hierarchical description of the location of records in +//! a tree. They are formed from a sequence of *labels* that describe the path +//! through the tree upward from the leaf node to the root. //! //! ## Domain name representations //! @@ -30,8 +30,8 @@ //! name. In it, the octets of each label are interpreted as ASCII characters //! or, if there isn’t a printable one, as an escape sequence formed by a //! backslash followed by the three-digit decimal value of the octet. Labels -//! are separated by dots. If a dot (or a backslash) appears as an octet in -//! a label, they can be escaped by preceding them with a backslash. +//! are separated by dots. If a dot (or a backslash) appears as an octet in a +//! label, they can be escaped by preceding them with a backslash. //! //! This crate uses the presentation format when converting domain names from //! and to strings. @@ -48,18 +48,18 @@ //! //! In some cases, it is useful to have a domain name that doesn’t end with //! the root label. Such a name is called a *relative domain name* and, -//! conversely, a name that does end with the root label is called an -//! *abolute name*. Because these behave slightly differently, for instance, -//! you can’t include a relative name in a message, there are different -//! types for those two cases, [`Dname`] for absolute names and -//! [`RelativeDname`] for relative names. +//! conversely, a name that does end with the root label is called an *abolute +//! name*. Because these behave slightly differently, for instance, you can’t +//! include a relative name in a message, there are different types for those +//! two cases, [`Name`] for absolute names and [`RelativeName`] for relative +//! names. //! //! Sometimes, it isn’t quite clear if a domain name is absolute or relative. //! This happens in presentation format where the final dot at the end //! separating the empty and thus invisible root label is often omitted. For //! instance, instead of the strictly correct `www.example.com.` the slightly //! shorter `www.example.com` is accepted as an absolute name if it is clear -//! from context that the name is absolute. The [`UncertainDname`] type +//! from context that the name is absolute. The [`UncertainName`] type //! provides a means to keep such a name that may be absolute or relative. //! //! ## Name compression and parsed names. @@ -75,39 +75,39 @@ //! //! When making a relative name absolute to be included in a message, you //! often append a suffix to it. In order to avoid having to copy octets -//! around and make this cheap, the [`Chain`] type allows combining two -//! other name values. To make this work, the two traits [`ToDname`] -//! and [`ToRelativeDname`] allow writing code that is generic over any kind -//! of either absolute or relative domain name. +//! around and make this cheap, the [`Chain`] type allows combining two other +//! name values. To make this work, the two traits [`ToName`] and +//! [`ToRelativeName`] allow writing code that is generic over any kind of +//! either absolute or relative domain name. //! //! //! ## Building domain names //! -//! You can create a domain name value from its presentation format using -//! the `FromStr` trait. Alternatively, the [`DnameBuilder`] type allows you -//! to construct a name from scratch by appending octets, slices, or complete +//! You can create a domain name value from its presentation format using the +//! `FromStr` trait. Alternatively, the [`NameBuilder`] type allows you to +//! construct a name from scratch by appending octets, slices, or complete //! labels. +pub use self::absolute::{Name, NameError}; pub use self::builder::{ - DnameBuilder, FromStrError, PushError, PushNameError, + FromStrError, NameBuilder, PushError, PushNameError, }; pub use self::chain::{Chain, ChainIter, LongChainError, UncertainChainIter}; -pub use self::dname::{Dname, DnameError}; pub use self::label::{ Label, LabelTypeError, LongLabelError, OwnedLabel, SliceLabelsIter, SplitLabelError, }; -pub use self::parsed::{ParsedDname, ParsedDnameIter, ParsedSuffixIter}; +pub use self::parsed::{ParsedName, ParsedNameIter, ParsedSuffixIter}; pub use self::relative::{ - DnameIter, RelativeDname, RelativeDnameError, RelativeFromStrError, + NameIter, RelativeFromStrError, RelativeName, RelativeNameError, StripSuffixError, }; -pub use self::traits::{FlattenInto, ToDname, ToLabelIter, ToRelativeDname}; -pub use self::uncertain::UncertainDname; +pub use self::traits::{FlattenInto, ToLabelIter, ToName, ToRelativeName}; +pub use self::uncertain::UncertainName; +mod absolute; mod builder; mod chain; -mod dname; mod label; mod parsed; mod relative; diff --git a/src/base/name/parsed.rs b/src/base/name/parsed.rs index 1ce6b36a2..aba0b0659 100644 --- a/src/base/name/parsed.rs +++ b/src/base/name/parsed.rs @@ -5,10 +5,10 @@ use super::super::cmp::CanonicalOrd; use super::super::wire::{FormError, ParseError}; -use super::dname::Dname; +use super::absolute::Name; use super::label::{Label, LabelTypeError}; -use super::relative::RelativeDname; -use super::traits::{FlattenInto, ToDname, ToLabelIter}; +use super::relative::RelativeName; +use super::traits::{FlattenInto, ToLabelIter, ToName}; use core::{cmp, fmt, hash}; use octseq::builder::{ BuilderAppendError, EmptyBuilder, FreezeBuilder, FromBuilder, @@ -17,7 +17,7 @@ use octseq::builder::{ use octseq::octets::Octets; use octseq::parse::Parser; -//------------ ParsedDname --------------------------------------------------- +//------------ ParsedName ---------------------------------------------------- /// A domain name parsed from a DNS message. /// @@ -34,25 +34,25 @@ use octseq::parse::Parser; /// need the complete name. Many operations can be performed by just /// iterating over the labels which we can do in place. /// -/// `ParsedDname` deals with such names. It takes a copy of a [`Parser`] +/// `ParsedName` deals with such names. It takes a copy of a [`Parser`] /// representing a reference to the underlying DNS message and, if nedded, /// traverses over the name starting at the current position of the parser. /// When being created, the type quickly walks over the name to check that it /// is, indeed, a valid name. While this does take a bit of time, it spares /// you having to deal with possible parse errors later on. /// -/// `ParsedDname` implementes the [`ToDname`] trait, so you can use it +/// `ParsedName` implementes the [`ToName`] trait, so you can use it /// everywhere where a generic absolute domain name is accepted. In /// particular, you can compare it to other names or chain it to the end of a -/// relative name. If necessary, [`ToDname::to_name`] can be used to produce -/// a flat, self-contained [`Dname`]. +/// relative name. If necessary, [`ToName::to_name`] can be used to produce +/// a flat, self-contained [`Name`]. /// -/// [`Dname`]: struct.Dname.html +/// [`Name`]: struct.Name.html /// [`Parser`]: ../parse/struct.Parser.html -/// [`ToDname`]: trait.ToDname.html -/// [`ToDname::to_name`]: trait.ToDname.html#method.to_name +/// [`ToName`]: trait.ToName.html +/// [`ToName::to_name`]: trait.ToName.html#method.to_name #[derive(Clone, Copy)] -pub struct ParsedDname { +pub struct ParsedName { /// The octets the name is embedded in. /// /// This needs to be the full message as compression pointers in the name @@ -73,7 +73,7 @@ pub struct ParsedDname { compressed: bool, } -impl ParsedDname { +impl ParsedName { /// Returns whether the name is compressed. pub fn is_compressed(&self) -> bool { self.compressed @@ -90,13 +90,13 @@ impl ParsedDname { Octs: AsRef<[u8]>, { let mut res = Parser::from_ref(&self.octets); - res.advance(self.pos).expect("illegal pos in ParsedDname"); + res.advance(self.pos).expect("illegal pos in ParsedName"); res } /// Returns an equivalent name for a reference to the contained octets. - pub fn ref_octets(&self) -> ParsedDname<&Octs> { - ParsedDname { + pub fn ref_octets(&self) -> ParsedName<&Octs> { + ParsedName { octets: &self.octets, pos: self.pos, name_len: self.name_len, @@ -105,10 +105,10 @@ impl ParsedDname { } } -impl<'a, Octs: Octets + ?Sized> ParsedDname<&'a Octs> { +impl<'a, Octs: Octets + ?Sized> ParsedName<&'a Octs> { #[must_use] - pub fn deref_octets(&self) -> ParsedDname> { - ParsedDname { + pub fn deref_octets(&self) -> ParsedName> { + ParsedName { octets: self.octets.range(..), pos: self.pos, name_len: self.name_len, @@ -119,10 +119,10 @@ impl<'a, Octs: Octets + ?Sized> ParsedDname<&'a Octs> { /// # Working with Labels /// -impl> ParsedDname { +impl> ParsedName { /// Returns an iterator over the labels of the name. - pub fn iter(&self) -> ParsedDnameIter { - ParsedDnameIter::new(self.octets.as_ref(), self.pos, self.name_len) + pub fn iter(&self) -> ParsedNameIter { + ParsedNameIter::new(self.octets.as_ref(), self.pos, self.name_len) } /// Returns an iterator over the suffixes of the name. @@ -168,7 +168,7 @@ impl> ParsedDname { /// If this name is longer than just the root label, returns the first /// label as a relative name and removes it from the name itself. If the /// name is only the root label, returns `None` and does nothing. - pub fn split_first(&mut self) -> Option>> + pub fn split_first(&mut self) -> Option>> where Octs: Octets, { @@ -195,7 +195,7 @@ impl> ParsedDname { self.pos = range.end; self.name_len = name_len; Some(unsafe { - RelativeDname::from_octets_unchecked(self.octets.range(range)) + RelativeName::from_octets_unchecked(self.octets.range(range)) }) } @@ -228,15 +228,15 @@ impl> ParsedDname { } } -impl ParsedDname { +impl ParsedName { pub fn parse<'a, Src: Octets = Octs> + ?Sized>( parser: &mut Parser<'a, Src>, ) -> Result { - ParsedDname::parse_ref(parser).map(|res| res.deref_octets()) + ParsedName::parse_ref(parser).map(|res| res.deref_octets()) } } -impl<'a, Octs: AsRef<[u8]> + ?Sized> ParsedDname<&'a Octs> { +impl<'a, Octs: AsRef<[u8]> + ?Sized> ParsedName<&'a Octs> { pub fn parse_ref( parser: &mut Parser<'a, Octs>, ) -> Result { @@ -252,7 +252,7 @@ impl<'a, Octs: AsRef<[u8]> + ?Sized> ParsedDname<&'a Octs> { LabelType::Normal(0) => { // Root label. name_len += 1; - return Ok(ParsedDname { + return Ok(ParsedName { octets: parser.octets_ref(), pos, name_len, @@ -307,7 +307,7 @@ impl<'a, Octs: AsRef<[u8]> + ?Sized> ParsedDname<&'a Octs> { LabelType::Normal(0) => { // Root label. name_len += 1; - return Ok(ParsedDname { + return Ok(ParsedName { octets: parser.octets_ref(), pos, name_len, @@ -332,7 +332,7 @@ impl<'a, Octs: AsRef<[u8]> + ?Sized> ParsedDname<&'a Octs> { } } -impl ParsedDname<()> { +impl ParsedName<()> { /// Skip over a domain name. /// /// This will only check the uncompressed part of the name. If the name @@ -370,10 +370,10 @@ impl ParsedDname<()> { //--- From -impl> From> for ParsedDname { - fn from(name: Dname) -> ParsedDname { +impl> From> for ParsedName { + fn from(name: Name) -> ParsedName { let name_len = name.compose_len(); - ParsedDname { + ParsedName { octets: name.into_octets(), pos: 0, name_len, @@ -384,7 +384,7 @@ impl> From> for ParsedDname { //--- FlattenInto -impl FlattenInto> for ParsedDname +impl FlattenInto> for ParsedName where Octs: Octets, Target: FromBuilder, @@ -392,7 +392,7 @@ where { type AppendError = BuilderAppendError; - fn try_flatten_into(self) -> Result, Self::AppendError> { + fn try_flatten_into(self) -> Result, Self::AppendError> { let mut builder = Target::Builder::with_capacity(self.compose_len().into()); if let Some(slice) = self.as_flat_slice() { @@ -401,46 +401,46 @@ where self.iter_labels() .try_for_each(|label| label.compose(&mut builder))?; } - Ok(unsafe { Dname::from_octets_unchecked(builder.freeze()) }) + Ok(unsafe { Name::from_octets_unchecked(builder.freeze()) }) } } //--- PartialEq and Eq -impl PartialEq for ParsedDname +impl PartialEq for ParsedName where Octs: AsRef<[u8]>, - N: ToDname + ?Sized, + N: ToName + ?Sized, { fn eq(&self, other: &N) -> bool { self.name_eq(other) } } -impl> Eq for ParsedDname {} +impl> Eq for ParsedName {} //--- PartialOrd, Ord, and CanonicalOrd -impl PartialOrd for ParsedDname +impl PartialOrd for ParsedName where Octs: AsRef<[u8]>, - N: ToDname + ?Sized, + N: ToName + ?Sized, { fn partial_cmp(&self, other: &N) -> Option { Some(self.name_cmp(other)) } } -impl> Ord for ParsedDname { +impl> Ord for ParsedName { fn cmp(&self, other: &Self) -> cmp::Ordering { self.name_cmp(other) } } -impl CanonicalOrd for ParsedDname +impl CanonicalOrd for ParsedName where Octs: AsRef<[u8]>, - N: ToDname + ?Sized, + N: ToName + ?Sized, { fn canonical_cmp(&self, other: &N) -> cmp::Ordering { self.name_cmp(other) @@ -449,7 +449,7 @@ where //--- Hash -impl> hash::Hash for ParsedDname { +impl> hash::Hash for ParsedName { fn hash(&self, state: &mut H) { for item in self.iter() { item.hash(state) @@ -457,10 +457,10 @@ impl> hash::Hash for ParsedDname { } } -//--- ToLabelIter and ToDname +//--- ToLabelIter and ToName -impl> ToLabelIter for ParsedDname { - type LabelIter<'s> = ParsedDnameIter<'s> where Octs: 's; +impl> ToLabelIter for ParsedName { + type LabelIter<'s> = ParsedNameIter<'s> where Octs: 's; fn iter_labels(&self) -> Self::LabelIter<'_> { self.iter() @@ -471,7 +471,7 @@ impl> ToLabelIter for ParsedDname { } } -impl> ToDname for ParsedDname { +impl> ToName for ParsedName { fn as_flat_slice(&self) -> Option<&[u8]> { if self.compressed { None @@ -486,12 +486,12 @@ impl> ToDname for ParsedDname { //--- IntoIterator -impl<'a, Octs> IntoIterator for &'a ParsedDname +impl<'a, Octs> IntoIterator for &'a ParsedName where Octs: AsRef<[u8]>, { type Item = &'a Label; - type IntoIter = ParsedDnameIter<'a>; + type IntoIter = ParsedNameIter<'a>; fn into_iter(self) -> Self::IntoIter { self.iter() @@ -500,7 +500,7 @@ where //--- Display and Debug -impl> fmt::Display for ParsedDname { +impl> fmt::Display for ParsedName { /// Formats the domain name. /// /// This will produce the domain name in common display format without @@ -517,28 +517,28 @@ impl> fmt::Display for ParsedDname { } } -impl> fmt::Debug for ParsedDname { +impl> fmt::Debug for ParsedName { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "ParsedDname({}.)", self) + write!(f, "ParsedName({}.)", self) } } -//------------ ParsedDnameIter ----------------------------------------------- +//------------ ParsedNameIter ----------------------------------------------- /// An iterator over the labels in a parsed domain name. #[derive(Clone)] -pub struct ParsedDnameIter<'a> { +pub struct ParsedNameIter<'a> { slice: &'a [u8], pos: usize, len: u16, } -impl<'a> ParsedDnameIter<'a> { +impl<'a> ParsedNameIter<'a> { /// Creates a new iterator from the parser and the name length. /// /// The parser must be positioned at the beginning of the name. pub(crate) fn new(slice: &'a [u8], pos: usize, len: u16) -> Self { - ParsedDnameIter { slice, pos, len } + ParsedNameIter { slice, pos, len } } /// Returns the next label. @@ -569,7 +569,7 @@ impl<'a> ParsedDnameIter<'a> { } } -impl<'a> Iterator for ParsedDnameIter<'a> { +impl<'a> Iterator for ParsedNameIter<'a> { type Item = &'a Label; fn next(&mut self) -> Option<&'a Label> { @@ -580,7 +580,7 @@ impl<'a> Iterator for ParsedDnameIter<'a> { } } -impl<'a> DoubleEndedIterator for ParsedDnameIter<'a> { +impl<'a> DoubleEndedIterator for ParsedNameIter<'a> { fn next_back(&mut self) -> Option<&'a Label> { if self.len == 0 { return None; @@ -602,12 +602,12 @@ impl<'a> DoubleEndedIterator for ParsedDnameIter<'a> { /// An iterator over ever shorter suffixes of a parsed domain name. #[derive(Clone)] pub struct ParsedSuffixIter<'a, Octs: ?Sized> { - name: Option>, + name: Option>, } impl<'a, Octs> ParsedSuffixIter<'a, Octs> { /// Creates a new iterator cloning `name`. - fn new(name: &'a ParsedDname) -> Self { + fn new(name: &'a ParsedName) -> Self { ParsedSuffixIter { name: Some(name.ref_octets()), } @@ -615,7 +615,7 @@ impl<'a, Octs> ParsedSuffixIter<'a, Octs> { } impl<'a, Octs: Octets + ?Sized> Iterator for ParsedSuffixIter<'a, Octs> { - type Item = ParsedDname>; + type Item = ParsedName>; fn next(&mut self) -> Option { let name = match self.name { @@ -755,7 +755,7 @@ mod test { ($bytes:expr, $start:expr, $len:expr, $compressed:expr) => {{ let mut parser = Parser::from_ref($bytes.as_ref()); parser.advance($start).unwrap(); - ParsedDname { + ParsedName { octets: $bytes.as_ref(), pos: $start, name_len: $len, @@ -792,7 +792,7 @@ mod test { #[test] fn iter() { - use crate::base::name::dname::test::cmp_iter; + use crate::base::name::absolute::test::cmp_iter; let labels: &[&[u8]] = &[b"www", b"example", b"com", b""]; cmp_iter(name!(root).iter(), &[b""]); @@ -803,7 +803,7 @@ mod test { #[test] fn iter_back() { - use crate::base::name::dname::test::cmp_iter_back; + use crate::base::name::absolute::test::cmp_iter_back; let labels: &[&[u8]] = &[b"", b"com", b"example", b"www"]; cmp_iter_back(name!(root).iter(), &[b""]); @@ -814,11 +814,11 @@ mod test { fn cmp_iter_suffixes<'a, I>(iter: I, labels: &[&[u8]]) where - I: Iterator>, + I: Iterator>, { for (name, labels) in iter.zip(labels) { let mut iter = name.iter(); - let labels = Dname::from_slice(labels).unwrap(); + let labels = Name::from_slice(labels).unwrap(); let mut labels_iter = labels.iter(); loop { match (iter.next(), labels_iter.next()) { @@ -868,53 +868,53 @@ mod test { let once_wec = name!(once); let twice_wec = name!(twice); - let test = Dname::root_ref(); + let test = Name::root_ref(); assert!(root.starts_with(&test)); assert!(!flat_wec.starts_with(&test)); assert!(!once_wec.starts_with(&test)); assert!(!twice_wec.starts_with(&test)); - let test = RelativeDname::empty_ref(); + let test = RelativeName::empty_ref(); assert!(root.starts_with(&test)); assert!(flat_wec.starts_with(&test)); assert!(once_wec.starts_with(&test)); assert!(twice_wec.starts_with(&test)); - let test = RelativeDname::from_slice(b"\x03www").unwrap(); + let test = RelativeName::from_slice(b"\x03www").unwrap(); assert!(!root.starts_with(&test)); assert!(flat_wec.starts_with(&test)); assert!(once_wec.starts_with(&test)); assert!(twice_wec.starts_with(&test)); - let test = RelativeDname::from_slice(b"\x03www\x07example").unwrap(); + let test = RelativeName::from_slice(b"\x03www\x07example").unwrap(); assert!(!root.starts_with(&test)); assert!(flat_wec.starts_with(&test)); assert!(once_wec.starts_with(&test)); assert!(twice_wec.starts_with(&test)); let test = - RelativeDname::from_slice(b"\x03www\x07example\x03com").unwrap(); + RelativeName::from_slice(b"\x03www\x07example\x03com").unwrap(); assert!(!root.starts_with(&test)); assert!(flat_wec.starts_with(&test)); assert!(once_wec.starts_with(&test)); assert!(twice_wec.starts_with(&test)); - let test = Dname::from_slice(b"\x03www\x07example\x03com\0").unwrap(); + let test = Name::from_slice(b"\x03www\x07example\x03com\0").unwrap(); assert!(!root.starts_with(&test)); assert!(flat_wec.starts_with(&test)); assert!(once_wec.starts_with(&test)); assert!(twice_wec.starts_with(&test)); - let test = RelativeDname::from_slice(b"\x07example\x03com").unwrap(); + let test = RelativeName::from_slice(b"\x07example\x03com").unwrap(); assert!(!root.starts_with(&test)); assert!(!flat_wec.starts_with(&test)); assert!(!once_wec.starts_with(&test)); assert!(!twice_wec.starts_with(&test)); - let test = RelativeDname::from_octets(b"\x03www".as_ref()) + let test = RelativeName::from_octets(b"\x03www".as_ref()) .unwrap() .chain( - RelativeDname::from_octets(b"\x07example".as_ref()).unwrap(), + RelativeName::from_octets(b"\x07example".as_ref()).unwrap(), ) .unwrap(); assert!(!root.starts_with(&test)); @@ -923,7 +923,7 @@ mod test { assert!(twice_wec.starts_with(&test)); let test = test - .chain(RelativeDname::from_octets(b"\x03com".as_ref()).unwrap()) + .chain(RelativeName::from_octets(b"\x03com".as_ref()).unwrap()) .unwrap(); assert!(!root.starts_with(&test)); assert!(flat_wec.starts_with(&test)); @@ -937,9 +937,8 @@ mod test { let flat_wec = name!(flat); let once_wec = name!(once); let twice_wec = name!(twice); - let wecr = - Dname::from_octets(b"\x03www\x07example\x03com\0".as_ref()) - .unwrap(); + let wecr = Name::from_octets(b"\x03www\x07example\x03com\0".as_ref()) + .unwrap(); for name in wecr.iter_suffixes() { if name.is_root() { @@ -956,7 +955,7 @@ mod test { #[test] #[cfg(feature = "std")] fn split_first() { - fn split_first_wec(mut name: ParsedDname<&[u8]>) { + fn split_first_wec(mut name: ParsedName<&[u8]>) { assert_eq!( name.to_vec().as_slice(), b"\x03www\x07example\x03com\0" @@ -988,7 +987,7 @@ mod test { #[test] #[cfg(feature = "std")] fn parent() { - fn parent_wec(mut name: ParsedDname<&[u8]>) { + fn parent_wec(mut name: ParsedName<&[u8]>) { assert_eq!( name.to_vec().as_slice(), b"\x03www\x07example\x03com\0" @@ -1013,7 +1012,7 @@ mod test { fn parse_and_skip() { use std::vec::Vec; - fn name_eq(parsed: ParsedDname<&[u8]>, name: ParsedDname<&[u8]>) { + fn name_eq(parsed: ParsedName<&[u8]>, name: ParsedName<&[u8]>) { assert_eq!(parsed.octets, name.octets); assert_eq!(parsed.pos, name.pos); assert_eq!(parsed.name_len, name.name_len); @@ -1022,18 +1021,18 @@ mod test { fn parse( mut parser: Parser<&[u8]>, - equals: ParsedDname<&[u8]>, + equals: ParsedName<&[u8]>, compose_len: usize, ) { let end = parser.pos() + compose_len; - name_eq(ParsedDname::parse(&mut parser).unwrap(), equals); + name_eq(ParsedName::parse(&mut parser).unwrap(), equals); assert_eq!(parser.pos(), end); } - fn skip(name: ParsedDname<&[u8]>, len: usize) { + fn skip(name: ParsedName<&[u8]>, len: usize) { let mut parser = name.parser(); let pos = parser.pos(); - assert_eq!(ParsedDname::skip(&mut parser), Ok(())); + assert_eq!(ParsedName::skip(&mut parser), Ok(())); assert_eq!(parser.pos(), pos + len); } @@ -1058,41 +1057,41 @@ mod test { // Short buffer in the middle of a label. let mut parser = p(b"\x03www\x07exam", 0); assert_eq!( - ParsedDname::parse(&mut parser.clone()), + ParsedName::parse(&mut parser.clone()), Err(ParseError::ShortInput) ); assert_eq!( - ParsedDname::skip(&mut parser), + ParsedName::skip(&mut parser), Err(ParseError::ShortInput) ); // Short buffer at end of label. let mut parser = p(b"\x03www\x07example", 0); assert_eq!( - ParsedDname::parse(&mut parser.clone()), + ParsedName::parse(&mut parser.clone()), Err(ParseError::ShortInput) ); assert_eq!( - ParsedDname::skip(&mut parser), + ParsedName::skip(&mut parser), Err(ParseError::ShortInput) ); // Compression pointer beyond the end of buffer. let mut parser = p(b"\x03www\xc0\xee12", 0); - assert!(ParsedDname::parse(&mut parser.clone()).is_err()); - assert_eq!(ParsedDname::skip(&mut parser), Ok(())); + assert!(ParsedName::parse(&mut parser.clone()).is_err()); + assert_eq!(ParsedName::skip(&mut parser), Ok(())); assert_eq!(parser.remaining(), 2); // Compression pointer to itself - assert!(ParsedDname::parse(&mut p(b"\x03www\xc0\x0412", 4)).is_err()); + assert!(ParsedName::parse(&mut p(b"\x03www\xc0\x0412", 4)).is_err()); // Compression pointer forward - assert!(ParsedDname::parse(&mut p(b"\x03www\xc0\x0612", 4)).is_err()); + assert!(ParsedName::parse(&mut p(b"\x03www\xc0\x0612", 4)).is_err()); // Bad label header. let mut parser = p(b"\x03www\x07example\xbffoo", 0); - assert!(ParsedDname::parse(&mut parser.clone()).is_err()); - assert!(ParsedDname::skip(&mut parser).is_err()); + assert!(ParsedName::parse(&mut parser.clone()).is_err()); + assert!(ParsedName::skip(&mut parser).is_err()); // Long name: 255 bytes is fine. let mut buf = Vec::from(&b"\x03123\0"[..]); @@ -1102,9 +1101,9 @@ mod test { buf.extend_from_slice(b"\xc0\x0012"); let mut parser = Parser::from_ref(buf.as_slice()); parser.advance(5).unwrap(); - let name = ParsedDname::parse(&mut parser.clone()).unwrap(); + let name = ParsedName::parse(&mut parser.clone()).unwrap(); assert_eq!(name.compose_len(), 255); - assert_eq!(ParsedDname::skip(&mut parser), Ok(())); + assert_eq!(ParsedName::skip(&mut parser), Ok(())); assert_eq!(parser.remaining(), 2); // Long name: 256 bytes are bad. @@ -1115,26 +1114,26 @@ mod test { buf.extend_from_slice(b"\xc0\x0012"); let mut parser = Parser::from_ref(buf.as_slice()); parser.advance(6).unwrap(); - assert!(ParsedDname::parse(&mut parser.clone()).is_err()); - assert_eq!(ParsedDname::skip(&mut parser), Ok(())); + assert!(ParsedName::parse(&mut parser.clone()).is_err()); + assert_eq!(ParsedName::skip(&mut parser), Ok(())); assert_eq!(parser.remaining(), 2); // Long name through recursion let mut parser = p(b"\x03www\xc0\x0012", 0); - assert!(ParsedDname::parse(&mut parser.clone()).is_err()); - assert_eq!(ParsedDname::skip(&mut parser), Ok(())); + assert!(ParsedName::parse(&mut parser.clone()).is_err()); + assert_eq!(ParsedName::skip(&mut parser), Ok(())); assert_eq!(parser.remaining(), 2); // Single-step infinite recursion let mut parser = p(b"\xc0\x0012", 0); - assert!(ParsedDname::parse(&mut parser.clone()).is_err()); - assert_eq!(ParsedDname::skip(&mut parser), Ok(())); + assert!(ParsedName::parse(&mut parser.clone()).is_err()); + assert_eq!(ParsedName::skip(&mut parser), Ok(())); assert_eq!(parser.remaining(), 2); // Two-step infinite recursion let mut parser = p(b"\xc0\x02\xc0\x0012", 2); - assert!(ParsedDname::parse(&mut parser.clone()).is_err()); - assert_eq!(ParsedDname::skip(&mut parser), Ok(())); + assert!(ParsedName::parse(&mut parser.clone()).is_err()); + assert_eq!(ParsedName::skip(&mut parser), Ok(())); assert_eq!(parser.remaining(), 2); } @@ -1144,7 +1143,7 @@ mod test { use octseq::builder::infallible; use std::vec::Vec; - fn step(name: ParsedDname<&[u8]>, result: &[u8]) { + fn step(name: ParsedName<&[u8]>, result: &[u8]) { let mut buf = Vec::new(); infallible(name.compose(&mut buf)); assert_eq!(buf.as_slice(), result); @@ -1168,13 +1167,13 @@ mod test { #[test] fn eq() { - fn step(name: N) { + fn step(name: N) { assert_eq!(name!(flat), &name); assert_eq!(name!(once), &name); assert_eq!(name!(twice), &name); } - fn ne_step(name: N) { + fn ne_step(name: N) { assert_ne!(name!(flat), &name); assert_ne!(name!(once), &name); assert_ne!(name!(twice), &name); @@ -1184,21 +1183,21 @@ mod test { step(name!(once)); step(name!(twice)); - step(Dname::from_slice(b"\x03www\x07example\x03com\x00").unwrap()); - step(Dname::from_slice(b"\x03wWw\x07EXAMPLE\x03com\x00").unwrap()); + step(Name::from_slice(b"\x03www\x07example\x03com\x00").unwrap()); + step(Name::from_slice(b"\x03wWw\x07EXAMPLE\x03com\x00").unwrap()); step( - RelativeDname::from_octets(b"\x03www\x07example\x03com") + RelativeName::from_octets(b"\x03www\x07example\x03com") .unwrap() .chain_root(), ); step( - RelativeDname::from_octets(b"\x03www\x07example") + RelativeName::from_octets(b"\x03www\x07example") .unwrap() - .chain(Dname::from_octets(b"\x03com\x00").unwrap()) + .chain(Name::from_octets(b"\x03com\x00").unwrap()) .unwrap(), ); - ne_step(Dname::from_slice(b"\x03ww4\x07EXAMPLE\x03com\x00").unwrap()); + ne_step(Name::from_slice(b"\x03ww4\x07EXAMPLE\x03com\x00").unwrap()); } // XXX TODO Test for cmp and hash. diff --git a/src/base/name/relative.rs b/src/base/name/relative.rs index 6328388cf..1737dbc6b 100644 --- a/src/base/name/relative.rs +++ b/src/base/name/relative.rs @@ -3,11 +3,11 @@ //! This is a private module. Its public types are re-exported by the parent. use super::super::wire::ParseError; -use super::builder::{DnameBuilder, FromStrError, PushError}; +use super::absolute::Name; +use super::builder::{FromStrError, NameBuilder, PushError}; use super::chain::{Chain, LongChainError}; -use super::dname::Dname; use super::label::{Label, LabelTypeError, SplitLabelError}; -use super::traits::{ToLabelIter, ToRelativeDname}; +use super::traits::{ToLabelIter, ToRelativeName}; #[cfg(feature = "bytes")] use bytes::Bytes; use core::cmp::Ordering; @@ -23,29 +23,29 @@ use octseq::serde::{DeserializeOctets, SerializeOctets}; #[cfg(feature = "std")] use std::vec::Vec; -//------------ RelativeDname ------------------------------------------------- +//------------ RelativeName -------------------------------------------------- /// An uncompressed, relative domain name. /// /// A relative domain name is one that doesn’t end with the root label. As the /// name suggests, it is relative to some other domain name. This type wraps /// a octets sequence containing such a relative name similarly to the way -/// [`Dname`] wraps an absolute one. In fact, it behaves very similarly to -/// [`Dname`] taking into account differences when slicing and dicing names. +/// [`Name`] wraps an absolute one. In fact, it behaves very similarly to +/// [`Name`] taking into account differences when slicing and dicing names. /// -/// `RelativeDname` guarantees that the name is at most 254 bytes long. As the +/// `RelativeName` guarantees that the name is at most 254 bytes long. As the /// length limit for a domain name is actually 255 bytes, this means that you -/// can always safely turn a `RelativeDname` into a `Dname` by adding the root +/// can always safely turn a `RelativeName` into a `Name` by adding the root /// label (which is exactly one byte long). /// /// [`Bytes`]: ../../../bytes/struct.Bytes.html -/// [`Dname`]: struct.Dname.html +/// [`Name`]: struct.Name.html #[derive(Clone)] -pub struct RelativeDname(Octs); +pub struct RelativeName(Octs); /// # Creating Values /// -impl RelativeDname { +impl RelativeName { /// Creates a relative domain name from octets without checking. /// /// Since the content of the octets sequence can be anything, really, @@ -57,19 +57,19 @@ impl RelativeDname { /// encoded relative domain name. It must be at most 254 octets long. /// There must be no root labels anywhere in the name. pub const unsafe fn from_octets_unchecked(octets: Octs) -> Self { - RelativeDname(octets) + RelativeName(octets) } /// Creates a relative domain name from an octets sequence. /// /// This checks that `octets` contains a properly encoded relative domain /// name and fails if it doesn’t. - pub fn from_octets(octets: Octs) -> Result + pub fn from_octets(octets: Octs) -> Result where Octs: AsRef<[u8]>, { - RelativeDname::check_slice(octets.as_ref())?; - Ok(unsafe { RelativeDname::from_octets_unchecked(octets) }) + RelativeName::check_slice(octets.as_ref())?; + Ok(unsafe { RelativeName::from_octets_unchecked(octets) }) } /// Creates an empty relative domain name. @@ -78,7 +78,7 @@ impl RelativeDname { where Octs: From<&'static [u8]>, { - unsafe { RelativeDname::from_octets_unchecked(b"".as_ref().into()) } + unsafe { RelativeName::from_octets_unchecked(b"".as_ref().into()) } } /// Creates a relative domain name representing the wildcard label. @@ -93,7 +93,7 @@ impl RelativeDname { Octs: From<&'static [u8]>, { unsafe { - RelativeDname::from_octets_unchecked(b"\x01*".as_ref().into()) + RelativeName::from_octets_unchecked(b"\x01*".as_ref().into()) } } @@ -117,7 +117,7 @@ impl RelativeDname { + AsMut<[u8]>, C: IntoIterator, { - let mut builder = DnameBuilder::::new(); + let mut builder = NameBuilder::::new(); builder.append_chars(chars)?; if builder.in_label() || builder.is_empty() { Ok(builder.finish()) @@ -127,14 +127,14 @@ impl RelativeDname { } } -impl RelativeDname<[u8]> { +impl RelativeName<[u8]> { /// Creates a domain name from an octet slice without checking. /// /// # Safety /// /// The same rules as for `from_octets_unchecked` apply. pub(super) unsafe fn from_slice_unchecked(slice: &[u8]) -> &Self { - &*(slice as *const [u8] as *const RelativeDname<[u8]>) + &*(slice as *const [u8] as *const RelativeName<[u8]>) } /// Creates a relative domain name from an octet slice. @@ -144,10 +144,10 @@ impl RelativeDname<[u8]> { /// # Example /// /// ``` - /// use domain::base::name::RelativeDname; - /// RelativeDname::from_slice(b"\x0c_submissions\x04_tcp"); + /// use domain::base::name::RelativeName; + /// RelativeName::from_slice(b"\x0c_submissions\x04_tcp"); /// ``` - pub fn from_slice(slice: &[u8]) -> Result<&Self, RelativeDnameError> { + pub fn from_slice(slice: &[u8]) -> Result<&Self, RelativeNameError> { Self::check_slice(slice)?; Ok(unsafe { Self::from_slice_unchecked(slice) }) } @@ -166,14 +166,14 @@ impl RelativeDname<[u8]> { /// Checks whether an octet slice contains a correctly encoded name. pub(super) fn check_slice( mut slice: &[u8], - ) -> Result<(), RelativeDnameError> { + ) -> Result<(), RelativeNameError> { if slice.len() > 254 { - return Err(RelativeDnameErrorEnum::LongName.into()); + return Err(RelativeNameErrorEnum::LongName.into()); } while !slice.is_empty() { let (label, tail) = Label::split_from(slice)?; if label.is_root() { - return Err(RelativeDnameErrorEnum::AbsoluteName.into()); + return Err(RelativeNameErrorEnum::AbsoluteName.into()); } slice = tail; } @@ -181,7 +181,7 @@ impl RelativeDname<[u8]> { } } -impl RelativeDname<&'static [u8]> { +impl RelativeName<&'static [u8]> { /// Creates an empty relative name atop a slice reference. #[must_use] pub fn empty_ref() -> Self { @@ -196,7 +196,7 @@ impl RelativeDname<&'static [u8]> { } #[cfg(feature = "std")] -impl RelativeDname> { +impl RelativeName> { /// Creates an empty relative name atop a `Vec`. #[must_use] pub fn empty_vec() -> Self { @@ -216,7 +216,7 @@ impl RelativeDname> { } #[cfg(feature = "bytes")] -impl RelativeDname { +impl RelativeName { /// Creates an empty relative name atop a bytes value. pub fn empty_bytes() -> Self { Self::empty() @@ -235,7 +235,7 @@ impl RelativeDname { /// # Conversions /// -impl RelativeDname { +impl RelativeName { /// Returns a reference to the underlying octets. pub fn as_octets(&self) -> &Octs { &self.0 @@ -250,8 +250,8 @@ impl RelativeDname { } /// Returns a domain name using a reference to the octets. - pub fn for_ref(&self) -> RelativeDname<&Octs> { - unsafe { RelativeDname::from_octets_unchecked(&self.0) } + pub fn for_ref(&self) -> RelativeName<&Octs> { + unsafe { RelativeName::from_octets_unchecked(&self.0) } } /// Returns a reference to an octets slice with the content of the name. @@ -263,11 +263,11 @@ impl RelativeDname { } /// Returns a domain name for the octets slice of the content. - pub fn for_slice(&self) -> &RelativeDname<[u8]> + pub fn for_slice(&self) -> &RelativeName<[u8]> where Octs: AsRef<[u8]>, { - unsafe { RelativeDname::from_slice_unchecked(self.0.as_ref()) } + unsafe { RelativeName::from_slice_unchecked(self.0.as_ref()) } } /// Converts the name into its canonical form. @@ -279,16 +279,16 @@ impl RelativeDname { } } -impl RelativeDname { +impl RelativeName { /// Converts the name into a domain name builder for appending data. /// /// This method is only available for octets sequences that have an /// associated octets builder such as `Vec` or `Bytes`. - pub fn into_builder(self) -> DnameBuilder<::Builder> + pub fn into_builder(self) -> NameBuilder<::Builder> where Octs: IntoBuilder, { - unsafe { DnameBuilder::from_builder_unchecked(self.0.into_builder()) } + unsafe { NameBuilder::from_builder_unchecked(self.0.into_builder()) } } /// Converts the name into an absolute name by appending the root label. @@ -298,13 +298,13 @@ impl RelativeDname { /// such as `Vec`. /// /// [`chain_root`]: #method.chain_root - pub fn into_absolute(self) -> Result, PushError> + pub fn into_absolute(self) -> Result, PushError> where Octs: IntoBuilder, ::Builder: FreezeBuilder + AsRef<[u8]> + AsMut<[u8]>, { - self.into_builder().into_dname() + self.into_builder().into_name() } /// Chains another name to the end of this name. @@ -327,17 +327,17 @@ impl RelativeDname { } /// Creates an absolute name by chaining the root label to it. - pub fn chain_root(self) -> Chain> + pub fn chain_root(self) -> Chain> where Octs: AsRef<[u8]>, { - self.chain(Dname::root()).unwrap() + self.chain(Name::root()).unwrap() } } /// # Properties /// -impl + ?Sized> RelativeDname { +impl + ?Sized> RelativeName { /// Returns the length of the name. pub fn len(&self) -> usize { self.0.as_ref().len() @@ -351,10 +351,10 @@ impl + ?Sized> RelativeDname { /// # Working with Labels /// -impl + ?Sized> RelativeDname { +impl + ?Sized> RelativeName { /// Returns an iterator over the labels of the domain name. - pub fn iter(&self) -> DnameIter { - DnameIter::new(self.0.as_ref()) + pub fn iter(&self) -> NameIter { + NameIter::new(self.0.as_ref()) } /// Returns the number of labels in the name. @@ -456,10 +456,10 @@ impl + ?Sized> RelativeDname { pub fn slice( &self, range: impl RangeBounds, - ) -> &RelativeDname<[u8]> { + ) -> &RelativeName<[u8]> { self.check_bounds(&range); unsafe { - RelativeDname::from_slice_unchecked(self.0.as_ref().range(range)) + RelativeName::from_slice_unchecked(self.0.as_ref().range(range)) } } @@ -476,16 +476,16 @@ impl + ?Sized> RelativeDname { pub fn range( &self, range: impl RangeBounds, - ) -> RelativeDname<::Range<'_>> + ) -> RelativeName<::Range<'_>> where Octs: Octets, { self.check_bounds(&range); - unsafe { RelativeDname::from_octets_unchecked(self.0.range(range)) } + unsafe { RelativeName::from_octets_unchecked(self.0.range(range)) } } } -impl + ?Sized> RelativeDname { +impl + ?Sized> RelativeName { /// Splits the name into two at the given position. /// /// Returns a pair of the left and right part of the split name. @@ -497,18 +497,15 @@ impl + ?Sized> RelativeDname { pub fn split( &self, mid: usize, - ) -> ( - RelativeDname>, - RelativeDname>, - ) + ) -> (RelativeName>, RelativeName>) where Octs: Octets, { self.check_index(mid); unsafe { ( - RelativeDname::from_octets_unchecked(self.0.range(..mid)), - RelativeDname::from_octets_unchecked(self.0.range(mid..)), + RelativeName::from_octets_unchecked(self.0.range(..mid)), + RelativeName::from_octets_unchecked(self.0.range(mid..)), ) } } @@ -535,7 +532,7 @@ impl + ?Sized> RelativeDname { /// is empty, returns `None`. pub fn split_first( &self, - ) -> Option<(&Label, RelativeDname>)> + ) -> Option<(&Label, RelativeName>)> where Octs: Octets, { @@ -549,7 +546,7 @@ impl + ?Sized> RelativeDname { /// Returns the parent name. /// /// Returns `None` if the name was empty. - pub fn parent(&self) -> Option>> + pub fn parent(&self) -> Option>> where Octs: Octets, { @@ -562,7 +559,7 @@ impl + ?Sized> RelativeDname { /// [`ends_with`] doesn’t return `true`. /// /// [`ends_with`]: #method.ends_with - pub fn strip_suffix( + pub fn strip_suffix( &mut self, base: &N, ) -> Result<(), StripSuffixError> @@ -581,13 +578,13 @@ impl + ?Sized> RelativeDname { //--- AsRef -impl AsRef for RelativeDname { +impl AsRef for RelativeName { fn as_ref(&self) -> &Octs { &self.0 } } -impl + ?Sized> AsRef<[u8]> for RelativeDname { +impl + ?Sized> AsRef<[u8]> for RelativeName { fn as_ref(&self) -> &[u8] { self.0.as_ref() } @@ -595,14 +592,14 @@ impl + ?Sized> AsRef<[u8]> for RelativeDname { //--- OctetsFrom -impl OctetsFrom> for RelativeDname +impl OctetsFrom> for RelativeName where Octs: OctetsFrom, { type Error = Octs::Error; fn try_octets_from( - source: RelativeDname, + source: RelativeName, ) -> Result { Octs::try_octets_from(source.0) .map(|octets| unsafe { Self::from_octets_unchecked(octets) }) @@ -611,7 +608,7 @@ where //--- FromStr -impl FromStr for RelativeDname +impl FromStr for RelativeName where Octs: FromBuilder, ::Builder: EmptyBuilder @@ -635,13 +632,13 @@ where } } -//--- ToLabelIter and ToRelativeDname +//--- ToLabelIter and ToRelativeName -impl ToLabelIter for RelativeDname +impl ToLabelIter for RelativeName where Octs: AsRef<[u8]> + ?Sized, { - type LabelIter<'a> = DnameIter<'a> where Octs: 'a; + type LabelIter<'a> = NameIter<'a> where Octs: 'a; fn iter_labels(&self) -> Self::LabelIter<'_> { self.iter() @@ -652,7 +649,7 @@ where } } -impl + ?Sized> ToRelativeDname for RelativeDname { +impl + ?Sized> ToRelativeName for RelativeName { fn as_flat_slice(&self) -> Option<&[u8]> { Some(self.0.as_ref()) } @@ -664,12 +661,12 @@ impl + ?Sized> ToRelativeDname for RelativeDname { //--- IntoIterator -impl<'a, Octs> IntoIterator for &'a RelativeDname +impl<'a, Octs> IntoIterator for &'a RelativeName where Octs: AsRef<[u8]> + ?Sized, { type Item = &'a Label; - type IntoIter = DnameIter<'a>; + type IntoIter = NameIter<'a>; fn into_iter(self) -> Self::IntoIter { self.iter() @@ -678,31 +675,31 @@ where //--- PartialEq and Eq -impl PartialEq for RelativeDname +impl PartialEq for RelativeName where Octs: AsRef<[u8]> + ?Sized, - N: ToRelativeDname + ?Sized, + N: ToRelativeName + ?Sized, { fn eq(&self, other: &N) -> bool { self.name_eq(other) } } -impl + ?Sized> Eq for RelativeDname {} +impl + ?Sized> Eq for RelativeName {} //--- PartialOrd and Ord -impl PartialOrd for RelativeDname +impl PartialOrd for RelativeName where Octs: AsRef<[u8]> + ?Sized, - N: ToRelativeDname + ?Sized, + N: ToRelativeName + ?Sized, { fn partial_cmp(&self, other: &N) -> Option { Some(self.name_cmp(other)) } } -impl + ?Sized> Ord for RelativeDname { +impl + ?Sized> Ord for RelativeName { fn cmp(&self, other: &Self) -> cmp::Ordering { self.name_cmp(other) } @@ -710,7 +707,7 @@ impl + ?Sized> Ord for RelativeDname { //--- Hash -impl + ?Sized> hash::Hash for RelativeDname { +impl + ?Sized> hash::Hash for RelativeName { fn hash(&self, state: &mut H) { for item in self.iter() { item.hash(state) @@ -720,7 +717,7 @@ impl + ?Sized> hash::Hash for RelativeDname { //--- Display and Debug -impl + ?Sized> fmt::Display for RelativeDname { +impl + ?Sized> fmt::Display for RelativeName { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let mut iter = self.iter(); match iter.next() { @@ -735,47 +732,47 @@ impl + ?Sized> fmt::Display for RelativeDname { } } -impl + ?Sized> fmt::Debug for RelativeDname { +impl + ?Sized> fmt::Debug for RelativeName { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "RelativeDname({})", self) + write!(f, "RelativeName({})", self) } } //--- AsRef and Borrow -impl AsRef> for RelativeDname +impl AsRef> for RelativeName where Octs: AsRef<[u8]> + ?Sized, { - fn as_ref(&self) -> &RelativeDname<[u8]> { + fn as_ref(&self) -> &RelativeName<[u8]> { self.for_slice() } } /// Borrow a relative domain name. /// -/// Containers holding an owned `RelativeDname<_>` may be queried with name +/// Containers holding an owned `RelativeName<_>` may be queried with name /// over a slice. This `Borrow<_>` impl supports user code querying containers /// with compatible-but-different types like the following example: /// /// ``` /// use std::collections::HashMap; /// -/// use domain::base::RelativeDname; +/// use domain::base::RelativeName; /// /// fn get_description( -/// hash: &HashMap>, String> +/// hash: &HashMap>, String> /// ) -> Option<&str> { -/// let lookup_name: &RelativeDname<[u8]> = -/// RelativeDname::from_slice(b"\x03ftp").unwrap(); +/// let lookup_name: &RelativeName<[u8]> = +/// RelativeName::from_slice(b"\x03ftp").unwrap(); /// hash.get(lookup_name).map(|x| x.as_ref()) /// } /// ``` -impl borrow::Borrow> for RelativeDname +impl borrow::Borrow> for RelativeName where Octs: AsRef<[u8]>, { - fn borrow(&self) -> &RelativeDname<[u8]> { + fn borrow(&self) -> &RelativeName<[u8]> { self.for_slice() } } @@ -783,7 +780,7 @@ where //--- Serialize and Deserialize #[cfg(feature = "serde")] -impl serde::Serialize for RelativeDname +impl serde::Serialize for RelativeName where Octs: AsRef<[u8]> + SerializeOctets + ?Sized, { @@ -793,12 +790,12 @@ where ) -> Result { if serializer.is_human_readable() { serializer.serialize_newtype_struct( - "RelativeDname", + "RelativeName", &format_args!("{}", self), ) } else { serializer.serialize_newtype_struct( - "RelativeDname", + "RelativeName", &self.0.as_serialized_octets(), ) } @@ -806,7 +803,7 @@ where } #[cfg(feature = "serde")] -impl<'de, Octs> serde::Deserialize<'de> for RelativeDname +impl<'de, Octs> serde::Deserialize<'de> for RelativeName where Octs: FromBuilder + DeserializeOctets<'de>, ::Builder: FreezeBuilder @@ -829,7 +826,7 @@ where + AsRef<[u8]> + AsMut<[u8]>, { - type Value = RelativeDname; + type Value = RelativeName; fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str("a relative domain name") @@ -839,7 +836,7 @@ where self, v: &str, ) -> Result { - let mut builder = DnameBuilder::::new(); + let mut builder = NameBuilder::::new(); builder.append_chars(v.chars()).map_err(E::custom)?; Ok(builder.finish()) } @@ -849,7 +846,7 @@ where value: &'de [u8], ) -> Result { self.0.visit_borrowed_bytes(value).and_then(|octets| { - RelativeDname::from_octets(octets).map_err(E::custom) + RelativeName::from_octets(octets).map_err(E::custom) }) } @@ -859,7 +856,7 @@ where value: std::vec::Vec, ) -> Result { self.0.visit_byte_buf(value).and_then(|octets| { - RelativeDname::from_octets(octets).map_err(E::custom) + RelativeName::from_octets(octets).map_err(E::custom) }) } } @@ -874,7 +871,7 @@ where + AsRef<[u8]> + AsMut<[u8]>, { - type Value = RelativeDname; + type Value = RelativeName; fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str("a relative domain name") @@ -897,27 +894,27 @@ where } deserializer.deserialize_newtype_struct( - "RelativeDname", + "RelativeName", NewtypeVisitor(PhantomData), ) } } -//------------ DnameIter ----------------------------------------------------- +//------------ NameIter ----------------------------------------------------- /// An iterator over the labels in an uncompressed name. #[derive(Clone, Debug)] -pub struct DnameIter<'a> { +pub struct NameIter<'a> { slice: &'a [u8], } -impl<'a> DnameIter<'a> { +impl<'a> NameIter<'a> { pub(super) fn new(slice: &'a [u8]) -> Self { - DnameIter { slice } + NameIter { slice } } } -impl<'a> Iterator for DnameIter<'a> { +impl<'a> Iterator for NameIter<'a> { type Item = &'a Label; fn next(&mut self) -> Option { @@ -930,7 +927,7 @@ impl<'a> Iterator for DnameIter<'a> { } } -impl<'a> DoubleEndedIterator for DnameIter<'a> { +impl<'a> DoubleEndedIterator for NameIter<'a> { fn next_back(&mut self) -> Option { if self.slice.is_empty() { return None; @@ -951,14 +948,14 @@ impl<'a> DoubleEndedIterator for DnameIter<'a> { //============ Error Types =================================================== -//------------ RelativeDnameError -------------------------------------------- +//------------ RelativeNameError -------------------------------------------- /// An error happened while creating a domain name from octets. #[derive(Clone, Copy, Debug, Eq, PartialEq)] -pub struct RelativeDnameError(RelativeDnameErrorEnum); +pub struct RelativeNameError(RelativeNameErrorEnum); #[derive(Clone, Copy, Debug, Eq, PartialEq)] -enum RelativeDnameErrorEnum { +enum RelativeNameErrorEnum { /// A bad label was encountered. BadLabel(LabelTypeError), @@ -977,48 +974,46 @@ enum RelativeDnameErrorEnum { //--- From -impl From for RelativeDnameError { +impl From for RelativeNameError { fn from(err: LabelTypeError) -> Self { - Self(RelativeDnameErrorEnum::BadLabel(err)) + Self(RelativeNameErrorEnum::BadLabel(err)) } } -impl From for RelativeDnameError { +impl From for RelativeNameError { fn from(err: SplitLabelError) -> Self { Self(match err { SplitLabelError::Pointer(_) => { - RelativeDnameErrorEnum::CompressedName + RelativeNameErrorEnum::CompressedName } - SplitLabelError::BadType(t) => { - RelativeDnameErrorEnum::BadLabel(t) - } - SplitLabelError::ShortInput => RelativeDnameErrorEnum::ShortInput, + SplitLabelError::BadType(t) => RelativeNameErrorEnum::BadLabel(t), + SplitLabelError::ShortInput => RelativeNameErrorEnum::ShortInput, }) } } -impl From for RelativeDnameError { - fn from(err: RelativeDnameErrorEnum) -> Self { +impl From for RelativeNameError { + fn from(err: RelativeNameErrorEnum) -> Self { Self(err) } } //--- Display and Error -impl fmt::Display for RelativeDnameError { +impl fmt::Display for RelativeNameError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self.0 { - RelativeDnameErrorEnum::BadLabel(err) => err.fmt(f), - RelativeDnameErrorEnum::CompressedName => { + RelativeNameErrorEnum::BadLabel(err) => err.fmt(f), + RelativeNameErrorEnum::CompressedName => { f.write_str("compressed domain name") } - RelativeDnameErrorEnum::ShortInput => { + RelativeNameErrorEnum::ShortInput => { ParseError::ShortInput.fmt(f) } - RelativeDnameErrorEnum::LongName => { + RelativeNameErrorEnum::LongName => { f.write_str("long domain name") } - RelativeDnameErrorEnum::AbsoluteName => { + RelativeNameErrorEnum::AbsoluteName => { f.write_str("absolute domain name") } } @@ -1026,7 +1021,7 @@ impl fmt::Display for RelativeDnameError { } #[cfg(feature = "std")] -impl std::error::Error for RelativeDnameError {} +impl std::error::Error for RelativeNameError {} //------------ RelativeFromStrError ------------------------------------------ @@ -1098,19 +1093,19 @@ mod test { #[test] #[cfg(feature = "std")] fn impls() { - fn assert_to_relative_dname(_: &T) {} + fn assert_to_relative_name(_: &T) {} - assert_to_relative_dname( - RelativeDname::from_slice(b"\x03www".as_ref()).unwrap(), + assert_to_relative_name( + RelativeName::from_slice(b"\x03www".as_ref()).unwrap(), ); - assert_to_relative_dname( - &RelativeDname::from_octets(b"\x03www").unwrap(), + assert_to_relative_name( + &RelativeName::from_octets(b"\x03www").unwrap(), ); - assert_to_relative_dname( - &RelativeDname::from_octets(b"\x03www".as_ref()).unwrap(), + assert_to_relative_name( + &RelativeName::from_octets(b"\x03www".as_ref()).unwrap(), ); - assert_to_relative_dname( - &RelativeDname::from_octets(Vec::from(b"\x03www".as_ref())) + assert_to_relative_name( + &RelativeName::from_octets(Vec::from(b"\x03www".as_ref())) .unwrap(), ); } @@ -1118,54 +1113,54 @@ mod test { #[cfg(feature = "bytes")] #[test] fn impl_bytes() { - fn assert_to_relative_dname(_: &T) {} + fn assert_to_relative_name(_: &T) {} - assert_to_relative_dname( - &RelativeDname::from_octets(Bytes::from(b"\x03www".as_ref())) + assert_to_relative_name( + &RelativeName::from_octets(Bytes::from(b"\x03www".as_ref())) .unwrap(), ); } #[test] fn empty() { - assert_eq!(RelativeDname::empty_slice().as_slice(), b""); - assert_eq!(RelativeDname::empty_ref().as_slice(), b""); + assert_eq!(RelativeName::empty_slice().as_slice(), b""); + assert_eq!(RelativeName::empty_ref().as_slice(), b""); #[cfg(feature = "std")] { - assert_eq!(RelativeDname::empty_vec().as_slice(), b""); + assert_eq!(RelativeName::empty_vec().as_slice(), b""); } } #[test] fn wildcard() { - assert_eq!(RelativeDname::wildcard_slice().as_slice(), b"\x01*"); - assert_eq!(RelativeDname::wildcard_ref().as_slice(), b"\x01*"); + assert_eq!(RelativeName::wildcard_slice().as_slice(), b"\x01*"); + assert_eq!(RelativeName::wildcard_ref().as_slice(), b"\x01*"); #[cfg(feature = "std")] { - assert_eq!(RelativeDname::wildcard_vec().as_slice(), b"\x01*"); + assert_eq!(RelativeName::wildcard_vec().as_slice(), b"\x01*"); } } #[cfg(feature = "bytes")] #[test] fn literals_bytes() { - assert_eq!(RelativeDname::empty_bytes().as_slice(), b""); - assert_eq!(RelativeDname::wildcard_bytes().as_slice(), b"\x01*"); + assert_eq!(RelativeName::empty_bytes().as_slice(), b""); + assert_eq!(RelativeName::wildcard_bytes().as_slice(), b"\x01*"); } #[test] #[cfg(feature = "std")] fn from_slice() { // good names - assert_eq!(RelativeDname::from_slice(b"").unwrap().as_slice(), b""); + assert_eq!(RelativeName::from_slice(b"").unwrap().as_slice(), b""); assert_eq!( - RelativeDname::from_slice(b"\x03www").unwrap().as_slice(), + RelativeName::from_slice(b"\x03www").unwrap().as_slice(), b"\x03www" ); assert_eq!( - RelativeDname::from_slice(b"\x03www\x07example") + RelativeName::from_slice(b"\x03www\x07example") .unwrap() .as_slice(), b"\x03www\x07example" @@ -1173,27 +1168,27 @@ mod test { // absolute names assert_eq!( - RelativeDname::from_slice(b"\x03www\x07example\x03com\0"), - Err(RelativeDnameError(RelativeDnameErrorEnum::AbsoluteName)) + RelativeName::from_slice(b"\x03www\x07example\x03com\0"), + Err(RelativeNameError(RelativeNameErrorEnum::AbsoluteName)) ); assert_eq!( - RelativeDname::from_slice(b"\0"), - Err(RelativeDnameError(RelativeDnameErrorEnum::AbsoluteName)) + RelativeName::from_slice(b"\0"), + Err(RelativeNameError(RelativeNameErrorEnum::AbsoluteName)) ); // bytes shorter than what label length says. assert_eq!( - RelativeDname::from_slice(b"\x03www\x07exa"), - Err(RelativeDnameError(RelativeDnameErrorEnum::ShortInput)) + RelativeName::from_slice(b"\x03www\x07exa"), + Err(RelativeNameError(RelativeNameErrorEnum::ShortInput)) ); // label 63 long ok, 64 bad. let mut slice = [0u8; 64]; slice[0] = 63; - assert!(RelativeDname::from_slice(&slice[..]).is_ok()); + assert!(RelativeName::from_slice(&slice[..]).is_ok()); let mut slice = [0u8; 65]; slice[0] = 64; - assert!(RelativeDname::from_slice(&slice[..]).is_err()); + assert!(RelativeName::from_slice(&slice[..]).is_err()); // name 254 long ok, 255 bad. let mut buf = Vec::new(); @@ -1203,22 +1198,22 @@ mod test { assert_eq!(buf.len(), 250); let mut tmp = buf.clone(); tmp.extend_from_slice(b"\x03123"); - assert_eq!(RelativeDname::from_slice(&tmp).map(|_| ()), Ok(())); + assert_eq!(RelativeName::from_slice(&tmp).map(|_| ()), Ok(())); buf.extend_from_slice(b"\x041234"); - assert!(RelativeDname::from_slice(&buf).is_err()); + assert!(RelativeName::from_slice(&buf).is_err()); // bad label heads: compressed, other types. assert_eq!( - RelativeDname::from_slice(b"\xa2asdasds"), + RelativeName::from_slice(b"\xa2asdasds"), Err(LabelTypeError::Undefined.into()) ); assert_eq!( - RelativeDname::from_slice(b"\x62asdasds"), + RelativeName::from_slice(b"\x62asdasds"), Err(LabelTypeError::Extended(0x62).into()) ); assert_eq!( - RelativeDname::from_slice(b"\xccasdasds"), - Err(RelativeDnameError(RelativeDnameErrorEnum::CompressedName)) + RelativeName::from_slice(b"\xccasdasds"), + Err(RelativeNameError(RelativeNameErrorEnum::CompressedName)) ); } @@ -1226,25 +1221,25 @@ mod test { #[cfg(feature = "std")] fn from_str() { // empty name - assert_eq!(RelativeDname::vec_from_str("").unwrap().as_slice(), b""); + assert_eq!(RelativeName::vec_from_str("").unwrap().as_slice(), b""); // relative name assert_eq!( - RelativeDname::vec_from_str("www.example") + RelativeName::vec_from_str("www.example") .unwrap() .as_slice(), b"\x03www\x07example" ); // absolute name - assert!(RelativeDname::vec_from_str("www.example.com.").is_err()); + assert!(RelativeName::vec_from_str("www.example.com.").is_err()); } #[test] #[cfg(feature = "std")] fn into_absolute() { assert_eq!( - RelativeDname::from_octets(Vec::from( + RelativeName::from_octets(Vec::from( b"\x03www\x07example\x03com".as_ref() )) .unwrap() @@ -1262,7 +1257,7 @@ mod test { assert_eq!(buf.len(), 250); let mut tmp = buf.clone(); tmp.extend_from_slice(b"\x03123"); - RelativeDname::from_octets(tmp) + RelativeName::from_octets(tmp) .unwrap() .into_absolute() .unwrap(); @@ -1271,11 +1266,11 @@ mod test { #[test] #[cfg(feature = "std")] fn make_canonical() { - let mut name = Dname::vec_from_str("wWw.exAmpLE.coM.").unwrap(); + let mut name = Name::vec_from_str("wWw.exAmpLE.coM.").unwrap(); name.make_canonical(); assert_eq!( name, - Dname::from_octets(b"\x03www\x07example\x03com\0").unwrap() + Name::from_octets(b"\x03www\x07example\x03com\0").unwrap() ); } @@ -1284,8 +1279,8 @@ mod test { #[test] fn chain_root() { assert_eq!( - Dname::from_octets(b"\x03www\x07example\x03com\0").unwrap(), - RelativeDname::from_octets(b"\x03www\x07example\x03com") + Name::from_octets(b"\x03www\x07example\x03com\0").unwrap(), + RelativeName::from_octets(b"\x03www\x07example\x03com") .unwrap() .chain_root() ); @@ -1293,12 +1288,12 @@ mod test { #[test] fn iter() { - use crate::base::name::dname::test::cmp_iter; + use crate::base::name::absolute::test::cmp_iter; - cmp_iter(RelativeDname::empty_ref().iter(), &[]); - cmp_iter(RelativeDname::wildcard_ref().iter(), &[b"*"]); + cmp_iter(RelativeName::empty_ref().iter(), &[]); + cmp_iter(RelativeName::wildcard_ref().iter(), &[b"*"]); cmp_iter( - RelativeDname::from_slice(b"\x03www\x07example\x03com") + RelativeName::from_slice(b"\x03www\x07example\x03com") .unwrap() .iter(), &[b"www", b"example", b"com"], @@ -1307,12 +1302,12 @@ mod test { #[test] fn iter_back() { - use crate::base::name::dname::test::cmp_iter_back; + use crate::base::name::absolute::test::cmp_iter_back; - cmp_iter_back(RelativeDname::empty_ref().iter(), &[]); - cmp_iter_back(RelativeDname::wildcard_ref().iter(), &[b"*"]); + cmp_iter_back(RelativeName::empty_ref().iter(), &[]); + cmp_iter_back(RelativeName::wildcard_ref().iter(), &[b"*"]); cmp_iter_back( - RelativeDname::from_slice(b"\x03www\x07example\x03com") + RelativeName::from_slice(b"\x03www\x07example\x03com") .unwrap() .iter(), &[b"com", b"example", b"www"], @@ -1321,10 +1316,10 @@ mod test { #[test] fn label_count() { - assert_eq!(RelativeDname::empty_ref().label_count(), 0); - assert_eq!(RelativeDname::wildcard_slice().label_count(), 1); + assert_eq!(RelativeName::empty_ref().label_count(), 0); + assert_eq!(RelativeName::wildcard_slice().label_count(), 1); assert_eq!( - RelativeDname::from_slice(b"\x03www\x07example\x03com") + RelativeName::from_slice(b"\x03www\x07example\x03com") .unwrap() .label_count(), 3 @@ -1333,9 +1328,9 @@ mod test { #[test] fn first() { - assert_eq!(RelativeDname::empty_slice().first(), None); + assert_eq!(RelativeName::empty_slice().first(), None); assert_eq!( - RelativeDname::from_slice(b"\x03www") + RelativeName::from_slice(b"\x03www") .unwrap() .first() .unwrap() @@ -1343,7 +1338,7 @@ mod test { b"www" ); assert_eq!( - RelativeDname::from_slice(b"\x03www\x07example") + RelativeName::from_slice(b"\x03www\x07example") .unwrap() .first() .unwrap() @@ -1354,9 +1349,9 @@ mod test { #[test] fn last() { - assert_eq!(RelativeDname::empty_slice().last(), None); + assert_eq!(RelativeName::empty_slice().last(), None); assert_eq!( - RelativeDname::from_slice(b"\x03www") + RelativeName::from_slice(b"\x03www") .unwrap() .last() .unwrap() @@ -1364,7 +1359,7 @@ mod test { b"www" ); assert_eq!( - RelativeDname::from_slice(b"\x03www\x07example") + RelativeName::from_slice(b"\x03www\x07example") .unwrap() .last() .unwrap() @@ -1375,10 +1370,10 @@ mod test { #[test] fn ndots() { - assert_eq!(RelativeDname::empty_slice().ndots(), 0); - assert_eq!(RelativeDname::from_slice(b"\x03www").unwrap().ndots(), 0); + assert_eq!(RelativeName::empty_slice().ndots(), 0); + assert_eq!(RelativeName::from_slice(b"\x03www").unwrap().ndots(), 0); assert_eq!( - RelativeDname::from_slice(b"\x03www\x07example") + RelativeName::from_slice(b"\x03www\x07example") .unwrap() .ndots(), 1 @@ -1389,28 +1384,28 @@ mod test { fn starts_with() { let matrix = [ ( - RelativeDname::empty_slice(), + RelativeName::empty_slice(), [true, false, false, false, false, false], ), ( - RelativeDname::from_slice(b"\x03www").unwrap(), + RelativeName::from_slice(b"\x03www").unwrap(), [true, true, false, false, false, false], ), ( - RelativeDname::from_slice(b"\x03www\x07example").unwrap(), + RelativeName::from_slice(b"\x03www\x07example").unwrap(), [true, true, true, false, false, false], ), ( - RelativeDname::from_slice(b"\x03www\x07example\x03com") + RelativeName::from_slice(b"\x03www\x07example\x03com") .unwrap(), [true, true, true, true, false, false], ), ( - RelativeDname::from_slice(b"\x07example\x03com").unwrap(), + RelativeName::from_slice(b"\x07example\x03com").unwrap(), [true, false, false, false, true, false], ), ( - RelativeDname::from_slice(b"\x03com").unwrap(), + RelativeName::from_slice(b"\x03com").unwrap(), [true, false, false, false, false, true], ), ]; @@ -1431,28 +1426,28 @@ mod test { fn ends_with() { let matrix = [ ( - RelativeDname::empty_slice(), + RelativeName::empty_slice(), [true, false, false, false, false, false], ), ( - RelativeDname::from_slice(b"\x03www").unwrap(), + RelativeName::from_slice(b"\x03www").unwrap(), [true, true, false, false, false, false], ), ( - RelativeDname::from_slice(b"\x03www\x07example").unwrap(), + RelativeName::from_slice(b"\x03www\x07example").unwrap(), [true, false, true, false, false, false], ), ( - RelativeDname::from_slice(b"\x03www\x07example\x03com") + RelativeName::from_slice(b"\x03www\x07example\x03com") .unwrap(), [true, false, false, true, true, true], ), ( - RelativeDname::from_slice(b"\x07example\x03com").unwrap(), + RelativeName::from_slice(b"\x07example\x03com").unwrap(), [true, false, false, false, true, true], ), ( - RelativeDname::from_slice(b"\x03com").unwrap(), + RelativeName::from_slice(b"\x03com").unwrap(), [true, false, false, false, false, true], ), ]; @@ -1472,7 +1467,7 @@ mod test { #[test] fn is_label_start() { let wec = - RelativeDname::from_slice(b"\x03www\x07example\x03com").unwrap(); + RelativeName::from_slice(b"\x03www\x07example\x03com").unwrap(); assert!(wec.is_label_start(0)); // \x03 assert!(!wec.is_label_start(1)); // w @@ -1498,7 +1493,7 @@ mod test { #[cfg(feature = "std")] fn slice() { let wec = - RelativeDname::from_slice(b"\x03www\x07example\x03com").unwrap(); + RelativeName::from_slice(b"\x03www\x07example\x03com").unwrap(); assert_eq!(wec.slice(0..4).as_slice(), b"\x03www"); assert_eq!(wec.slice(0..12).as_slice(), b"\x03www\x07example"); assert_eq!(wec.slice(4..12).as_slice(), b"\x07example"); @@ -1517,7 +1512,7 @@ mod test { #[cfg(feature = "std")] fn range() { let wec = - RelativeDname::from_octets(b"\x03www\x07example\x03com".as_ref()) + RelativeName::from_octets(b"\x03www\x07example\x03com".as_ref()) .unwrap(); assert_eq!(wec.range(0..4).as_slice(), b"\x03www"); assert_eq!(wec.range(0..12).as_slice(), b"\x03www\x07example"); @@ -1537,7 +1532,7 @@ mod test { #[cfg(feature = "std")] fn split() { let wec = - RelativeDname::from_octets(b"\x03www\x07example\x03com".as_ref()) + RelativeName::from_octets(b"\x03www\x07example\x03com".as_ref()) .unwrap(); let (left, right) = wec.split(0); @@ -1566,7 +1561,7 @@ mod test { #[cfg(feature = "std")] fn truncate() { let wec = - RelativeDname::from_octets(b"\x03www\x07example\x03com".as_ref()) + RelativeName::from_octets(b"\x03www\x07example\x03com".as_ref()) .unwrap(); let mut tmp = wec.clone(); @@ -1594,7 +1589,7 @@ mod test { #[test] fn split_first() { let wec = - RelativeDname::from_octets(b"\x03www\x07example\x03com".as_ref()) + RelativeName::from_octets(b"\x03www\x07example\x03com".as_ref()) .unwrap(); let (label, wec) = wec.split_first().unwrap(); @@ -1614,7 +1609,7 @@ mod test { #[test] fn parent() { let wec = - RelativeDname::from_octets(b"\x03www\x07example\x03com".as_ref()) + RelativeName::from_octets(b"\x03www\x07example\x03com".as_ref()) .unwrap(); let wec = wec.parent().unwrap(); @@ -1632,17 +1627,17 @@ mod test { #[test] fn strip_suffix() { let wec = - RelativeDname::from_octets(b"\x03www\x07example\x03com".as_ref()) + RelativeName::from_octets(b"\x03www\x07example\x03com".as_ref()) .unwrap(); - let ec = RelativeDname::from_octets(b"\x07example\x03com".as_ref()) + let ec = RelativeName::from_octets(b"\x07example\x03com".as_ref()) .unwrap(); - let c = RelativeDname::from_octets(b"\x03com".as_ref()).unwrap(); + let c = RelativeName::from_octets(b"\x03com".as_ref()).unwrap(); let wen = - RelativeDname::from_octets(b"\x03www\x07example\x03net".as_ref()) + RelativeName::from_octets(b"\x03www\x07example\x03net".as_ref()) .unwrap(); - let en = RelativeDname::from_octets(b"\x07example\x03net".as_ref()) + let en = RelativeName::from_octets(b"\x07example\x03net".as_ref()) .unwrap(); - let n = RelativeDname::from_slice(b"\x03net".as_ref()).unwrap(); + let n = RelativeName::from_slice(b"\x03net".as_ref()).unwrap(); let mut tmp = wec.clone(); assert_eq!(tmp.strip_suffix(&wec), Ok(())); @@ -1657,7 +1652,7 @@ mod test { assert_eq!(tmp.as_slice(), b"\x03www\x07example"); let mut tmp = wec.clone(); - assert_eq!(tmp.strip_suffix(&RelativeDname::empty_ref()), Ok(())); + assert_eq!(tmp.strip_suffix(&RelativeName::empty_ref()), Ok(())); assert_eq!(tmp.as_slice(), b"\x03www\x07example\x03com"); assert!(wec.clone().strip_suffix(&wen).is_err()); @@ -1670,45 +1665,42 @@ mod test { #[test] fn eq() { assert_eq!( - RelativeDname::from_slice(b"\x03www\x07example\x03com").unwrap(), - RelativeDname::from_slice(b"\x03www\x07example\x03com").unwrap() + RelativeName::from_slice(b"\x03www\x07example\x03com").unwrap(), + RelativeName::from_slice(b"\x03www\x07example\x03com").unwrap() ); assert_eq!( - RelativeDname::from_slice(b"\x03www\x07example\x03com").unwrap(), - RelativeDname::from_slice(b"\x03wWw\x07eXAMple\x03Com").unwrap() + RelativeName::from_slice(b"\x03www\x07example\x03com").unwrap(), + RelativeName::from_slice(b"\x03wWw\x07eXAMple\x03Com").unwrap() ); assert_eq!( - RelativeDname::from_slice(b"\x03www\x07example\x03com").unwrap(), - &RelativeDname::from_octets(b"\x03www") + RelativeName::from_slice(b"\x03www\x07example\x03com").unwrap(), + &RelativeName::from_octets(b"\x03www") .unwrap() .chain( - RelativeDname::from_octets(b"\x07example\x03com") - .unwrap() + RelativeName::from_octets(b"\x07example\x03com").unwrap() ) .unwrap() ); assert_eq!( - RelativeDname::from_slice(b"\x03www\x07example\x03com").unwrap(), - &RelativeDname::from_octets(b"\x03wWw") + RelativeName::from_slice(b"\x03www\x07example\x03com").unwrap(), + &RelativeName::from_octets(b"\x03wWw") .unwrap() .chain( - RelativeDname::from_octets(b"\x07eXAMple\x03coM") - .unwrap() + RelativeName::from_octets(b"\x07eXAMple\x03coM").unwrap() ) .unwrap() ); assert_ne!( - RelativeDname::from_slice(b"\x03www\x07example\x03com").unwrap(), - RelativeDname::from_slice(b"\x03ww4\x07example\x03com").unwrap() + RelativeName::from_slice(b"\x03www\x07example\x03com").unwrap(), + RelativeName::from_slice(b"\x03ww4\x07example\x03com").unwrap() ); assert_ne!( - RelativeDname::from_slice(b"\x03www\x07example\x03com").unwrap(), - &RelativeDname::from_octets(b"\x03www") + RelativeName::from_slice(b"\x03www\x07example\x03com").unwrap(), + &RelativeName::from_octets(b"\x03www") .unwrap() .chain( - RelativeDname::from_octets(b"\x073xample\x03com") - .unwrap() + RelativeName::from_octets(b"\x073xample\x03com").unwrap() ) .unwrap() ); @@ -1720,16 +1712,16 @@ mod test { // The following is taken from section 6.1 of RFC 4034. let names = [ - RelativeDname::from_slice(b"\x07example").unwrap(), - RelativeDname::from_slice(b"\x01a\x07example").unwrap(), - RelativeDname::from_slice(b"\x08yljkjljk\x01a\x07example") + RelativeName::from_slice(b"\x07example").unwrap(), + RelativeName::from_slice(b"\x01a\x07example").unwrap(), + RelativeName::from_slice(b"\x08yljkjljk\x01a\x07example") .unwrap(), - RelativeDname::from_slice(b"\x01Z\x01a\x07example").unwrap(), - RelativeDname::from_slice(b"\x04zABC\x01a\x07example").unwrap(), - RelativeDname::from_slice(b"\x01z\x07example").unwrap(), - RelativeDname::from_slice(b"\x01\x01\x01z\x07example").unwrap(), - RelativeDname::from_slice(b"\x01*\x01z\x07example").unwrap(), - RelativeDname::from_slice(b"\x01\xc8\x01z\x07example").unwrap(), + RelativeName::from_slice(b"\x01Z\x01a\x07example").unwrap(), + RelativeName::from_slice(b"\x04zABC\x01a\x07example").unwrap(), + RelativeName::from_slice(b"\x01z\x07example").unwrap(), + RelativeName::from_slice(b"\x01\x01\x01z\x07example").unwrap(), + RelativeName::from_slice(b"\x01*\x01z\x07example").unwrap(), + RelativeName::from_slice(b"\x01\xc8\x01z\x07example").unwrap(), ]; for i in 0..names.len() { for j in 0..names.len() { @@ -1740,9 +1732,9 @@ mod test { } let n1 = - RelativeDname::from_slice(b"\x03www\x07example\x03com").unwrap(); + RelativeName::from_slice(b"\x03www\x07example\x03com").unwrap(); let n2 = - RelativeDname::from_slice(b"\x03wWw\x07eXAMple\x03Com").unwrap(); + RelativeName::from_slice(b"\x03wWw\x07eXAMple\x03Com").unwrap(); assert_eq!(n1.partial_cmp(n2), Some(Ordering::Equal)); assert_eq!(n1.cmp(n2), Ordering::Equal); } @@ -1755,10 +1747,10 @@ mod test { let mut s1 = DefaultHasher::new(); let mut s2 = DefaultHasher::new(); - RelativeDname::from_slice(b"\x03www\x07example\x03com") + RelativeName::from_slice(b"\x03www\x07example\x03com") .unwrap() .hash(&mut s1); - RelativeDname::from_slice(b"\x03wWw\x07eXAMple\x03Com") + RelativeName::from_slice(b"\x03wWw\x07eXAMple\x03Com") .unwrap() .hash(&mut s2); assert_eq!(s1.finish(), s2.finish()); @@ -1770,7 +1762,7 @@ mod test { use std::string::ToString; fn cmp(bytes: &[u8], fmt: &str) { - let name = RelativeDname::from_octets(bytes).unwrap(); + let name = RelativeName::from_octets(bytes).unwrap(); assert_eq!(name.to_string(), fmt); } @@ -1784,7 +1776,7 @@ mod test { fn ser_de() { use serde_test::{assert_tokens, Configure, Token}; - let name = RelativeDname::from_octets(Vec::from( + let name = RelativeName::from_octets(Vec::from( b"\x03www\x07example\x03com".as_ref(), )) .unwrap(); @@ -1792,7 +1784,7 @@ mod test { &name.clone().compact(), &[ Token::NewtypeStruct { - name: "RelativeDname", + name: "RelativeName", }, Token::ByteBuf(b"\x03www\x07example\x03com"), ], @@ -1801,7 +1793,7 @@ mod test { &name.readable(), &[ Token::NewtypeStruct { - name: "RelativeDname", + name: "RelativeName", }, Token::Str("www.example.com"), ], diff --git a/src/base/name/traits.rs b/src/base/name/traits.rs index 2d7df1276..347010e32 100644 --- a/src/base/name/traits.rs +++ b/src/base/name/traits.rs @@ -2,10 +2,10 @@ //! //! This is a private module. Its public traits are re-exported by the parent. +use super::absolute::Name; use super::chain::{Chain, LongChainError}; -use super::dname::Dname; use super::label::Label; -use super::relative::RelativeDname; +use super::relative::RelativeName; #[cfg(feature = "bytes")] use bytes::Bytes; use core::cmp; @@ -21,13 +21,13 @@ use std::borrow::Cow; /// A type that can produce an iterator over its labels. /// -/// This trait is used as a trait bound for both [`ToDname`] and -/// [`ToRelativeDname`]. It is separate since it has to be generic over the +/// This trait is used as a trait bound for both [`ToName`] and +/// [`ToRelativeName`]. It is separate since it has to be generic over the /// lifetime of the label reference but we don’t want to have this lifetime /// parameter pollute those traits. /// -/// [`ToDname`]: trait.ToDname.html -/// [`ToRelativeDname`]: trait ToRelativeDname.html +/// [`ToName`]: trait.ToName.html +/// [`ToRelativeName`]: trait ToRelativeName.html #[allow(clippy::len_without_is_empty)] pub trait ToLabelIter { /// The type of the iterator over the labels. @@ -92,7 +92,7 @@ impl<'r, N: ToLabelIter + ?Sized> ToLabelIter for &'r N { } } -//------------ ToDname ------------------------------------------------------- +//------------ ToName ------------------------------------------------------- /// A type that represents an absolute domain name. /// @@ -102,22 +102,22 @@ impl<'r, N: ToLabelIter + ?Sized> ToLabelIter for &'r N { /// label sequence via an iterator and know how to compose the wire-format /// representation into a buffer. /// -/// The most common types implementing this trait are [`Dname`], -/// [`ParsedDname`], and [`Chain`] where `R` is `ToDname` itself. +/// The most common types implementing this trait are [`Name`], +/// [`ParsedDname`], and [`Chain`] where `R` is `ToName` itself. /// /// [`Chain`]: struct.Chain.html -/// [`Dname`]: struct.Dname.html +/// [`Name`]: struct.Name.html /// [`ParsedDname`]: struct.ParsedDname.html -pub trait ToDname: ToLabelIter { +pub trait ToName: ToLabelIter { /// Converts the name into a single, uncompressed name. /// /// The default implementation provided by the trait iterates over the - /// labels of the name and adds them one by one to [`Dname`]. This will + /// labels of the name and adds them one by one to [`Name`]. This will /// work for any name but an optimized implementation can be provided for /// some types of names. - fn try_to_dname( + fn try_to_name( &self, - ) -> Result, BuilderAppendError> + ) -> Result, BuilderAppendError> where Octets: FromBuilder, ::Builder: EmptyBuilder, @@ -126,27 +126,27 @@ pub trait ToDname: ToLabelIter { Octets::Builder::with_capacity(self.compose_len().into()); self.iter_labels() .try_for_each(|label| label.compose(&mut builder))?; - Ok(unsafe { Dname::from_octets_unchecked(builder.freeze()) }) + Ok(unsafe { Name::from_octets_unchecked(builder.freeze()) }) } /// Converts the name into a single, uncompressed name. /// - /// This is the same as [`try_to_dname`][ToDname::try_to_dname] but for + /// This is the same as [`try_to_name`][ToName::try_to_name] but for /// builder types with an unrestricted buffer. - fn to_dname(&self) -> Dname + fn to_name(&self) -> Name where Octets: FromBuilder, ::Builder: OctetsBuilder, ::Builder: EmptyBuilder, { - infallible(self.try_to_dname()) + infallible(self.try_to_name()) } /// Converts the name into a single name in canonical form. - fn try_to_canonical_dname( + fn try_to_canonical_name( &self, - ) -> Result, BuilderAppendError> + ) -> Result, BuilderAppendError> where Octets: FromBuilder, ::Builder: EmptyBuilder, @@ -155,22 +155,22 @@ pub trait ToDname: ToLabelIter { Octets::Builder::with_capacity(self.compose_len().into()); self.iter_labels() .try_for_each(|label| label.compose_canonical(&mut builder))?; - Ok(unsafe { Dname::from_octets_unchecked(builder.freeze()) }) + Ok(unsafe { Name::from_octets_unchecked(builder.freeze()) }) } /// Converts the name into a single name in canonical form. /// /// This is the same as - /// [`try_to_canonical_dname`][ToDname::try_to_canonical_dname] but for + /// [`try_to_canonical_name`][ToName::try_to_canonical_name] but for /// builder types with an unrestricted buffer. - fn to_canonical_dname(&self) -> Dname + fn to_canonical_name(&self) -> Name where Octets: FromBuilder, ::Builder: OctetsBuilder, ::Builder: EmptyBuilder, { - infallible(self.try_to_canonical_dname()) + infallible(self.try_to_canonical_name()) } /// Returns an octets slice of the content if possible. @@ -213,39 +213,39 @@ pub trait ToDname: ToLabelIter { /// /// If the name is available as one single slice – i.e., /// [`as_flat_slice`] returns ‘some,’ creates the borrowed variant from - /// that slice. Otherwise assembles an owned variant via [`to_dname`]. + /// that slice. Otherwise assembles an owned variant via [`to_name`]. /// /// [`as_flat_slice`]: #method.as_flat_slice - /// [`to_dname`]: #method.to_dname + /// [`to_name`]: #method.to_name #[cfg(feature = "std")] - fn to_cow(&self) -> Dname> { + fn to_cow(&self) -> Name> { let octets = self .as_flat_slice() .map(Cow::Borrowed) .unwrap_or_else(|| Cow::Owned(self.to_vec().into_octets())); - unsafe { Dname::from_octets_unchecked(octets) } + unsafe { Name::from_octets_unchecked(octets) } } /// Returns the domain name assembled into a `Vec`. #[cfg(feature = "std")] - fn to_vec(&self) -> Dname> { - self.to_dname() + fn to_vec(&self) -> Name> { + self.to_name() } /// Returns the domain name assembled into a bytes value. #[cfg(feature = "bytes")] - fn to_bytes(&self) -> Dname { - self.to_dname() + fn to_bytes(&self) -> Name { + self.to_name() } /// Tests whether `self` and `other` are equal. /// /// This method can be used to implement `PartialEq` on types implementing - /// `ToDname` since a blanket implementation for all pairs of `ToDname` + /// `ToName` since a blanket implementation for all pairs of `ToName` /// is currently impossible. /// /// Domain names are compared ignoring ASCII case. - fn name_eq(&self, other: &N) -> bool { + fn name_eq(&self, other: &N) -> bool { if let (Some(left), Some(right)) = (self.as_flat_slice(), other.as_flat_slice()) { @@ -260,14 +260,14 @@ pub trait ToDname: ToLabelIter { /// Returns the ordering between `self` and `other`. /// /// This method can be used to implement both `PartialOrd` and `Ord` on - /// types implementing `ToDname` since a blanket implementation for all - /// pairs of `ToDname`s is currently not possible. + /// types implementing `ToName` since a blanket implementation for all + /// pairs of `ToName`s is currently not possible. /// /// Domain name order is determined according to the ‘canonical DNS /// name order’ as defined in [section 6.1 of RFC 4034][RFC4034-6.1]. /// /// [RFC4034-6.1]: https://tools.ietf.org/html/rfc4034#section-6.1 - fn name_cmp(&self, other: &N) -> cmp::Ordering { + fn name_cmp(&self, other: &N) -> cmp::Ordering { let mut self_iter = self.iter_labels(); let mut other_iter = other.iter_labels(); loop { @@ -284,7 +284,7 @@ pub trait ToDname: ToLabelIter { } /// Returns the composed name ordering. - fn composed_cmp(&self, other: &N) -> cmp::Ordering { + fn composed_cmp(&self, other: &N) -> cmp::Ordering { if let (Some(left), Some(right)) = (self.as_flat_slice(), other.as_flat_slice()) { @@ -310,7 +310,7 @@ pub trait ToDname: ToLabelIter { } /// Returns the lowercase composed ordering. - fn lowercase_composed_cmp( + fn lowercase_composed_cmp( &self, other: &N, ) -> cmp::Ordering { @@ -351,9 +351,9 @@ pub trait ToDname: ToLabelIter { } } -impl<'a, N: ToDname + ?Sized + 'a> ToDname for &'a N {} +impl<'a, N: ToName + ?Sized + 'a> ToName for &'a N {} -//------------ ToRelativeDname ----------------------------------------------- +//------------ ToRelativeName ------------------------------------------------ /// A type that represents a relative domain name. /// @@ -365,24 +365,24 @@ impl<'a, N: ToDname + ?Sized + 'a> ToDname for &'a N {} /// one character long root label, a valid absolute name can be constructed /// from the relative name. /// -/// The most important types implementing this trait are [`RelativeDname`] -/// and [`Chain`] where `R` is a `ToRelativeDname` itself. +/// The most important types implementing this trait are [`RelativeName`] +/// and [`Chain`] where `R` is a `ToRelativeName` itself. /// /// [`Chain`]: struct.Chain.html -/// [`RelativeDname`]: struct.RelativeDname.html -pub trait ToRelativeDname: ToLabelIter { +/// [`RelativeName`]: struct.RelativeName.html +pub trait ToRelativeName: ToLabelIter { /// Converts the name into a single, continous name. /// /// The canonical implementation provided by the trait iterates over the - /// labels of the name and adds them one by one to [`RelativeDname`]. + /// labels of the name and adds them one by one to [`RelativeName`]. /// This will work for any name but an optimized implementation can be /// provided for /// some types of names. /// - /// [`RelativeDname`]: struct.RelativeDname.html - fn try_to_relative_dname( + /// [`RelativeName`]: struct.RelativeName.html + fn try_to_relative_name( &self, - ) -> Result, BuilderAppendError> + ) -> Result, BuilderAppendError> where Octets: FromBuilder, ::Builder: EmptyBuilder, @@ -391,28 +391,28 @@ pub trait ToRelativeDname: ToLabelIter { Octets::Builder::with_capacity(self.compose_len().into()); self.iter_labels() .try_for_each(|label| label.compose(&mut builder))?; - Ok(unsafe { RelativeDname::from_octets_unchecked(builder.freeze()) }) + Ok(unsafe { RelativeName::from_octets_unchecked(builder.freeze()) }) } /// Converts the name into a single, continous name. /// /// This is the same as - /// [`try_to_relative_dname`][ToRelativeDname::try_to_relative_dname] + /// [`try_to_relative_name`][ToRelativeName::try_to_relative_name] /// but for builder types with an unrestricted buffer. - fn to_relative_dname(&self) -> RelativeDname + fn to_relative_name(&self) -> RelativeName where Octets: FromBuilder, ::Builder: OctetsBuilder, ::Builder: EmptyBuilder, { - infallible(self.try_to_relative_dname()) + infallible(self.try_to_relative_name()) } /// Converts the name into a single name in canonical form. - fn try_to_canonical_relative_dname( + fn try_to_canonical_relative_name( &self, - ) -> Result, BuilderAppendError> + ) -> Result, BuilderAppendError> where Octets: FromBuilder, ::Builder: EmptyBuilder, @@ -421,22 +421,22 @@ pub trait ToRelativeDname: ToLabelIter { Octets::Builder::with_capacity(self.compose_len().into()); self.iter_labels() .try_for_each(|label| label.compose_canonical(&mut builder))?; - Ok(unsafe { RelativeDname::from_octets_unchecked(builder.freeze()) }) + Ok(unsafe { RelativeName::from_octets_unchecked(builder.freeze()) }) } /// Converts the name into a single name in canonical form. /// /// This is the same as - /// [`try_to_canonical_relative_dname`][ToRelativeDname::try_to_canonical_relative_dname] + /// [`try_to_canonical_relative_name`][ToRelativeName::try_to_canonical_relative_name] /// but for builder types with an unrestricted buffer. - fn to_canonical_relative_dname(&self) -> RelativeDname + fn to_canonical_relative_name(&self) -> RelativeName where Octets: FromBuilder, ::Builder: OctetsBuilder, ::Builder: EmptyBuilder, { - infallible(self.try_to_canonical_relative_dname()) + infallible(self.try_to_canonical_relative_name()) } /// Returns a byte slice of the content if possible. @@ -475,29 +475,30 @@ pub trait ToRelativeDname: ToLabelIter { /// /// If the name is available as one single slice – i.e., /// [`as_flat_slice`] returns ‘some,’ creates the borrowed variant from - /// that slice. Otherwise assembles an owned variant via [`to_dname`]. + /// that slice. Otherwise assembles an owned variant via + /// [`to_relative_name`]. /// /// [`as_flat_slice`]: #method.as_flat_slice - /// [`to_dname`]: #method.to_dname + /// [`to_relatove_name`]: #method.to_relative_name #[cfg(feature = "std")] - fn to_cow(&self) -> RelativeDname> { + fn to_cow(&self) -> RelativeName> { let octets = self .as_flat_slice() .map(Cow::Borrowed) .unwrap_or_else(|| Cow::Owned(self.to_vec().into_octets())); - unsafe { RelativeDname::from_octets_unchecked(octets) } + unsafe { RelativeName::from_octets_unchecked(octets) } } /// Returns the domain name assembled into a `Vec`. #[cfg(feature = "std")] - fn to_vec(&self) -> RelativeDname> { - self.to_relative_dname() + fn to_vec(&self) -> RelativeName> { + self.to_relative_name() } /// Returns the domain name assembled into a bytes value. #[cfg(feature = "bytes")] - fn to_bytes(&self) -> RelativeDname { - self.to_relative_dname() + fn to_bytes(&self) -> RelativeName { + self.to_relative_name() } /// Returns whether the name is empty. @@ -517,22 +518,22 @@ pub trait ToRelativeDname: ToLabelIter { } /// Returns the absolute name by chaining it with the root label. - fn chain_root(self) -> Chain> + fn chain_root(self) -> Chain> where Self: Sized, { // Appending the root label will always work. - Chain::new(self, Dname::root()).unwrap() + Chain::new(self, Name::root()).unwrap() } /// Tests whether `self` and `other` are equal. /// /// This method can be used to implement `PartialEq` on types implementing - /// `ToDname` since a blanket implementation for all pairs of `ToDname` + /// `ToName` since a blanket implementation for all pairs of `ToName` /// is currently impossible. /// /// Domain names are compared ignoring ASCII case. - fn name_eq(&self, other: &N) -> bool { + fn name_eq(&self, other: &N) -> bool { if let (Some(left), Some(right)) = (self.as_flat_slice(), other.as_flat_slice()) { @@ -545,8 +546,8 @@ pub trait ToRelativeDname: ToLabelIter { /// Returns the ordering between `self` and `other`. /// /// This method can be used to implement both `PartialOrd` and `Ord` on - /// types implementing `ToDname` since a blanket implementation for all - /// pairs of `ToDname`s is currently not possible. + /// types implementing `ToName` since a blanket implementation for all + /// pairs of `ToName`s is currently not possible. /// /// Domain name order is determined according to the ‘canonical DNS /// name order’ as defined in [section 6.1 of RFC 4034][RFC4034-6.1]. @@ -556,7 +557,7 @@ pub trait ToRelativeDname: ToLabelIter { /// same name. /// /// [RFC4034-6.1]: https://tools.ietf.org/html/rfc4034#section-6.1 - fn name_cmp( + fn name_cmp( &self, other: &N, ) -> cmp::Ordering { @@ -576,7 +577,7 @@ pub trait ToRelativeDname: ToLabelIter { } } -impl<'a, N: ToRelativeDname + ?Sized + 'a> ToRelativeDname for &'a N {} +impl<'a, N: ToRelativeName + ?Sized + 'a> ToRelativeName for &'a N {} //------------ FlattenInto --------------------------------------------------- diff --git a/src/base/name/uncertain.rs b/src/base/name/uncertain.rs index 89024e385..9f6c48e72 100644 --- a/src/base/name/uncertain.rs +++ b/src/base/name/uncertain.rs @@ -4,11 +4,11 @@ use super::super::scan::Scanner; use super::super::wire::ParseError; -use super::builder::{DnameBuilder, FromStrError, PushError}; +use super::absolute::Name; +use super::builder::{FromStrError, NameBuilder, PushError}; use super::chain::{Chain, LongChainError}; -use super::dname::Dname; use super::label::{Label, LabelTypeError, SplitLabelError}; -use super::relative::{DnameIter, RelativeDname}; +use super::relative::{NameIter, RelativeName}; use super::traits::ToLabelIter; #[cfg(feature = "bytes")] use bytes::Bytes; @@ -21,27 +21,27 @@ use octseq::serde::{DeserializeOctets, SerializeOctets}; #[cfg(feature = "std")] use std::vec::Vec; -//------------ UncertainDname ------------------------------------------------ +//------------ UncertainName ------------------------------------------------ /// A domain name that may be absolute or relative. /// /// This type is helpful when reading a domain name from some source where it /// may end up being absolute or not. #[derive(Clone)] -pub enum UncertainDname { - Absolute(Dname), - Relative(RelativeDname), +pub enum UncertainName { + Absolute(Name), + Relative(RelativeName), } -impl UncertainDname { +impl UncertainName { /// Creates a new uncertain domain name from an absolute domain name. - pub fn absolute(name: Dname) -> Self { - UncertainDname::Absolute(name) + pub fn absolute(name: Name) -> Self { + UncertainName::Absolute(name) } /// Creates a new uncertain domain name from a relative domain name. - pub fn relative(name: RelativeDname) -> Self { - UncertainDname::Relative(name) + pub fn relative(name: RelativeName) -> Self { + UncertainName::Relative(name) } /// Creates a new uncertain domain name containing the root label only. @@ -50,7 +50,7 @@ impl UncertainDname { where Octets: From<&'static [u8]>, { - UncertainDname::Absolute(Dname::root()) + UncertainName::Absolute(Name::root()) } /// Creates a new uncertain yet empty domain name. @@ -59,7 +59,7 @@ impl UncertainDname { where Octets: From<&'static [u8]>, { - UncertainDname::Relative(RelativeDname::empty()) + UncertainName::Relative(RelativeName::empty()) } /// Creates a new domain name from its wire format representation. @@ -71,12 +71,12 @@ impl UncertainDname { Octets: AsRef<[u8]>, { if Self::is_slice_absolute(octets.as_ref())? { - Ok(UncertainDname::Absolute(unsafe { - Dname::from_octets_unchecked(octets) + Ok(UncertainName::Absolute(unsafe { + Name::from_octets_unchecked(octets) })) } else { - Ok(UncertainDname::Relative(unsafe { - RelativeDname::from_octets_unchecked(octets) + Ok(UncertainName::Relative(unsafe { + RelativeName::from_octets_unchecked(octets) })) } } @@ -85,7 +85,7 @@ impl UncertainDname { fn is_slice_absolute( mut slice: &[u8], ) -> Result { - if slice.len() > Dname::MAX_LEN { + if slice.len() > Name::MAX_LEN { return Err(UncertainDnameErrorEnum::LongName.into()); } loop { @@ -131,23 +131,23 @@ impl UncertainDname { C: IntoIterator, { let mut builder = - DnameBuilder::<::Builder>::new(); + NameBuilder::<::Builder>::new(); builder.append_chars(chars)?; if builder.in_label() || builder.is_empty() { Ok(builder.finish().into()) } else { - Ok(builder.into_dname()?.into()) + Ok(builder.into_name()?.into()) } } - pub fn scan>>( + pub fn scan>>( scanner: &mut S, ) -> Result { - scanner.scan_dname().map(UncertainDname::Absolute) + scanner.scan_name().map(UncertainName::Absolute) } } -impl UncertainDname<&'static [u8]> { +impl UncertainName<&'static [u8]> { /// Creates an empty relative name atop a slice reference. #[must_use] pub fn empty_ref() -> Self { @@ -162,7 +162,7 @@ impl UncertainDname<&'static [u8]> { } #[cfg(feature = "std")] -impl UncertainDname> { +impl UncertainName> { /// Creates an empty relative name atop a `Vec`. #[must_use] pub fn empty_vec() -> Self { @@ -177,7 +177,7 @@ impl UncertainDname> { } #[cfg(feature = "bytes")] -impl UncertainDname { +impl UncertainName { /// Creates an empty relative name atop a bytes value. pub fn empty_bytes() -> Self { Self::empty() @@ -189,12 +189,12 @@ impl UncertainDname { } } -impl UncertainDname { +impl UncertainName { /// Returns whether the name is absolute. pub fn is_absolute(&self) -> bool { match *self { - UncertainDname::Absolute(_) => true, - UncertainDname::Relative(_) => false, + UncertainName::Absolute(_) => true, + UncertainName::Relative(_) => false, } } @@ -204,17 +204,17 @@ impl UncertainDname { } /// Returns a reference to an absolute name, if this name is absolute. - pub fn as_absolute(&self) -> Option<&Dname> { + pub fn as_absolute(&self) -> Option<&Name> { match *self { - UncertainDname::Absolute(ref name) => Some(name), + UncertainName::Absolute(ref name) => Some(name), _ => None, } } /// Returns a reference to a relative name, if the name is relative. - pub fn as_relative(&self) -> Option<&RelativeDname> { + pub fn as_relative(&self) -> Option<&RelativeName> { match *self { - UncertainDname::Relative(ref name) => Some(name), + UncertainName::Relative(ref name) => Some(name), _ => None, } } @@ -222,27 +222,27 @@ impl UncertainDname { /// Converts the name into an absolute name. /// /// If the name is relative, appends the root label to it using - /// [`RelativeDname::into_absolute`]. + /// [`RelativeName::into_absolute`]. /// - /// [`RelativeDname::into_absolute`]: - /// struct.RelativeDname.html#method.into_absolute - pub fn into_absolute(self) -> Result, PushError> + /// [`RelativeName::into_absolute`]: + /// struct.RelativeName.html#method.into_absolute + pub fn into_absolute(self) -> Result, PushError> where Octets: AsRef<[u8]> + IntoBuilder, ::Builder: FreezeBuilder + AsRef<[u8]> + AsMut<[u8]>, { match self { - UncertainDname::Absolute(name) => Ok(name), - UncertainDname::Relative(name) => name.into_absolute(), + UncertainName::Absolute(name) => Ok(name), + UncertainName::Relative(name) => name.into_absolute(), } } /// Converts the name into an absolute name if it is absolute. /// /// Otherwise, returns itself as the error. - pub fn try_into_absolute(self) -> Result, Self> { - if let UncertainDname::Absolute(name) = self { + pub fn try_into_absolute(self) -> Result, Self> { + if let UncertainName::Absolute(name) = self { Ok(name) } else { Err(self) @@ -252,8 +252,8 @@ impl UncertainDname { /// Converts the name into a relative name if it is relative. /// /// Otherwise just returns itself as the error. - pub fn try_into_relative(self) -> Result, Self> { - if let UncertainDname::Relative(name) = self { + pub fn try_into_relative(self) -> Result, Self> { + if let UncertainName::Relative(name) = self { Ok(name) } else { Err(self) @@ -263,8 +263,8 @@ impl UncertainDname { /// Returns a reference to the underlying octets sequence. pub fn as_octets(&self) -> &Octets { match *self { - UncertainDname::Absolute(ref name) => name.as_octets(), - UncertainDname::Relative(ref name) => name.as_octets(), + UncertainName::Absolute(ref name) => name.as_octets(), + UncertainName::Relative(ref name) => name.as_octets(), } } @@ -274,8 +274,8 @@ impl UncertainDname { Octets: AsRef<[u8]>, { match *self { - UncertainDname::Absolute(ref name) => name.as_slice(), - UncertainDname::Relative(ref name) => name.as_slice(), + UncertainName::Absolute(ref name) => name.as_slice(), + UncertainName::Relative(ref name) => name.as_slice(), } } @@ -298,21 +298,21 @@ impl UncertainDname { //--- From -impl From> for UncertainDname { - fn from(src: Dname) -> Self { - UncertainDname::Absolute(src) +impl From> for UncertainName { + fn from(src: Name) -> Self { + UncertainName::Absolute(src) } } -impl From> for UncertainDname { - fn from(src: RelativeDname) -> Self { - UncertainDname::Relative(src) +impl From> for UncertainName { + fn from(src: RelativeName) -> Self { + UncertainName::Relative(src) } } //--- FromStr -impl str::FromStr for UncertainDname +impl str::FromStr for UncertainName where Octets: FromBuilder, ::Builder: EmptyBuilder @@ -329,34 +329,33 @@ where //--- AsRef -impl AsRef for UncertainDname { +impl AsRef for UncertainName { fn as_ref(&self) -> &Octs { match *self { - UncertainDname::Absolute(ref name) => name.as_ref(), - UncertainDname::Relative(ref name) => name.as_ref(), + UncertainName::Absolute(ref name) => name.as_ref(), + UncertainName::Relative(ref name) => name.as_ref(), } } } -impl> AsRef<[u8]> for UncertainDname { +impl> AsRef<[u8]> for UncertainName { fn as_ref(&self) -> &[u8] { match *self { - UncertainDname::Absolute(ref name) => name.as_ref(), - UncertainDname::Relative(ref name) => name.as_ref(), + UncertainName::Absolute(ref name) => name.as_ref(), + UncertainName::Relative(ref name) => name.as_ref(), } } } //--- PartialEq, and Eq -impl PartialEq> - for UncertainDname +impl PartialEq> for UncertainName where Octets: AsRef<[u8]>, Other: AsRef<[u8]>, { - fn eq(&self, other: &UncertainDname) -> bool { - use UncertainDname::*; + fn eq(&self, other: &UncertainName) -> bool { + use UncertainName::*; match (self, other) { (Absolute(l), Absolute(r)) => l.eq(r), @@ -366,11 +365,11 @@ where } } -impl> Eq for UncertainDname {} +impl> Eq for UncertainName {} //--- Hash -impl> hash::Hash for UncertainDname { +impl> hash::Hash for UncertainName { fn hash(&self, state: &mut H) { for item in self.iter_labels() { item.hash(state) @@ -380,29 +379,29 @@ impl> hash::Hash for UncertainDname { //--- ToLabelIter -impl> ToLabelIter for UncertainDname { - type LabelIter<'a> = DnameIter<'a> where Octs: 'a; +impl> ToLabelIter for UncertainName { + type LabelIter<'a> = NameIter<'a> where Octs: 'a; fn iter_labels(&self) -> Self::LabelIter<'_> { match *self { - UncertainDname::Absolute(ref name) => name.iter_labels(), - UncertainDname::Relative(ref name) => name.iter_labels(), + UncertainName::Absolute(ref name) => name.iter_labels(), + UncertainName::Relative(ref name) => name.iter_labels(), } } fn compose_len(&self) -> u16 { match *self { - UncertainDname::Absolute(ref name) => name.compose_len(), - UncertainDname::Relative(ref name) => name.compose_len(), + UncertainName::Absolute(ref name) => name.compose_len(), + UncertainName::Relative(ref name) => name.compose_len(), } } } //--- IntoIterator -impl<'a, Octets: AsRef<[u8]>> IntoIterator for &'a UncertainDname { +impl<'a, Octets: AsRef<[u8]>> IntoIterator for &'a UncertainName { type Item = &'a Label; - type IntoIter = DnameIter<'a>; + type IntoIter = NameIter<'a>; fn into_iter(self) -> Self::IntoIter { self.iter_labels() @@ -411,25 +410,25 @@ impl<'a, Octets: AsRef<[u8]>> IntoIterator for &'a UncertainDname { //--- Display and Debug -impl> fmt::Display for UncertainDname { +impl> fmt::Display for UncertainName { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - UncertainDname::Absolute(ref name) => { + UncertainName::Absolute(ref name) => { write!(f, "{}.", name) } - UncertainDname::Relative(ref name) => name.fmt(f), + UncertainName::Relative(ref name) => name.fmt(f), } } } -impl> fmt::Debug for UncertainDname { +impl> fmt::Debug for UncertainName { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { - UncertainDname::Absolute(ref name) => { - write!(f, "UncertainDname::Absolute({})", name) + UncertainName::Absolute(ref name) => { + write!(f, "UncertainName::Absolute({})", name) } - UncertainDname::Relative(ref name) => { - write!(f, "UncertainDname::Relative({})", name) + UncertainName::Relative(ref name) => { + write!(f, "UncertainName::Relative({})", name) } } } @@ -438,7 +437,7 @@ impl> fmt::Debug for UncertainDname { //--- Serialize and Deserialize #[cfg(feature = "serde")] -impl serde::Serialize for UncertainDname +impl serde::Serialize for UncertainName where Octets: AsRef<[u8]> + SerializeOctets, { @@ -448,12 +447,12 @@ where ) -> Result { if serializer.is_human_readable() { serializer.serialize_newtype_struct( - "UncertainDname", + "UncertainName", &format_args!("{}", self), ) } else { serializer.serialize_newtype_struct( - "UncertainDname", + "UncertainName", &self.as_octets().as_serialized_octets(), ) } @@ -461,7 +460,7 @@ where } #[cfg(feature = "serde")] -impl<'de, Octets> serde::Deserialize<'de> for UncertainDname +impl<'de, Octets> serde::Deserialize<'de> for UncertainName where Octets: FromBuilder + DeserializeOctets<'de>, ::Builder: EmptyBuilder @@ -484,7 +483,7 @@ where + AsRef<[u8]> + AsMut<[u8]>, { - type Value = UncertainDname; + type Value = UncertainName; fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str("a domain name") @@ -496,7 +495,7 @@ where ) -> Result { use core::str::FromStr; - UncertainDname::from_str(v).map_err(E::custom) + UncertainName::from_str(v).map_err(E::custom) } fn visit_borrowed_bytes( @@ -504,7 +503,7 @@ where value: &'de [u8], ) -> Result { self.0.visit_borrowed_bytes(value).and_then(|octets| { - UncertainDname::from_octets(octets).map_err(E::custom) + UncertainName::from_octets(octets).map_err(E::custom) }) } @@ -514,7 +513,7 @@ where value: std::vec::Vec, ) -> Result { self.0.visit_byte_buf(value).and_then(|octets| { - UncertainDname::from_octets(octets).map_err(E::custom) + UncertainName::from_octets(octets).map_err(E::custom) }) } } @@ -529,7 +528,7 @@ where + AsRef<[u8]> + AsMut<[u8]>, { - type Value = UncertainDname; + type Value = UncertainName; fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { f.write_str("a domain name") @@ -552,7 +551,7 @@ where } deserializer.deserialize_newtype_struct( - "UncertainDname", + "UncertainName", NewtypeVisitor(PhantomData), ) } @@ -650,7 +649,7 @@ mod test { #[test] fn from_str() { - type U = UncertainDname>; + type U = UncertainName>; fn name(s: &str) -> U { U::from_str(s).unwrap() @@ -730,14 +729,14 @@ mod test { use serde_test::{assert_tokens, Configure, Token}; let abs_name = - UncertainDname::>::from_str("www.example.com.").unwrap(); + UncertainName::>::from_str("www.example.com.").unwrap(); assert!(abs_name.is_absolute()); assert_tokens( &abs_name.clone().compact(), &[ Token::NewtypeStruct { - name: "UncertainDname", + name: "UncertainName", }, Token::ByteBuf(b"\x03www\x07example\x03com\0"), ], @@ -746,21 +745,21 @@ mod test { &abs_name.readable(), &[ Token::NewtypeStruct { - name: "UncertainDname", + name: "UncertainName", }, Token::Str("www.example.com."), ], ); let rel_name = - UncertainDname::>::from_str("www.example.com").unwrap(); + UncertainName::>::from_str("www.example.com").unwrap(); assert!(rel_name.is_relative()); assert_tokens( &rel_name.clone().compact(), &[ Token::NewtypeStruct { - name: "UncertainDname", + name: "UncertainName", }, Token::ByteBuf(b"\x03www\x07example\x03com"), ], @@ -769,7 +768,7 @@ mod test { &rel_name.readable(), &[ Token::NewtypeStruct { - name: "UncertainDname", + name: "UncertainName", }, Token::Str("www.example.com"), ], diff --git a/src/base/net/nostd.rs b/src/base/net/nostd.rs index 59b381f15..71d9a6ae2 100644 --- a/src/base/net/nostd.rs +++ b/src/base/net/nostd.rs @@ -261,6 +261,12 @@ impl From<[u8; 16]> for IpAddr { } } +impl From<[u16; 8]> for IpAddr { + fn from(src: [u16; 8]) -> Self { + IpAddr::V6(src.into()) + } +} + impl From for IpAddr { fn from(addr: Ipv4Addr) -> Self { IpAddr::V4(addr) diff --git a/src/base/opt/chain.rs b/src/base/opt/chain.rs index dec136775..eab6e2d91 100644 --- a/src/base/opt/chain.rs +++ b/src/base/opt/chain.rs @@ -9,7 +9,7 @@ use core::fmt; use super::super::iana::OptionCode; use super::super::message_builder::OptBuilder; -use super::super::name::{Dname, ToDname}; +use super::super::name::{Name, ToName}; use super::super::wire::{Composer, ParseError}; use super::{Opt, OptData, ComposeOptData, ParseOptData}; use octseq::builder::OctetsBuilder; @@ -74,12 +74,12 @@ impl Chain { } } -impl Chain> { +impl Chain> { /// Parses CHAIN option data from its wire format. pub fn parse<'a, Src: Octets = Octs> + ?Sized>( parser: &mut Parser<'a, Src> ) -> Result { - Dname::parse(parser).map(Self::new) + Name::parse(parser).map(Self::new) } } @@ -98,29 +98,29 @@ where Name: OctetsFrom { impl PartialEq> for Chain where - Name: ToDname, - OtherName: ToDname + Name: ToName, + OtherName: ToName { fn eq(&self, other: &Chain) -> bool { self.start().name_eq(other.start()) } } -impl Eq for Chain { } +impl Eq for Chain { } //--- PartialOrd and Ord impl PartialOrd> for Chain where - Name: ToDname, - OtherName: ToDname + Name: ToName, + OtherName: ToName { fn partial_cmp(&self, other: &Chain) -> Option { Some(self.start().name_cmp(other.start())) } } -impl Ord for Chain { +impl Ord for Chain { fn cmp(&self, other: &Self) -> Ordering { self.start().name_cmp(other.start()) } @@ -142,7 +142,7 @@ impl OptData for Chain { } } -impl<'a, Octs> ParseOptData<'a, Octs> for Chain>> +impl<'a, Octs> ParseOptData<'a, Octs> for Chain>> where Octs: Octets { fn parse_option( code: OptionCode, @@ -157,7 +157,7 @@ where Octs: Octets { } } -impl ComposeOptData for Chain { +impl ComposeOptData for Chain { fn compose_len(&self) -> u16 { self.start.compose_len() } @@ -192,7 +192,7 @@ impl Opt { /// /// The CHAIN option allows a client to request that all records that /// are necessary for DNSSEC validation are included in the response. - pub fn chain(&self) -> Option>>> { + pub fn chain(&self) -> Option>>> { self.first() } } @@ -205,7 +205,7 @@ impl<'a, Target: Composer> OptBuilder<'a, Target> { /// The `start` name is the longest suffix of the queried owner name /// for which the client already has all necessary records. pub fn chain( - &mut self, start: impl ToDname + &mut self, start: impl ToName ) -> Result<(), Target::AppendError> { self.push(&Chain::new(start)) } @@ -225,7 +225,7 @@ mod test { #[allow(clippy::redundant_closure)] // lifetimes ... fn chain_compose_parse() { test_option_compose_parse( - &Chain::new(Dname::>::from_str("example.com").unwrap()), + &Chain::new(Name::>::from_str("example.com").unwrap()), |parser| Chain::parse(parser) ); } diff --git a/src/base/opt/macros.rs b/src/base/opt/macros.rs index 93d49063e..d17474f3e 100644 --- a/src/base/opt/macros.rs +++ b/src/base/opt/macros.rs @@ -83,7 +83,7 @@ macro_rules! opt_types { } impl<'a, Octs: Octets> ParseOptData<'a, Octs> - for AllOptData, Dname>> { + for AllOptData, Name>> { fn parse_option( code: OptionCode, parser: &mut Parser<'a, Octs>, @@ -106,7 +106,7 @@ macro_rules! opt_types { } impl ComposeOptData for AllOptData - where Octs: AsRef<[u8]>, Name: ToDname { + where Octs: AsRef<[u8]>, Name: ToName { fn compose_len(&self) -> u16 { match *self { $( $( diff --git a/src/base/opt/mod.rs b/src/base/opt/mod.rs index e340c68fd..8e157eecf 100644 --- a/src/base/opt/mod.rs +++ b/src/base/opt/mod.rs @@ -42,7 +42,7 @@ opt_types! { use super::cmp::CanonicalOrd; use super::header::Header; use super::iana::{Class, OptRcode, OptionCode, Rtype}; -use super::name::{Dname, ToDname}; +use super::name::{Name, ToName}; use super::rdata::{ComposeRecordData, ParseRecordData, RecordData}; use super::record::{Record, Ttl}; use super::wire::{Compose, Composer, FormError, ParseError}; @@ -510,7 +510,7 @@ pub struct OptRecord { impl OptRecord { /// Converts a regular record into an OPT record - pub fn from_record(record: Record>) -> Self { + pub fn from_record(record: Record>) -> Self { OptRecord { udp_payload_size: record.class().to_int(), ext_rcode: (record.ttl().as_secs() >> 24) as u8, @@ -521,12 +521,12 @@ impl OptRecord { } /// Converts the OPT record into a regular record. - pub fn as_record(&self) -> Record<&'static Dname<[u8]>, Opt<&[u8]>> + pub fn as_record(&self) -> Record<&'static Name<[u8]>, Opt<&[u8]>> where Octs: AsRef<[u8]>, { Record::new( - Dname::root_slice(), + Name::root_slice(), Class::from_int(self.udp_payload_size), Ttl::from_secs( u32::from(self.ext_rcode) << 24 @@ -634,7 +634,7 @@ impl Default for OptRecord { //--- From -impl From>> for OptRecord { +impl From>> for OptRecord { fn from(record: Record>) -> Self { Self::from_record(record) } diff --git a/src/base/question.rs b/src/base/question.rs index a3aedc0f5..91fb1472e 100644 --- a/src/base/question.rs +++ b/src/base/question.rs @@ -6,7 +6,7 @@ use super::cmp::CanonicalOrd; use super::iana::{Class, Rtype}; -use super::name::{ParsedDname, ToDname}; +use super::name::{ParsedName, ToName}; use super::wire::{Composer, ParseError}; use core::cmp::Ordering; use core::{fmt, hash}; @@ -22,10 +22,10 @@ use octseq::parse::Parser; /// represents such a question. /// /// Questions are generic over the domain name type. When read from an -/// actual message, a [`ParsedDname`] has to be used because the name part +/// actual message, a [`ParsedName`] has to be used because the name part /// may be compressed. /// -/// [`ParsedDname`]: ../name/struct.ParsedDname.html +/// [`ParsedName`]: ../name/struct.ParsedName.html /// [`MessageBuilder`]: ../message_builder/struct.MessageBuilder.html #[derive(Clone, Copy)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -69,7 +69,7 @@ impl Question { /// # Field Access /// -impl Question { +impl Question { /// Returns a reference to the domain nmae in the question, pub fn qname(&self) -> &N { &self.qname @@ -88,24 +88,24 @@ impl Question { /// # Parsing and Composing /// -impl Question> { +impl Question> { pub fn parse<'a, Src: Octets = Octs> + ?Sized + 'a>( parser: &mut Parser<'a, Src>, ) -> Result { Ok(Question::new( - ParsedDname::parse(parser)?, + ParsedName::parse(parser)?, Rtype::parse(parser)?, Class::parse(parser)?, )) } } -impl Question { +impl Question { pub fn compose( &self, target: &mut Target, ) -> Result<(), Target::AppendError> { - target.append_compressed_dname(&self.qname)?; + target.append_compressed_name(&self.qname)?; self.qtype.compose(target)?; self.qclass.compose(target) } @@ -113,13 +113,13 @@ impl Question { //--- From -impl From<(N, Rtype, Class)> for Question { +impl From<(N, Rtype, Class)> for Question { fn from((name, rtype, class): (N, Rtype, Class)) -> Self { Question::new(name, rtype, class) } } -impl From<(N, Rtype)> for Question { +impl From<(N, Rtype)> for Question { fn from((name, rtype): (N, Rtype)) -> Self { Question::new(name, rtype, Class::IN) } @@ -148,8 +148,8 @@ where impl PartialEq> for Question where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn eq(&self, other: &Question) -> bool { self.qname.name_eq(&other.qname) @@ -158,14 +158,14 @@ where } } -impl Eq for Question {} +impl Eq for Question {} //--- PartialOrd, CanonicalOrd, and Ord impl PartialOrd> for Question where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn partial_cmp(&self, other: &Question) -> Option { match self.qname.name_cmp(&other.qname) { @@ -182,8 +182,8 @@ where impl CanonicalOrd> for Question where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn canonical_cmp(&self, other: &Question) -> Ordering { match self.qname.lowercase_composed_cmp(&other.qname) { @@ -198,7 +198,7 @@ where } } -impl Ord for Question { +impl Ord for Question { fn cmp(&self, other: &Self) -> Ordering { match self.qname.name_cmp(&other.qname) { Ordering::Equal => {} @@ -274,7 +274,7 @@ impl<'a, Q: ComposeQuestion> ComposeQuestion for &'a Q { } } -impl ComposeQuestion for Question { +impl ComposeQuestion for Question { fn compose_question( &self, target: &mut Target, @@ -283,7 +283,7 @@ impl ComposeQuestion for Question { } } -impl ComposeQuestion for (Name, Rtype, Class) { +impl ComposeQuestion for (Name, Rtype, Class) { fn compose_question( &self, target: &mut Target, @@ -292,7 +292,7 @@ impl ComposeQuestion for (Name, Rtype, Class) { } } -impl ComposeQuestion for (Name, Rtype) { +impl ComposeQuestion for (Name, Rtype) { fn compose_question( &self, target: &mut Target, diff --git a/src/base/record.rs b/src/base/record.rs index 3289d2577..8ea99467d 100644 --- a/src/base/record.rs +++ b/src/base/record.rs @@ -17,7 +17,7 @@ use super::cmp::CanonicalOrd; use super::iana::{Class, Rtype}; -use super::name::{FlattenInto, ParsedDname, ToDname}; +use super::name::{FlattenInto, ParsedName, ToName}; use super::rdata::{ ComposeRecordData, ParseAnyRecordData, ParseRecordData, RecordData, }; @@ -188,7 +188,7 @@ impl Record { /// Parsing and Composing /// -impl Record, Data> { +impl Record, Data> { pub fn parse<'a, Src: Octets = Octs> + 'a>( parser: &mut Parser<'a, Src>, ) -> Result, ParseError> @@ -200,12 +200,12 @@ impl Record, Data> { } } -impl Record { +impl Record { pub fn compose( &self, target: &mut Target, ) -> Result<(), Target::AppendError> { - target.append_compressed_dname(&self.owner)?; + target.append_compressed_name(&self.owner)?; self.data.rtype().compose(target)?; self.class.compose(target)?; self.ttl.compose(target)?; @@ -356,8 +356,8 @@ where impl CanonicalOrd> for Record where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, D: RecordData + CanonicalOrd
, DD: RecordData, { @@ -465,7 +465,7 @@ impl<'a, T: ComposeRecord> ComposeRecord for &'a T { impl ComposeRecord for Record where - Name: ToDname, + Name: ToName, Data: ComposeRecordData, { fn compose_record( @@ -478,7 +478,7 @@ where impl ComposeRecord for (Name, Class, u32, Data) where - Name: ToDname, + Name: ToName, Data: ComposeRecordData, { fn compose_record( @@ -492,7 +492,7 @@ where impl ComposeRecord for (Name, Class, Ttl, Data) where - Name: ToDname, + Name: ToName, Data: ComposeRecordData, { fn compose_record( @@ -505,7 +505,7 @@ where impl ComposeRecord for (Name, u32, Data) where - Name: ToDname, + Name: ToName, Data: ComposeRecordData, { fn compose_record( @@ -519,7 +519,7 @@ where impl ComposeRecord for (Name, Ttl, Data) where - Name: ToDname, + Name: ToName, Data: ComposeRecordData, { fn compose_record( @@ -569,8 +569,8 @@ impl RecordHeader { } } -impl<'a, Octs: Octets + ?Sized> RecordHeader> { - fn deref_owner(&self) -> RecordHeader>> { +impl<'a, Octs: Octets + ?Sized> RecordHeader> { + fn deref_owner(&self) -> RecordHeader>> { RecordHeader { owner: self.owner.deref_octets(), rtype: self.rtype, @@ -615,7 +615,7 @@ impl RecordHeader { /// # Parsing and Composing /// -impl RecordHeader> { +impl RecordHeader> { pub fn parse<'a, Src: Octets = Octs>>( parser: &mut Parser<'a, Src>, ) -> Result { @@ -623,12 +623,12 @@ impl RecordHeader> { } } -impl<'a, Octs: AsRef<[u8]> + ?Sized> RecordHeader> { +impl<'a, Octs: AsRef<[u8]> + ?Sized> RecordHeader> { pub fn parse_ref( parser: &mut Parser<'a, Octs>, ) -> Result { Ok(RecordHeader::new( - ParsedDname::parse_ref(parser)?, + ParsedName::parse_ref(parser)?, Rtype::parse(parser)?, Class::parse(parser)?, Ttl::parse(parser)?, @@ -662,7 +662,7 @@ impl RecordHeader<()> { fn parse_rdlen( parser: &mut Parser, ) -> Result { - ParsedDname::skip(parser)?; + ParsedName::skip(parser)?; parser.advance( (Rtype::COMPOSE_LEN + Class::COMPOSE_LEN + u32::COMPOSE_LEN) .into(), @@ -671,7 +671,7 @@ impl RecordHeader<()> { } } -impl RecordHeader> { +impl RecordHeader> { /// Parses the remainder of the record if the record data type supports it. /// /// The method assumes that the parser is currently positioned right @@ -682,7 +682,7 @@ impl RecordHeader> { pub fn parse_into_record<'a, Src, Data>( self, parser: &mut Parser<'a, Src>, - ) -> Result, Data>>, ParseError> + ) -> Result, Data>>, ParseError> where Src: AsRef<[u8]> + ?Sized, Data: ParseRecordData<'a, Src>, @@ -705,7 +705,7 @@ impl RecordHeader> { pub fn parse_into_any_record<'a, Src, Data>( self, parser: &mut Parser<'a, Src>, - ) -> Result, Data>, ParseError> + ) -> Result, Data>, ParseError> where Src: AsRef<[u8]> + ?Sized, Data: ParseAnyRecordData<'a, Src>, @@ -726,12 +726,12 @@ impl RecordHeader> { } } -impl RecordHeader { +impl RecordHeader { pub fn compose( &self, buf: &mut Target, ) -> Result<(), Target::AppendError> { - buf.append_compressed_dname(&self.owner)?; + buf.append_compressed_name(&self.owner)?; self.rtype.compose(buf)?; self.class.compose(buf)?; self.ttl.compose(buf)?; @@ -754,8 +754,8 @@ impl RecordHeader { impl PartialEq> for RecordHeader where - Name: ToDname, - NName: ToDname, + Name: ToName, + NName: ToName, { fn eq(&self, other: &RecordHeader) -> bool { self.owner.name_eq(&other.owner) @@ -766,7 +766,7 @@ where } } -impl Eq for RecordHeader {} +impl Eq for RecordHeader {} //--- PartialOrd and Ord // @@ -774,8 +774,8 @@ impl Eq for RecordHeader {} impl PartialOrd> for RecordHeader where - Name: ToDname, - NName: ToDname, + Name: ToName, + NName: ToName, { fn partial_cmp(&self, other: &RecordHeader) -> Option { match self.owner.name_cmp(&other.owner) { @@ -798,7 +798,7 @@ where } } -impl Ord for RecordHeader { +impl Ord for RecordHeader { fn cmp(&self, other: &Self) -> Ordering { match self.owner.name_cmp(&other.owner) { Ordering::Equal => {} @@ -866,7 +866,7 @@ impl fmt::Debug for RecordHeader { #[derive(Clone)] pub struct ParsedRecord<'a, Octs: Octets + ?Sized> { /// The record’s header. - header: RecordHeader>, + header: RecordHeader>, /// A parser positioned at the beginning of the record’s data. data: Parser<'a, Octs>, @@ -879,7 +879,7 @@ impl<'a, Octs: Octets + ?Sized> ParsedRecord<'a, Octs> { /// first byte of the record data. #[must_use] pub fn new( - header: RecordHeader>, + header: RecordHeader>, data: Parser<'a, Octs>, ) -> Self { ParsedRecord { header, data } @@ -887,7 +887,7 @@ impl<'a, Octs: Octets + ?Sized> ParsedRecord<'a, Octs> { /// Returns a reference to the owner of the record. #[must_use] - pub fn owner(&self) -> ParsedDname<&'a Octs> { + pub fn owner(&self) -> ParsedName<&'a Octs> { *self.header.owner() } @@ -929,7 +929,7 @@ impl<'a, Octs: Octets + ?Sized> ParsedRecord<'a, Octs> { #[allow(clippy::type_complexity)] pub fn to_record( &self, - ) -> Result>, Data>>, ParseError> + ) -> Result>, Data>>, ParseError> where Data: ParseRecordData<'a, Octs>, { @@ -945,7 +945,7 @@ impl<'a, Octs: Octets + ?Sized> ParsedRecord<'a, Octs> { /// this trait for parsing. pub fn to_any_record( &self, - ) -> Result>, Data>, ParseError> + ) -> Result>, Data>, ParseError> where Data: ParseAnyRecordData<'a, Octs>, { @@ -966,7 +966,7 @@ impl<'a, Octs: Octets + ?Sized> ParsedRecord<'a, Octs> { #[allow(clippy::type_complexity)] pub fn into_record( mut self, - ) -> Result>, Data>>, ParseError> + ) -> Result>, Data>>, ParseError> where Data: ParseRecordData<'a, Octs>, { @@ -980,7 +980,7 @@ impl<'a, Octs: Octets + ?Sized> ParsedRecord<'a, Octs> { /// this trait for parsing. #[allow(clippy::type_complexity)] pub fn into_any_record( mut self, - ) -> Result>, Data>, ParseError> + ) -> Result>, Data>, ParseError> where Data: ParseAnyRecordData<'a, Octs>, { @@ -1586,13 +1586,13 @@ mod test { fn ds_octets_into() { use super::*; use crate::base::iana::{Class, DigestAlg, SecAlg}; - use crate::base::name::Dname; + use crate::base::name::Name; use crate::rdata::Ds; use bytes::Bytes; use octseq::octets::OctetsInto; - let ds: Record, Ds<&[u8]>> = Record::new( - Dname::from_octets(b"\x01a\x07example\0".as_ref()).unwrap(), + let ds: Record, Ds<&[u8]>> = Record::new( + Name::from_octets(b"\x01a\x07example\0".as_ref()).unwrap(), Class::IN, Ttl::from_secs(86400), Ds::new( @@ -1603,7 +1603,7 @@ mod test { ) .unwrap(), ); - let ds_bytes: Record, Ds> = + let ds_bytes: Record, Ds> = ds.clone().octets_into(); assert_eq!(ds.owner(), ds_bytes.owner()); assert_eq!(ds.data().digest(), ds_bytes.data().digest()); diff --git a/src/base/scan.rs b/src/base/scan.rs index 62605c586..e3a7aa59f 100644 --- a/src/base/scan.rs +++ b/src/base/scan.rs @@ -22,7 +22,7 @@ #![allow(unused_imports)] // XXX use crate::base::charstr::{CharStr, CharStrBuilder}; -use crate::base::name::{Dname, ToDname}; +use crate::base::name::{Name, ToName}; use crate::base::wire::{Compose, Composer}; use core::convert::{TryFrom, TryInto}; use core::iter::Peekable; @@ -153,7 +153,7 @@ pub trait Scanner { + FreezeBuilder; /// The type of a domain name returned by the scanner. - type Dname: ToDname; + type Name: ToName; /// The error type of the scanner. type Error: ScannerError; @@ -220,7 +220,7 @@ pub trait Scanner { F: FnOnce(&str) -> Result; /// Scans a token into a domain name. - fn scan_dname(&mut self) -> Result; + fn scan_name(&mut self) -> Result; /// Scans a token into a character string. /// @@ -851,7 +851,7 @@ where { type Octets = Octets; type OctetsBuilder = ::Builder; - type Dname = Dname; + type Name = Name; type Error = StrError; fn has_space(&self) -> bool { @@ -957,12 +957,12 @@ where } } - fn scan_dname(&mut self) -> Result { + fn scan_name(&mut self) -> Result { let token = match self.iter.next() { Some(token) => token, None => return Err(StrError::end_of_entry()), }; - Dname::from_symbols(Symbols::new(token.as_ref().chars())) + Name::from_symbols(Symbols::new(token.as_ref().chars())) .map_err(|_| StrError::custom("invalid domain name")) } diff --git a/src/base/wire.rs b/src/base/wire.rs index b10acf0b2..a97ef3386 100644 --- a/src/base/wire.rs +++ b/src/base/wire.rs @@ -1,6 +1,6 @@ //! Creating and consuming data in wire format. -use super::name::ToDname; +use super::name::ToName; use super::net::{Ipv4Addr, Ipv6Addr}; use core::fmt; use octseq::builder::{OctetsBuilder, Truncate}; @@ -23,7 +23,7 @@ pub trait Composer: /// /// The trait provides a default implementation which simply appends the /// name uncompressed. - fn append_compressed_dname( + fn append_compressed_name( &mut self, name: &N, ) -> Result<(), Self::AppendError> { @@ -50,11 +50,11 @@ impl> Composer for smallvec::SmallVec {} impl Composer for heapless::Vec {} impl Composer for &mut T { - fn append_compressed_dname( + fn append_compressed_name( &mut self, name: &N, ) -> Result<(), Self::AppendError> { - Composer::append_compressed_dname(*self, name) + Composer::append_compressed_name(*self, name) } fn can_compress(&self) -> bool { diff --git a/src/net/client/cache.rs b/src/net/client/cache.rs index 8cdcaabda..1f26a27ec 100644 --- a/src/net/client/cache.rs +++ b/src/net/client/cache.rs @@ -30,10 +30,9 @@ #![warn(clippy::missing_docs_in_private_items)] use crate::base::iana::{Class, Opcode, OptRcode, Rtype}; -use crate::base::name::ToDname; +use crate::base::name::ToName; use crate::base::{ - Dname, Header, Message, MessageBuilder, ParsedDname, StaticCompressor, - Ttl, + Header, Message, MessageBuilder, Name, ParsedName, StaticCompressor, Ttl, }; use crate::dep::octseq::Octets; // use crate::net::client::clock::{Clock, Elapsed, SystemClock}; @@ -773,7 +772,7 @@ enum RequestState { /// Note that the AD and DO flags are combined into a single enum. struct Key { /// DNS name in the request. - qname: Dname, + qname: Name, /// The request class. Always IN at the moment. qclass: Class, @@ -803,10 +802,10 @@ impl Key { rd: bool, ) -> Key where - TDN: ToDname, + TDN: ToName, { Self { - qname: qname.to_canonical_dname(), + qname: qname.to_canonical_name(), qclass, qtype, addo: AdDo::new(ad, dnssec_ok), @@ -920,7 +919,7 @@ impl Value orig_qname: TDN, ) -> Option, Error>> where - TDN: ToDname + Clone, + TDN: ToName + Clone, // C: Clock + Send + Sync, { let elapsed = self.created_at.elapsed(); @@ -1012,7 +1011,7 @@ fn decrement_ttl( amount: u32, ) -> Result, Error> where - TDN: ToDname + Clone, + TDN: ToName + Clone, { let msg = match response { Err(err) => return Err(err.clone()), @@ -1041,7 +1040,7 @@ where let mut target = target.answer(); for rr in &mut source { let mut rr = rr? - .into_record::>>()? + .into_record::>>()? .expect("record expected"); rr.set_ttl(rr.ttl() - amount); target.push(rr).expect("push failed"); @@ -1052,7 +1051,7 @@ where let mut target = target.authority(); for rr in &mut source { let mut rr = rr? - .into_record::>>()? + .into_record::>>()? .expect("record expected"); rr.set_ttl(rr.ttl() - amount); target.push(rr).expect("push failed"); @@ -1063,7 +1062,7 @@ where for rr in source { let rr = rr?; let mut rr = rr - .into_record::>>()? + .into_record::>>()? .expect("record expected"); if rr.rtype() != Rtype::OPT { rr.set_ttl(rr.ttl() - amount); @@ -1107,7 +1106,7 @@ fn remove_dnssec( let mut target = target.answer(); for rr in &mut source { let rr = rr? - .into_record::>>()? + .into_record::>>()? .expect("record expected"); if is_dnssec(rr.rtype()) { continue; @@ -1120,7 +1119,7 @@ fn remove_dnssec( let mut target = target.authority(); for rr in &mut source { let rr = rr? - .into_record::>>()? + .into_record::>>()? .expect("record expected"); if is_dnssec(rr.rtype()) { continue; @@ -1133,7 +1132,7 @@ fn remove_dnssec( for rr in source { let rr = rr?; let rr = rr - .into_record::>>()? + .into_record::>>()? .expect("record expected"); if is_dnssec(rr.rtype()) { continue; diff --git a/src/net/client/mod.rs b/src/net/client/mod.rs index 5125a0919..7dadebac9 100644 --- a/src/net/client/mod.rs +++ b/src/net/client/mod.rs @@ -46,13 +46,13 @@ //! //! For example: //! ```rust -//! # use domain::base::{Dname, MessageBuilder, Rtype}; +//! # use domain::base::{Name, MessageBuilder, Rtype}; //! # use domain::net::client::request::RequestMessage; //! let mut msg = MessageBuilder::new_vec(); //! msg.header_mut().set_rd(true); //! let mut msg = msg.question(); //! msg.push( -//! (Dname::vec_from_str("example.com").unwrap(), Rtype::AAAA) +//! (Name::vec_from_str("example.com").unwrap(), Rtype::AAAA) //! ).unwrap(); //! let req = RequestMessage::new(msg); //! ``` diff --git a/src/net/client/protocol.rs b/src/net/client/protocol.rs index 4c4b07b3f..bd0a34ff0 100644 --- a/src/net/client/protocol.rs +++ b/src/net/client/protocol.rs @@ -6,13 +6,9 @@ use pin_project_lite::pin_project; use std::boxed::Box; use std::io; use std::net::SocketAddr; -use std::sync::Arc; use std::task::{Context, Poll}; use tokio::io::ReadBuf; use tokio::net::{TcpStream, UdpSocket}; -use tokio_rustls::client::TlsStream; -use tokio_rustls::rustls::{ClientConfig, ServerName}; -use tokio_rustls::TlsConnector; /// How many times do we try a new random port if we get ‘address in use.’ const RETRY_RANDOM_PORT: usize = 10; @@ -71,25 +67,30 @@ impl AsyncConnect for TcpConnect { //------------ TlsConnect ----------------------------------------------------- /// Create new TLS connections +#[cfg(feature = "tokio-rustls")] #[derive(Clone, Debug)] pub struct TlsConnect { /// Configuration for setting up a TLS connection. - client_config: Arc, + client_config: std::sync::Arc, /// Server name for certificate verification. - server_name: ServerName, + server_name: tokio_rustls::rustls::pki_types::ServerName<'static>, /// Remote address to connect to. addr: SocketAddr, } +#[cfg(feature = "tokio-rustls")] impl TlsConnect { /// Function to create a new TLS connection stream - pub fn new( - client_config: impl Into>, - server_name: ServerName, + pub fn new( + client_config: Conf, + server_name: tokio_rustls::rustls::pki_types::ServerName<'static>, addr: SocketAddr, - ) -> Self { + ) -> Self + where + Conf: Into>, + { Self { client_config: client_config.into(), server_name, @@ -98,8 +99,9 @@ impl TlsConnect { } } +#[cfg(feature = "tokio-rustls")] impl AsyncConnect for TlsConnect { - type Connection = TlsStream; + type Connection = tokio_rustls::client::TlsStream; type Fut = Pin< Box< dyn Future> @@ -109,7 +111,8 @@ impl AsyncConnect for TlsConnect { >; fn connect(&self) -> Self::Fut { - let tls_connection = TlsConnector::from(self.client_config.clone()); + let tls_connection = + tokio_rustls::TlsConnector::from(self.client_config.clone()); let server_name = self.server_name.clone(); let addr = self.addr; Box::pin(async move { diff --git a/src/net/client/request.rs b/src/net/client/request.rs index bdad09731..3a3fd1ab5 100644 --- a/src/net/client/request.rs +++ b/src/net/client/request.rs @@ -10,7 +10,7 @@ use crate::base::message_builder::{ }; use crate::base::opt::{ComposeOptData, LongOptData, OptRecord}; use crate::base::wire::{Composer, ParseError}; -use crate::base::{Header, Message, ParsedDname, Rtype}; +use crate::base::{Header, Message, ParsedName, Rtype}; use crate::rdata::AllRecordData; use bytes::Bytes; use octseq::Octets; @@ -147,7 +147,7 @@ impl + Debug + Octets> RequestMessage { let mut target = target.answer(); for rr in &mut source { let rr = rr? - .into_record::>>()? + .into_record::>>()? .expect("record expected"); target.push(rr)?; } @@ -157,7 +157,7 @@ impl + Debug + Octets> RequestMessage { let mut target = target.authority(); for rr in &mut source { let rr = rr? - .into_record::>>()? + .into_record::>>()? .expect("record expected"); target.push(rr)?; } @@ -169,7 +169,7 @@ impl + Debug + Octets> RequestMessage { let rr = rr?; if rr.rtype() != Rtype::OPT { let rr = rr - .into_record::>>()? + .into_record::>>()? .expect("record expected"); target.push(rr)?; } diff --git a/src/net/server/message.rs b/src/net/server/message.rs index e84e5da1e..dd1d673ed 100644 --- a/src/net/server/message.rs +++ b/src/net/server/message.rs @@ -10,7 +10,7 @@ use crate::base::Message; //------------ UdpTransportContext ------------------------------------------- /// Request context for a UDP transport. -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Default)] pub struct UdpTransportContext { /// Optional maximum response size hint. max_response_size_hint: Arc>>, @@ -127,6 +127,22 @@ impl TransportSpecificContext { } } +//--- impl From + +impl From for TransportSpecificContext { + fn from(ctx: UdpTransportContext) -> Self { + Self::Udp(ctx) + } +} + +//--- impl From + +impl From for TransportSpecificContext { + fn from(ctx: NonUdpTransportContext) -> Self { + Self::NonUdp(ctx) + } +} + //------------ Request ------------------------------------------------------- /// A DNS message with additional properties describing its context. diff --git a/src/net/server/middleware/cookies.rs b/src/net/server/middleware/cookies.rs index 4754fabfc..3cc320a5f 100644 --- a/src/net/server/middleware/cookies.rs +++ b/src/net/server/middleware/cookies.rs @@ -9,18 +9,17 @@ use std::vec::Vec; use futures::stream::{once, Once}; use octseq::Octets; use rand::RngCore; -use tracing::{debug, enabled, trace, warn, Level}; +use tracing::{debug, trace, warn}; -use crate::base::iana::{OptRcode, OptionCode, Rcode}; +use crate::base::iana::{OptRcode, Rcode}; use crate::base::message_builder::AdditionalBuilder; use crate::base::opt; -use crate::base::opt::Cookie; use crate::base::wire::{Composer, ParseError}; use crate::base::{Serial, StreamTarget}; use crate::net::server::message::Request; use crate::net::server::middleware::stream::MiddlewareStream; use crate::net::server::service::{CallResult, Service, ServiceResult}; -use crate::net::server::util::{add_edns_options, to_pcap_text}; +use crate::net::server::util::add_edns_options; use crate::net::server::util::{mk_builder_for_target, start_reply}; use super::stream::PostprocessingStream; @@ -161,7 +160,7 @@ where // Note: if rcode is non-extended this will also correctly handle // setting the rcode in the main message header. - if let Err(err) = add_edns_options(&mut additional, |_, opt| { + if let Err(err) = add_edns_options(&mut additional, |opt| { opt.cookie(response_cookie)?; opt.set_rcode(rcode); Ok(()) @@ -216,37 +215,13 @@ where self.response_with_cookie(request, Rcode::NOERROR.into()) } - /// Check the cookie contained in the request to make sure that it is - /// complete, and if so return the cookie to the caller. - #[must_use] - fn ensure_cookie_is_complete( - request: &Request, - server_secret: &[u8; 16], - ) -> Option { - if let Some(Ok(cookie)) = Self::cookie(request) { - let cookie = if cookie.server().is_some() { - cookie - } else { - cookie.create_response( - Serial::now(), - request.client_addr().ip(), - server_secret, - ) - }; - - Some(cookie) - } else { - None - } - } - fn preprocess( &self, request: &Request, ) -> ControlFlow>> { match Self::cookie(request) { None => { - trace!("Request does not include DNS cookies"); + trace!("Request does not contain a DNS cookie"); // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.1 // No OPT RR or No COOKIE Option: @@ -424,9 +399,9 @@ where } fn postprocess( - request: &Request, - response: &mut AdditionalBuilder>, - server_secret: [u8; 16], + _request: &Request, + _response: &mut AdditionalBuilder>, + _server_secret: [u8; 16], ) where RequestOctets: Octets, { @@ -454,35 +429,6 @@ where // A Client Cookie and a Valid Server Cookie // Any server cookie will already have been validated during // pre-processing, we don't need to check it again here. - - if let Some(filled_cookie) = - Self::ensure_cookie_is_complete(request, &server_secret) - { - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.5 - // "The server SHALL process the request and include a COOKIE - // option in the response by (a) copying the complete COOKIE - // option from the request or (b) generating a new COOKIE option - // containing both the Client Cookie copied from the request and - // a valid Server Cookie it has generated." - if let Err(err) = add_edns_options( - response, - |existing_option_codes, builder| { - if !existing_option_codes.contains(&OptionCode::COOKIE) { - builder.push(&filled_cookie) - } else { - Ok(()) - } - }, - ) { - warn!("Cannot add RFC 7873 DNS Cookie option to response: {err}"); - } - } - - if enabled!(Level::TRACE) { - let bytes = response.as_slice(); - let pcap_text = to_pcap_text(bytes, bytes.len()); - trace!(pcap_text, "post-processing complete"); - } } fn map_stream_item( @@ -549,3 +495,87 @@ where } } } + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use std::vec::Vec; + use tokio::time::Instant; + + use crate::base::opt::cookie::ClientCookie; + use crate::base::opt::Cookie; + use crate::base::{Message, MessageBuilder, Name, Rtype}; + use crate::net::server::message::{Request, UdpTransportContext}; + use crate::net::server::middleware::cookies::CookiesMiddlewareSvc; + use crate::net::server::service::{CallResult, Service, ServiceResult}; + use crate::net::server::util::service_fn; + use futures::prelude::stream::StreamExt; + + #[tokio::test] + async fn dont_add_cookie_twice() { + // Build a dummy DNS query containing a client cookie. + let query = MessageBuilder::new_vec(); + let mut query = query.question(); + query.push((Name::::root(), Rtype::A)).unwrap(); + let mut additional = query.additional(); + let client_cookie = ClientCookie::new_random(); + let cookie = Cookie::new(client_cookie, None); + additional.opt(|builder| builder.cookie(cookie)).unwrap(); + let message = additional.into_message(); + + // Package the query into a context aware request to make it look + // as if it came from a UDP server. + let ctx = UdpTransportContext::default(); + let client_addr = "127.0.0.1:12345".parse().unwrap(); + let request = + Request::new(client_addr, Instant::now(), message, ctx.into()); + + fn my_service( + _req: Request>, + _meta: (), + ) -> ServiceResult> { + // For each request create a single response: + todo!() + } + + // And pass the query through the middleware processor + let my_svc = service_fn(my_service, ()); + let server_secret: [u8; 16] = + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; + let processor_svc = CookiesMiddlewareSvc::new(my_svc, server_secret); + + let mut stream = processor_svc.call(request).await; + let call_result: CallResult> = + stream.next().await.unwrap().unwrap(); + let (response, _feedback) = call_result.into_inner(); + + // Expect the response to contain a single cookie option containing + // both a client cookie and a server cookie. + let response = response.unwrap().finish(); + let response_bytes = response.as_dgram_slice().to_vec(); + let response = Message::from_octets(response_bytes).unwrap(); + + let Some(opt_record) = response.opt() else { + panic!("Missing OPT record") + }; + + let mut cookie_iter = opt_record.opt().iter::(); + let Some(Ok(cookie)) = cookie_iter.next() else { + panic!("Invalid or missing cookie") + }; + + assert!( + cookie.check_server_hash( + client_addr.ip(), + &server_secret, + |_| true + ), + "The cookie is incomplete or invalid" + ); + + assert!( + cookie_iter.next().is_none(), + "There should only be one COOKIE option" + ); + } +} diff --git a/src/net/server/middleware/edns.rs b/src/net/server/middleware/edns.rs index 2c980a2b0..38473ce41 100644 --- a/src/net/server/middleware/edns.rs +++ b/src/net/server/middleware/edns.rs @@ -7,7 +7,7 @@ use futures::stream::{once, Once}; use octseq::Octets; use tracing::{debug, enabled, error, trace, warn, Level}; -use crate::base::iana::{OptRcode, OptionCode}; +use crate::base::iana::OptRcode; use crate::base::message_builder::AdditionalBuilder; use crate::base::opt::keepalive::IdleTimeout; use crate::base::opt::{Opt, OptRecord, TcpKeepalive}; @@ -275,20 +275,18 @@ where // timeout is known: "Signal the timeout value // using the edns-tcp-keepalive EDNS(0) option // [RFC7828]". - if let Err(err) = add_edns_options( - response, - |existing_option_codes, builder| { - if !existing_option_codes.contains( - &OptionCode::TCP_KEEPALIVE, - ) { + if let Err(err) = + // TODO: Don't add the option if it + // already exists? + add_edns_options( + response, + |builder| { builder.push(&TcpKeepalive::new( Some(timeout), )) - } else { - Ok(()) - } - }, - ) { + }, + ) + { warn!("Cannot add RFC 7828 edns-tcp-keepalive option to response: {err}"); } } @@ -370,7 +368,7 @@ mod tests { use futures::stream::StreamExt; use tokio::time::Instant; - use crate::base::{Dname, Message, MessageBuilder, Rtype}; + use crate::base::{Message, MessageBuilder, Name, Rtype}; use crate::net::server::message::{ Request, TransportSpecificContext, UdpTransportContext, }; @@ -440,7 +438,7 @@ mod tests { // With a dummy question. let mut query = query.question(); - query.push((Dname::::root(), Rtype::A)).unwrap(); + query.push((Name::::root(), Rtype::A)).unwrap(); // And if requested, a requestor's UDP payload size: let message: Message<_> = if let Some(v) = client_value { @@ -463,7 +461,7 @@ mod tests { "127.0.0.1:12345".parse().unwrap(), Instant::now(), message, - TransportSpecificContext::Udp(ctx), + ctx.into(), ); fn my_service( diff --git a/src/net/server/middleware/mandatory.rs b/src/net/server/middleware/mandatory.rs index 912531e21..a205f1f48 100644 --- a/src/net/server/middleware/mandatory.rs +++ b/src/net/server/middleware/mandatory.rs @@ -359,14 +359,11 @@ mod tests { use bytes::Bytes; use futures::StreamExt; - use octseq::OctetsBuilder; use tokio::time::Instant; - use crate::base::iana::{OptionCode, Rcode}; - use crate::base::{Dname, MessageBuilder, Rtype}; - use crate::net::server::message::{ - Request, TransportSpecificContext, UdpTransportContext, - }; + use crate::base::iana::Rcode; + use crate::base::{MessageBuilder, Name, Rtype}; + use crate::net::server::message::{Request, UdpTransportContext}; use crate::net::server::service::{CallResult, Service, ServiceResult}; use crate::net::server::util::{mk_builder_for_target, service_fn}; @@ -403,20 +400,10 @@ mod tests { // Build a dummy DNS query. let query = MessageBuilder::new_vec(); let mut query = query.question(); - query.push((Dname::::root(), Rtype::A)).unwrap(); - let extra_bytes = vec![0; (MIN_ALLOWED as usize) * 2]; + query.push((Name::::root(), Rtype::A)).unwrap(); let mut additional = query.additional(); additional - .opt(|builder| { - builder.push_raw_option( - OptionCode::PADDING, - extra_bytes.len() as u16, - |target| { - target.append_slice(&extra_bytes).unwrap(); - Ok(()) - }, - ) - }) + .opt(|builder| builder.padding(MIN_ALLOWED * 2)) .unwrap(); let old_size = additional.as_slice().len(); let message = additional.into_message(); @@ -431,7 +418,7 @@ mod tests { "127.0.0.1:12345".parse().unwrap(), Instant::now(), message, - TransportSpecificContext::Udp(ctx), + ctx.into(), ); fn my_service( diff --git a/src/net/server/service.rs b/src/net/server/service.rs index 46712372f..0b4c2ce71 100644 --- a/src/net/server/service.rs +++ b/src/net/server/service.rs @@ -69,7 +69,7 @@ pub type ServiceResult = Result, ServiceError>; /// /// use domain::base::iana::{Class, Rcode}; /// use domain::base::message_builder::AdditionalBuilder; -/// use domain::base::{Dname, Message, MessageBuilder, StreamTarget}; +/// use domain::base::{Name, Message, MessageBuilder, StreamTarget}; /// use domain::net::server::message::Request; /// use domain::net::server::service::{ /// CallResult, Service, ServiceError, Transaction @@ -83,7 +83,7 @@ pub type ServiceResult = Result, ServiceError>; /// ) -> Result>>, ServiceError> { /// let mut answer = builder.start_answer(msg.message(), Rcode::NOERROR)?; /// answer.push(( -/// Dname::root_ref(), +/// Name::root_ref(), /// Class::IN, /// 86400, /// A::from_octets(192, 0, 2, 1), @@ -117,7 +117,7 @@ pub type ServiceResult = Result, ServiceError>; /// use core::future::ready; /// use core::future::Future; /// -/// use domain::base::{Dname, Message}; +/// use domain::base::{Name, Message}; /// use domain::base::iana::{Class, Rcode}; /// use domain::base::name::ToLabelIter; /// use domain::base::wire::Composer; @@ -159,7 +159,7 @@ pub type ServiceResult = Result, ServiceError>; /// .start_answer(msg.message(), Rcode::NOERROR) /// .unwrap(); /// answer -/// .push((Dname::root_ref(), Class::IN, 86400, a_rec)) +/// .push((Name::root_ref(), Class::IN, 86400, a_rec)) /// .unwrap(); /// out_answer = Some(answer); /// } diff --git a/src/net/server/tests.rs b/src/net/server/tests.rs index 9b181f418..885deb98e 100644 --- a/src/net/server/tests.rs +++ b/src/net/server/tests.rs @@ -15,8 +15,8 @@ use tokio::time::sleep; use tokio::time::Instant; use tracing::trace; -use crate::base::Dname; use crate::base::MessageBuilder; +use crate::base::Name; use crate::base::Rtype; use crate::base::StaticCompressor; use crate::base::StreamTarget; @@ -42,6 +42,8 @@ struct MockStream { /// The rate at which messages should be made available to the server. new_message_every: Duration, + + pending_responses: usize, } impl MockStream { @@ -49,10 +51,12 @@ impl MockStream { messages_to_read: VecDeque>, new_message_every: Duration, ) -> Self { + let pending_responses = messages_to_read.len(); Self { last_ready: Mutex::new(Option::None), messages_to_read: Mutex::new(messages_to_read), new_message_every, + pending_responses, } } } @@ -81,11 +85,13 @@ impl AsyncRead for MockStream { last_ready.replace(Instant::now()); return Poll::Ready(Ok(())); } else { - // End of stream - /*return Poll::Ready(Err(io::Error::new( - io::ErrorKind::ConnectionAborted, - "mock connection disconnect", - )));*/ + // Disconnect once we've sent all of the requests AND received all of the responses. + if self.pending_responses == 0 { + return Poll::Ready(Err(io::Error::new( + io::ErrorKind::ConnectionAborted, + "mock connection disconnect", + ))); + } } } _ => { @@ -110,10 +116,14 @@ impl AsyncRead for MockStream { impl AsyncWrite for MockStream { fn poll_write( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, _cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { + // Assume a single write is an entire response. + if self.pending_responses > 0 { + self.pending_responses -= 1; + } Poll::Ready(Ok(buf.len())) } @@ -333,11 +343,8 @@ fn mk_query() -> StreamTarget> { msg.header_mut().set_random_id(); let mut msg = msg.question(); - msg.push(( - Dname::>::from_str("example.com.").unwrap(), - Rtype::A, - )) - .unwrap(); + msg.push((Name::>::from_str("example.com.").unwrap(), Rtype::A)) + .unwrap(); let mut msg = msg.additional(); msg.opt(|opt| { diff --git a/src/net/server/util.rs b/src/net/server/util.rs index 034d70861..17fa2b022 100644 --- a/src/net/server/util.rs +++ b/src/net/server/util.rs @@ -5,17 +5,15 @@ use std::string::{String, ToString}; use futures::stream::Once; use octseq::{Octets, OctetsBuilder}; -use smallvec::SmallVec; use tracing::warn; -use crate::base::iana::{OptRcode, OptionCode, Rcode}; +use crate::base::iana::{OptRcode, Rcode}; use crate::base::message_builder::{ AdditionalBuilder, OptBuilder, PushError, QuestionBuilder, }; -use crate::base::opt::UnknownOptData; use crate::base::wire::Composer; use crate::base::Message; -use crate::base::{MessageBuilder, ParsedDname, Rtype, StreamTarget}; +use crate::base::{MessageBuilder, ParsedName, Rtype, StreamTarget}; use crate::rdata::AllRecordData; use super::message::Request; @@ -219,7 +217,7 @@ where // Note: if rcode is non-extended this will also correctly handle // setting the rcode in the main message header. - if let Err(err) = add_edns_options(&mut additional, |_, opt| { + if let Err(err) = add_edns_options(&mut additional, |opt| { opt.set_rcode(rcode); Ok(()) }) { @@ -250,7 +248,6 @@ pub fn add_edns_options( ) -> Result<(), PushError> where F: FnOnce( - &[OptionCode], &mut OptBuilder>, ) -> Result< (), @@ -287,7 +284,7 @@ where for rr in current_additional.flatten() { if rr.rtype() != Rtype::OPT { if let Ok(Some(rr)) = rr - .into_record::>>() + .into_record::>>() { response.push(rr)?; } @@ -299,25 +296,8 @@ where // the options within the existing OPT record plus the new options // that we want to add. let res = response.opt(|builder| { - let mut existing_option_codes = - SmallVec::<[OptionCode; 4]>::new(); - // Copy the header fields - builder.set_version(current_opt.version()); - builder.set_dnssec_ok(current_opt.dnssec_ok()); - builder - .set_rcode(current_opt.rcode(copied_response.header())); - builder.set_udp_payload_size(current_opt.udp_payload_size()); - - // Copy the options - for opt in - current_opt.opt().iter::>().flatten() - { - existing_option_codes.push(opt.code()); - builder.push(&opt)?; - } - - // Invoking the user supplied callback - op(&existing_option_codes, builder) + builder.clone_from(¤t_opt)?; + op(builder) }); return res; @@ -325,7 +305,7 @@ where } // No existing OPT record in the additional section so build a new one. - response.opt(|builder| op(&[], builder)) + response.opt(|builder| op(builder)) } /// Removes any OPT records present in the response. @@ -360,7 +340,7 @@ where for rr in current_additional.flatten() { if rr.rtype() != Rtype::OPT { if let Ok(Some(rr)) = rr - .into_record::>>() + .into_record::>>() { response.push(rr)?; } @@ -372,3 +352,177 @@ where Ok(()) } + +#[cfg(test)] +mod tests { + use bytes::Bytes; + use tokio::time::Instant; + + use crate::base::{Message, MessageBuilder, Name, Rtype, StreamTarget}; + use crate::net::server::message::{Request, UdpTransportContext}; + + use super::start_reply; + use crate::base::iana::{OptRcode, Rcode}; + use crate::base::message_builder::AdditionalBuilder; + use crate::base::opt::UnknownOptData; + use crate::base::wire::Composer; + use crate::net::server::util::{ + add_edns_options, remove_edns_opt_record, + }; + use std::vec::Vec; + + #[test] + fn test_add_edns_option() { + // Given a dummy DNS query. + let query = MessageBuilder::new_vec(); + let mut query = query.question(); + query.push((Name::::root(), Rtype::A)).unwrap(); + let msg = query.into_message(); + + // Package it into a received request. + let client_ip = "127.0.0.1:12345".parse().unwrap(); + let sent_at = Instant::now(); + let ctx = UdpTransportContext::default(); + let request = Request::new(client_ip, sent_at, msg, ctx.into()); + + // Create a dummy DNS reply which does not yet have an OPT record. + let reply = start_reply::<_, Vec>(request.message()); + assert_eq!(reply.counts().arcount(), 0); + assert_eq!(reply.header().rcode(), Rcode::NOERROR); + + // Add an OPT record to the reply. + let mut reply = reply.additional(); + reply + .opt(|builder| { + builder.set_rcode(OptRcode::BADCOOKIE); + builder.set_udp_payload_size(123); + Ok(()) + }) + .unwrap(); + assert_eq!(reply.counts().arcount(), 1); + + // When an OPT record exists the RCODE of the DNS message is extended + // from 4-bits to 12-bits, combining the original 4-bit RCODE in the + // DNS message header with an additional 8-bits in the OPT record + // header. This causes the main DNS header RCODE value to seem wrong + // if inspected in isolation. We set the RCODE to BADCOOKIE but that + // has value 23 which exceeds the 4-bit range maximum value and so is + // encoded as a full 12-bit RCODE. 23 in binary is 0001_0111 which as + // you can see causes the lower 4-bits to have value 0111 which is 7. + let expected_rcode = Rcode::checked_from_int(0b0111).unwrap(); + assert_eq!(reply.header().rcode(), expected_rcode); + + // Note: We can't test the upper 8-bits of the extended RCODE as there + // is no way to access the OPT record header via a message builder. We + // can however serialize the message and deserialize it again and + // check it via the Message interface. + let response = assert_opt( + reply.clone(), + expected_rcode, + Some(OptRcode::BADCOOKIE), + ); + + // And that it has no EDNS options. + let opt = response.opt().unwrap(); + let options = opt.opt(); + assert_eq!(options.len(), 0); + + // Now add an EDNS option to the OPT record. + add_edns_options(&mut reply, |builder| builder.padding(123)).unwrap(); + + // And verify that the OPT record still exists as expected. + let response = assert_opt( + reply.clone(), + expected_rcode, + Some(OptRcode::BADCOOKIE), + ); + + // And that it has a single EDNS option. + let opt = response.opt().unwrap(); + let options = opt.opt(); + assert_eq!(options.iter::>().count(), 1); + + // Now add another EDNS option to the OPT record (duplicates are allowed + // by RFC 6891). + add_edns_options(&mut reply, |builder| builder.padding(123)).unwrap(); + + // And verify that the OPT record still exists as expected. + let response = assert_opt( + reply.clone(), + expected_rcode, + Some(OptRcode::BADCOOKIE), + ); + + // And that it has a single EDNS option. + let opt = response.opt().unwrap(); + let options = opt.opt(); + assert_eq!(options.iter::>().count(), 2); + } + + #[test] + fn test_remove_edns_opt_record() { + // Given a dummy DNS query. + let query = MessageBuilder::new_vec(); + let mut query = query.question(); + query.push((Name::::root(), Rtype::A)).unwrap(); + let msg = query.into_message(); + + // Package it into a received request. + let client_ip = "127.0.0.1:12345".parse().unwrap(); + let sent_at = Instant::now(); + let ctx = UdpTransportContext::default(); + let request = Request::new(client_ip, sent_at, msg, ctx.into()); + + // Create a dummy DNS reply which does not yet have an OPT record. + let reply = start_reply::<_, Vec>(request.message()); + assert_eq!(reply.counts().arcount(), 0); + + // Add an OPT record to the reply. + let mut reply = reply.additional(); + reply.opt(|builder| builder.padding(32)).unwrap(); + assert_eq!(reply.counts().arcount(), 1); + + // Note: We can't test that the OPT record exists or inspect its properties + // when using a MessageBuilder, but we can if we serialize it and deserialize + // it again as a Message. + assert_opt(reply.clone(), Rcode::NOERROR, Some(OptRcode::NOERROR)); + + // Now remove the OPT record from the saved reply. + remove_edns_opt_record(&mut reply).unwrap(); + + // And verify that the OPT record no longer exists when serialized and + // deserialized again. + assert_opt(reply.clone(), Rcode::NOERROR, None); + } + + //------------ Helper functions ------------------------------------------ + + fn assert_opt( + reply: AdditionalBuilder>, + expected_rcode: Rcode, + expected_opt_rcode: Option, + ) -> Message> { + // Serialize the reply to wire format so that we can test that the OPT + // record was really added to a finally constructed DNS message and + // has the expected RCODE and OPT extended RCODE values. + let response = reply.finish(); + let response_bytes = response.as_dgram_slice().to_vec(); + let response = Message::from_octets(response_bytes).unwrap(); + + assert_eq!(response.header().rcode(), expected_rcode); + match expected_opt_rcode { + Some(opt_rcode) => { + assert_eq!(response.header_counts().arcount(), 1); + assert!(response.opt().is_some()); + assert_eq!(response.opt_rcode(), opt_rcode); + } + + None => { + assert_eq!(response.header_counts().arcount(), 0); + assert!(response.opt().is_none()); + } + } + + response + } +} diff --git a/src/rdata/dname.rs b/src/rdata/dname.rs index 94315b96e..bb420576f 100644 --- a/src/rdata/dname.rs +++ b/src/rdata/dname.rs @@ -1,5 +1,5 @@ use crate::base::cmp::CanonicalOrd; -use crate::base::name::{ParsedDname, ToDname}; +use crate::base::name::{ParsedName, ToName}; use crate::base::wire::ParseError; use core::cmp::Ordering; use core::str::FromStr; @@ -9,7 +9,7 @@ use octseq::parse::Parser; //------------ Dname -------------------------------------------------------- -dname_type_canonical! { +name_type_canonical! { /// DNAME record data. /// /// The DNAME record provides redirection for a subtree of the domain @@ -25,7 +25,7 @@ dname_type_canonical! { #[cfg(all(feature = "std", feature = "bytes"))] mod test { use super::*; - use crate::base::name::Dname as Name; + use crate::base::name::Name; use crate::base::rdata::test::{ test_compose_parse, test_rdlen, test_scan, }; diff --git a/src/rdata/dnssec.rs b/src/rdata/dnssec.rs index cfdb932cc..e90747c36 100644 --- a/src/rdata/dnssec.rs +++ b/src/rdata/dnssec.rs @@ -6,7 +6,7 @@ use crate::base::cmp::CanonicalOrd; use crate::base::iana::{DigestAlg, Rtype, SecAlg}; -use crate::base::name::{FlattenInto, ParsedDname, ToDname}; +use crate::base::name::{FlattenInto, ParsedName, ToName}; use crate::base::rdata::{ ComposeRecordData, LongRecordData, ParseRecordData, RecordData, }; @@ -472,7 +472,7 @@ impl ProtoRrsig { signature: Octs, ) -> Result, LongRecordData> where - Name: ToDname, + Name: ToName, { Rrsig::new( self.type_covered, @@ -488,7 +488,7 @@ impl ProtoRrsig { } } -impl ProtoRrsig { +impl ProtoRrsig { pub fn compose( &self, target: &mut Target, @@ -856,7 +856,7 @@ impl Rrsig { ) -> Result where Octs: AsRef<[u8]>, - Name: ToDname, + Name: ToName, { LongRecordData::check_len( usize::from( @@ -1002,12 +1002,12 @@ impl Rrsig { }) } - pub fn scan>( + pub fn scan>( scanner: &mut S, ) -> Result where Octs: AsRef<[u8]>, - Name: ToDname, + Name: ToName, { Self::new( Rtype::scan(scanner)?, @@ -1017,14 +1017,14 @@ impl Rrsig { Timestamp::scan(scanner)?, Timestamp::scan(scanner)?, u16::scan(scanner)?, - scanner.scan_dname()?, + scanner.scan_name()?, scanner.convert_entry(base64::SymbolConverter::new())?, ) .map_err(|err| S::Error::custom(err.as_str())) } } -impl Rrsig> { +impl Rrsig> { pub fn parse<'a, Src: Octets = Octs> + ?Sized + 'a>( parser: &mut Parser<'a, Src>, ) -> Result { @@ -1035,7 +1035,7 @@ impl Rrsig> { let expiration = Timestamp::parse(parser)?; let inception = Timestamp::parse(parser)?; let key_tag = u16::parse(parser)?; - let signer_name = ParsedDname::parse(parser)?; + let signer_name = ParsedName::parse(parser)?; let len = parser.remaining(); let signature = parser.parse_octets(len)?; Ok(unsafe { @@ -1101,8 +1101,8 @@ where impl PartialEq> for Rrsig where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, O: AsRef<[u8]>, OO: AsRef<[u8]>, { @@ -1122,7 +1122,7 @@ where impl Eq for Rrsig where Octs: AsRef<[u8]>, - Name: ToDname, + Name: ToName, { } @@ -1130,8 +1130,8 @@ where impl PartialOrd> for Rrsig where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, O: AsRef<[u8]>, OO: AsRef<[u8]>, { @@ -1176,8 +1176,8 @@ where impl CanonicalOrd> for Rrsig where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, O: AsRef<[u8]>, OO: AsRef<[u8]>, { @@ -1218,7 +1218,7 @@ where } } -impl, N: ToDname> Ord for Rrsig { +impl, N: ToName> Ord for Rrsig { fn cmp(&self, other: &Self) -> Ordering { self.canonical_cmp(other) } @@ -1249,7 +1249,7 @@ impl RecordData for Rrsig { } impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> - for Rrsig, ParsedDname>> + for Rrsig, ParsedName>> { fn parse_rdata( rtype: Rtype, @@ -1266,7 +1266,7 @@ impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> impl ComposeRecordData for Rrsig where Octs: AsRef<[u8]>, - Name: ToDname, + Name: ToName, { fn rdlen(&self, _compress: bool) -> Option { Some( @@ -1305,7 +1305,7 @@ where } } -impl, Name: ToDname> Rrsig { +impl, Name: ToName> Rrsig { fn compose_head( &self, target: &mut Target, @@ -1439,22 +1439,22 @@ impl Nsec { )) } - pub fn scan>( + pub fn scan>( scanner: &mut S, ) -> Result { Ok(Self::new( - scanner.scan_dname()?, + scanner.scan_name()?, RtypeBitmap::scan(scanner)?, )) } } -impl> Nsec> { +impl> Nsec> { pub fn parse<'a, Src: Octets = Octs> + ?Sized + 'a>( parser: &mut Parser<'a, Src>, ) -> Result { Ok(Nsec::new( - ParsedDname::parse(parser)?, + ParsedName::parse(parser)?, RtypeBitmap::parse(parser)?, )) } @@ -1499,15 +1499,15 @@ impl PartialEq> for Nsec where O: AsRef<[u8]>, OO: AsRef<[u8]>, - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn eq(&self, other: &Nsec) -> bool { self.next_name.name_eq(&other.next_name) && self.types == other.types } } -impl, N: ToDname> Eq for Nsec {} +impl, N: ToName> Eq for Nsec {} //--- PartialOrd, Ord, and CanonicalOrd @@ -1515,8 +1515,8 @@ impl PartialOrd> for Nsec where O: AsRef<[u8]>, OO: AsRef<[u8]>, - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn partial_cmp(&self, other: &Nsec) -> Option { match self.next_name.name_cmp(&other.next_name) { @@ -1531,8 +1531,8 @@ impl CanonicalOrd> for Nsec where O: AsRef<[u8]>, OO: AsRef<[u8]>, - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn canonical_cmp(&self, other: &Nsec) -> Ordering { // RFC 6840 says that Nsec::next_name is not converted to lower case. @@ -1547,7 +1547,7 @@ where impl Ord for Nsec where O: AsRef<[u8]>, - N: ToDname, + N: ToName, { fn cmp(&self, other: &Self) -> Ordering { match self.next_name.name_cmp(&other.next_name) { @@ -1576,7 +1576,7 @@ impl RecordData for Nsec { } impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> - for Nsec, ParsedDname>> + for Nsec, ParsedName>> { fn parse_rdata( rtype: Rtype, @@ -1593,7 +1593,7 @@ impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> impl ComposeRecordData for Nsec where Octs: AsRef<[u8]>, - Name: ToDname, + Name: ToName, { fn rdlen(&self, _compress: bool) -> Option { Some( @@ -2659,7 +2659,7 @@ impl std::error::Error for IllegalSignatureTime {} mod test { use super::*; use crate::base::iana::Rtype; - use crate::base::name::Dname; + use crate::base::name::Name; use crate::base::rdata::test::{ test_compose_parse, test_rdlen, test_scan, }; @@ -2690,7 +2690,7 @@ mod test { Timestamp::from(13), Timestamp::from(14), 15, - Dname::>::from_str("example.com.").unwrap(), + Name::>::from_str("example.com.").unwrap(), b"key", ) .unwrap(); @@ -2722,7 +2722,7 @@ mod test { rtype.add(Rtype::A).unwrap(); rtype.add(Rtype::SRV).unwrap(); let rdata = Nsec::new( - Dname::>::from_str("example.com.").unwrap(), + Name::>::from_str("example.com.").unwrap(), rtype.finalize(), ); test_rdlen(&rdata); diff --git a/src/rdata/macros.rs b/src/rdata/macros.rs index d7c0a6233..ac0b49697 100644 --- a/src/rdata/macros.rs +++ b/src/rdata/macros.rs @@ -33,7 +33,7 @@ macro_rules! rdata_types { use core::{fmt, hash}; use crate::base::cmp::CanonicalOrd; use crate::base::iana::Rtype; - use crate::base::name::{FlattenInto, ParsedDname, ToDname}; + use crate::base::name::{FlattenInto, ParsedName, ToName}; use crate::base::opt::Opt; use crate::base::rdata::{ ComposeRecordData, ParseAnyRecordData, ParseRecordData, @@ -83,7 +83,7 @@ macro_rules! rdata_types { Unknown(UnknownRecordData), } - impl, Name: ToDname> ZoneRecordData { + impl, Name: ToName> ZoneRecordData { /// Scans a value of the given rtype. /// /// If the record data is given via the notation for unknown @@ -94,7 +94,7 @@ macro_rules! rdata_types { scanner: &mut S ) -> Result where - S: $crate::base::scan::Scanner + S: $crate::base::scan::Scanner { if scanner.scan_opt_unknown_marker()? { UnknownRecordData::scan_without_marker( @@ -229,7 +229,7 @@ macro_rules! rdata_types { for ZoneRecordData where O: AsRef<[u8]>, OO: AsRef<[u8]>, - N: ToDname, NN: ToDname, + N: ToName, NN: ToName, { fn eq(&self, other: &ZoneRecordData) -> bool { match (self, other) { @@ -254,7 +254,7 @@ macro_rules! rdata_types { } impl Eq for ZoneRecordData - where O: AsRef<[u8]>, N: ToDname { } + where O: AsRef<[u8]>, N: ToName { } //--- PartialOrd, Ord, and CanonicalOrd @@ -262,7 +262,7 @@ macro_rules! rdata_types { impl Ord for ZoneRecordData where O: AsRef<[u8]>, - N: ToDname, + N: ToName, { fn cmp(&self, other: &Self) -> core::cmp::Ordering { match (self, other) { @@ -290,7 +290,7 @@ macro_rules! rdata_types { for ZoneRecordData where O: AsRef<[u8]>, OO: AsRef<[u8]>, - N: ToDname, NN: ToDname, + N: ToName, NN: ToName, { fn partial_cmp( &self, @@ -322,8 +322,8 @@ macro_rules! rdata_types { for ZoneRecordData where O: AsRef<[u8]>, OO: AsRef<[u8]>, - N: CanonicalOrd + ToDname, - NN: ToDname, + N: CanonicalOrd + ToName, + NN: ToName, { fn canonical_cmp( &self, @@ -380,7 +380,7 @@ macro_rules! rdata_types { impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> - for ZoneRecordData, ParsedDname>> { + for ZoneRecordData, ParsedName>> { fn parse_rdata( rtype: Rtype, parser: &mut Parser<'a, Octs>, @@ -403,7 +403,7 @@ macro_rules! rdata_types { } impl ComposeRecordData for ZoneRecordData - where Octs: AsRef<[u8]>, Name: ToDname { + where Octs: AsRef<[u8]>, Name: ToName { fn rdlen(&self, compress: bool) -> Option { match *self { $( $( $( @@ -694,7 +694,7 @@ macro_rules! rdata_types { for AllRecordData where O: AsRef<[u8]>, OO: AsRef<[u8]>, - N: ToDname, NN: ToDname + N: ToName, NN: ToName { fn eq(&self, other: &AllRecordData) -> bool { match (self, other) { @@ -720,14 +720,14 @@ macro_rules! rdata_types { } impl Eq for AllRecordData - where O: AsRef<[u8]>, N: ToDname { } + where O: AsRef<[u8]>, N: ToName { } //--- PartialOrd, Ord, and CanonicalOrd impl Ord for AllRecordData where O: AsRef<[u8]>, - N: ToDname, + N: ToName, { fn cmp( &self, @@ -773,7 +773,7 @@ macro_rules! rdata_types { for AllRecordData where O: AsRef<[u8]>, OO: AsRef<[u8]>, - N: ToDname, NN: ToDname, + N: ToName, NN: ToName, { fn partial_cmp( &self, @@ -820,8 +820,8 @@ macro_rules! rdata_types { for AllRecordData where O: AsRef<[u8]>, OO: AsRef<[u8]>, - N: CanonicalOrd + ToDname, - NN: ToDname, + N: CanonicalOrd + ToName, + NN: ToName, { fn canonical_cmp( &self, @@ -914,7 +914,7 @@ macro_rules! rdata_types { impl<'a, Octs: Octets> ParseAnyRecordData<'a, Octs> - for AllRecordData, ParsedDname>> { + for AllRecordData, ParsedName>> { fn parse_any_rdata( rtype: Rtype, parser: &mut Parser<'a, Octs>, @@ -952,7 +952,7 @@ macro_rules! rdata_types { impl<'a, Octs: Octets> ParseRecordData<'a, Octs> - for AllRecordData, ParsedDname>> { + for AllRecordData, ParsedName>> { fn parse_rdata( rtype: Rtype, parser: &mut Parser<'a, Octs>, @@ -962,7 +962,7 @@ macro_rules! rdata_types { } impl ComposeRecordData for AllRecordData - where Octs: AsRef<[u8]>, Name: ToDname { + where Octs: AsRef<[u8]>, Name: ToName { fn rdlen(&self, compress: bool) -> Option { match *self { $( $( $( @@ -1105,13 +1105,13 @@ macro_rules! rdata_types { } } -//------------ dname_type! -------------------------------------------------- +//------------ name_type! -------------------------------------------------- /// A macro for implementing a record data type with a single domain name. /// /// Implements some basic methods plus the `RecordData`, `FlatRecordData`, /// and `Display` traits. -macro_rules! dname_type_base { +macro_rules! name_type_base { ($(#[$attr:meta])* ( $target:ident, $rtype:ident, $field:ident, $into_field:ident ) ) => { @@ -1144,10 +1144,10 @@ macro_rules! dname_type_base { self.$field } - pub fn scan>( + pub fn scan>( scanner: &mut S ) -> Result { - scanner.scan_dname().map(Self::new) + scanner.scan_name().map(Self::new) } pub(in crate::rdata) fn convert_octets>( @@ -1167,11 +1167,11 @@ macro_rules! dname_type_base { } } - impl $target> { + impl $target> { pub fn parse<'a, Src: Octets = Octs> + ?Sized + 'a>( parser: &mut Parser<'a, Src>, ) -> Result { - ParsedDname::parse(parser).map(Self::new) + ParsedName::parse(parser).map(Self::new) } } @@ -1224,13 +1224,13 @@ macro_rules! dname_type_base { //--- PartialEq and Eq impl PartialEq<$target> for $target - where N: ToDname, NN: ToDname { + where N: ToName, NN: ToName { fn eq(&self, other: &$target) -> bool { self.$field.name_eq(&other.$field) } } - impl Eq for $target { } + impl Eq for $target { } //--- PartialOrd and Ord @@ -1238,13 +1238,13 @@ macro_rules! dname_type_base { // For CanonicalOrd, see below. impl PartialOrd<$target> for $target - where N: ToDname, NN: ToDname { + where N: ToName, NN: ToName { fn partial_cmp(&self, other: &$target) -> Option { Some(self.$field.name_cmp(&other.$field)) } } - impl Ord for $target { + impl Ord for $target { fn cmp(&self, other: &Self) -> Ordering { self.$field.name_cmp(&other.$field) } @@ -1267,7 +1267,7 @@ macro_rules! dname_type_base { } impl<'a, Octs> $crate::base::rdata::ParseRecordData<'a, Octs> - for $target<$crate::base::name::ParsedDname>> + for $target<$crate::base::name::ParsedName>> where Octs: octseq::octets::Octets + ?Sized { fn parse_rdata( rtype: $crate::base::iana::Rtype, @@ -1292,16 +1292,16 @@ macro_rules! dname_type_base { } } -macro_rules! dname_type_well_known { +macro_rules! name_type_well_known { ($(#[$attr:meta])* ( $target:ident, $rtype:ident, $field:ident, $into_field:ident ) ) => { - dname_type_base! { + name_type_base! { $( #[$attr] )* ($target, $rtype, $field, $into_field) } - impl $crate::base::rdata::ComposeRecordData + impl $crate::base::rdata::ComposeRecordData for $target { fn rdlen(&self, compress: bool) -> Option { if compress { @@ -1316,7 +1316,7 @@ macro_rules! dname_type_well_known { &self, target: &mut Target ) -> Result<(), Target::AppendError> { if target.can_compress() { - target.append_compressed_dname(&self.$field) + target.append_compressed_name(&self.$field) } else { self.$field.compose(target) @@ -1331,7 +1331,7 @@ macro_rules! dname_type_well_known { } } - impl CanonicalOrd<$target> for $target { + impl CanonicalOrd<$target> for $target { fn canonical_cmp(&self, other: &$target) -> Ordering { self.$field.lowercase_composed_cmp(&other.$field) } @@ -1339,16 +1339,16 @@ macro_rules! dname_type_well_known { } } -macro_rules! dname_type_canonical { +macro_rules! name_type_canonical { ($(#[$attr:meta])* ( $target:ident, $rtype:ident, $field:ident, $into_field:ident ) ) => { - dname_type_base! { + name_type_base! { $( #[$attr] )* ($target, $rtype, $field, $into_field) } - impl $crate::base::rdata::ComposeRecordData + impl $crate::base::rdata::ComposeRecordData for $target { fn rdlen(&self, _compress: bool) -> Option { Some(self.$field.compose_len()) @@ -1368,7 +1368,7 @@ macro_rules! dname_type_canonical { } } - impl CanonicalOrd<$target> for $target { + impl CanonicalOrd<$target> for $target { fn canonical_cmp(&self, other: &$target) -> Ordering { self.$field.lowercase_composed_cmp(&other.$field) } @@ -1377,16 +1377,16 @@ macro_rules! dname_type_canonical { } #[allow(unused_macros)] -macro_rules! dname_type { +macro_rules! name_type { ($(#[$attr:meta])* ( $target:ident, $rtype:ident, $field:ident, $into_field:ident ) ) => { - dname_type_base! { + name_type_base! { $( #[$attr] )* ($target, $rtype, $field, $into_field) } - impl $crate::base::rdata::ComposeRecordData + impl $crate::base::rdata::ComposeRecordData for $target { fn rdlen(&self, _compress: bool) -> Option { Some(self.compose_len) @@ -1406,7 +1406,7 @@ macro_rules! dname_type { } } - impl CanonicalOrd<$target> for $target { + impl CanonicalOrd<$target> for $target { fn canonical_cmp(&self, other: &$target) -> Ordering { self.$field.name_cmp(&other.$field) } diff --git a/src/rdata/mod.rs b/src/rdata/mod.rs index 2019674d5..646f53669 100644 --- a/src/rdata/mod.rs +++ b/src/rdata/mod.rs @@ -29,7 +29,7 @@ // // RFC 3597 stipulates that only record data of record types defined in RFC // 1035 is allowed to be compressed. (These are called “well-known record -// types.”) For all other types, `CompressDname::append_compressed_dname` +// types.”) For all other types, `CompressDname::append_compressed_name` // must not be used and the names be composed with `ToDname::compose`. // // RFC 4034 defines the canonical form of record data. For this form, domain @@ -40,9 +40,9 @@ // `ToDname::compose`. // // The macros module contains three macros for generating name-only record -// types in these three categories: `dname_type_well_known!` for types from -// RFC 1035, `dname_type_canonical!` for non-RFC 1035 types that need to be -// lowercased, and `dname_type!` for everything else. +// types in these three categories: `name_type_well_known!` for types from +// RFC 1035, `name_type_canonical!` for non-RFC 1035 types that need to be +// lowercased, and `name_type!` for everything else. #[macro_use] mod macros; diff --git a/src/rdata/rfc1035/minfo.rs b/src/rdata/rfc1035/minfo.rs index b5fe6d92f..cf1b3ccbb 100644 --- a/src/rdata/rfc1035/minfo.rs +++ b/src/rdata/rfc1035/minfo.rs @@ -4,7 +4,7 @@ use crate::base::cmp::CanonicalOrd; use crate::base::iana::Rtype; -use crate::base::name::{FlattenInto, ParsedDname, ToDname}; +use crate::base::name::{FlattenInto, ParsedName, ToName}; use crate::base::rdata::{ ComposeRecordData, ParseRecordData, RecordData, }; @@ -82,20 +82,20 @@ impl Minfo { )) } - pub fn scan>( + pub fn scan>( scanner: &mut S, ) -> Result { - Ok(Self::new(scanner.scan_dname()?, scanner.scan_dname()?)) + Ok(Self::new(scanner.scan_name()?, scanner.scan_name()?)) } } -impl Minfo> { +impl Minfo> { pub fn parse<'a, Src: Octets = Octs> + ?Sized>( parser: &mut Parser<'a, Src>, ) -> Result { Ok(Self::new( - ParsedDname::parse(parser)?, - ParsedDname::parse(parser)?, + ParsedName::parse(parser)?, + ParsedName::parse(parser)?, )) } } @@ -131,8 +131,8 @@ where impl PartialEq> for Minfo where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn eq(&self, other: &Minfo) -> bool { self.rmailbx.name_eq(&other.rmailbx) @@ -140,14 +140,14 @@ where } } -impl Eq for Minfo {} +impl Eq for Minfo {} //--- PartialOrd, Ord, and CanonicalOrd impl PartialOrd> for Minfo where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn partial_cmp(&self, other: &Minfo) -> Option { match self.rmailbx.name_cmp(&other.rmailbx) { @@ -158,7 +158,7 @@ where } } -impl Ord for Minfo { +impl Ord for Minfo { fn cmp(&self, other: &Self) -> Ordering { match self.rmailbx.name_cmp(&other.rmailbx) { Ordering::Equal => {} @@ -168,7 +168,7 @@ impl Ord for Minfo { } } -impl CanonicalOrd> for Minfo { +impl CanonicalOrd> for Minfo { fn canonical_cmp(&self, other: &Minfo) -> Ordering { match self.rmailbx.lowercase_composed_cmp(&other.rmailbx) { Ordering::Equal => {} @@ -187,7 +187,7 @@ impl RecordData for Minfo { } impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> - for Minfo>> + for Minfo>> { fn parse_rdata( rtype: Rtype, @@ -201,7 +201,7 @@ impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> } } -impl ComposeRecordData for Minfo { +impl ComposeRecordData for Minfo { fn rdlen(&self, compress: bool) -> Option { if compress { None @@ -215,8 +215,8 @@ impl ComposeRecordData for Minfo { target: &mut Target, ) -> Result<(), Target::AppendError> { if target.can_compress() { - target.append_compressed_dname(&self.rmailbx)?; - target.append_compressed_dname(&self.emailbx) + target.append_compressed_name(&self.rmailbx)?; + target.append_compressed_name(&self.emailbx) } else { self.rmailbx.compose(target)?; self.emailbx.compose(target) @@ -246,7 +246,7 @@ impl fmt::Display for Minfo { #[cfg(all(feature = "std", feature = "bytes"))] mod test { use super::*; - use crate::base::name::Dname; + use crate::base::name::Name; use crate::base::rdata::test::{ test_compose_parse, test_rdlen, test_scan, }; @@ -256,9 +256,9 @@ mod test { #[test] #[allow(clippy::redundant_closure)] // lifetimes ... fn minfo_compose_parse_scan() { - let rdata = Minfo::>>::new( - Dname::from_str("r.example.com").unwrap(), - Dname::from_str("e.example.com").unwrap(), + let rdata = Minfo::>>::new( + Name::from_str("r.example.com").unwrap(), + Name::from_str("e.example.com").unwrap(), ); test_rdlen(&rdata); test_compose_parse(&rdata, |parser| Minfo::parse(parser)); @@ -267,11 +267,11 @@ mod test { #[test] fn minfo_octets_into() { - let minfo: Minfo>> = Minfo::new( + let minfo: Minfo>> = Minfo::new( "a.example".parse().unwrap(), "b.example".parse().unwrap(), ); - let minfo_bytes: Minfo> = + let minfo_bytes: Minfo> = minfo.clone().octets_into(); assert_eq!(minfo.rmailbx(), minfo_bytes.rmailbx()); assert_eq!(minfo.emailbx(), minfo_bytes.emailbx()); diff --git a/src/rdata/rfc1035/mod.rs b/src/rdata/rfc1035/mod.rs index adb2ed912..06d5d2a53 100644 --- a/src/rdata/rfc1035/mod.rs +++ b/src/rdata/rfc1035/mod.rs @@ -5,7 +5,7 @@ //! [RFC 1035]: https://tools.ietf.org/html/rfc1035 pub use self::a::A; -pub use self::dname::{Cname, Mb, Md, Mf, Mg, Mr, Ns, Ptr}; +pub use self::name::{Cname, Mb, Md, Mf, Mg, Mr, Ns, Ptr}; pub use self::hinfo::Hinfo; pub use self::minfo::Minfo; pub use self::mx::Mx; @@ -16,7 +16,7 @@ pub use self::txt::{ }; mod a; -mod dname; +mod name; mod hinfo; mod minfo; mod mx; diff --git a/src/rdata/rfc1035/mx.rs b/src/rdata/rfc1035/mx.rs index c0fc823d5..7e83b8ec6 100644 --- a/src/rdata/rfc1035/mx.rs +++ b/src/rdata/rfc1035/mx.rs @@ -4,7 +4,7 @@ use crate::base::cmp::CanonicalOrd; use crate::base::iana::Rtype; -use crate::base::name::{FlattenInto, ParsedDname, ToDname}; +use crate::base::name::{FlattenInto, ParsedName, ToName}; use crate::base::rdata::{ ComposeRecordData, ParseRecordData, RecordData, }; @@ -70,18 +70,18 @@ impl Mx { Ok(Mx::new(self.preference, self.exchange.try_flatten_into()?)) } - pub fn scan>( + pub fn scan>( scanner: &mut S, ) -> Result { - Ok(Self::new(u16::scan(scanner)?, scanner.scan_dname()?)) + Ok(Self::new(u16::scan(scanner)?, scanner.scan_name()?)) } } -impl Mx> { +impl Mx> { pub fn parse<'a, Src: Octets = Octs> + ?Sized + 'a>( parser: &mut Parser<'a, Src>, ) -> Result { - Ok(Self::new(u16::parse(parser)?, ParsedDname::parse(parser)?)) + Ok(Self::new(u16::parse(parser)?, ParsedName::parse(parser)?)) } } @@ -116,8 +116,8 @@ where impl PartialEq> for Mx where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn eq(&self, other: &Mx) -> bool { self.preference == other.preference @@ -125,14 +125,14 @@ where } } -impl Eq for Mx {} +impl Eq for Mx {} //--- PartialOrd, Ord, and CanonicalOrd impl PartialOrd> for Mx where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn partial_cmp(&self, other: &Mx) -> Option { match self.preference.partial_cmp(&other.preference) { @@ -143,7 +143,7 @@ where } } -impl Ord for Mx { +impl Ord for Mx { fn cmp(&self, other: &Self) -> Ordering { match self.preference.cmp(&other.preference) { Ordering::Equal => {} @@ -153,7 +153,7 @@ impl Ord for Mx { } } -impl CanonicalOrd> for Mx { +impl CanonicalOrd> for Mx { fn canonical_cmp(&self, other: &Mx) -> Ordering { match self.preference.cmp(&other.preference) { Ordering::Equal => {} @@ -172,7 +172,7 @@ impl RecordData for Mx { } impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> - for Mx>> + for Mx>> { fn parse_rdata( rtype: Rtype, @@ -186,7 +186,7 @@ impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> } } -impl ComposeRecordData for Mx { +impl ComposeRecordData for Mx { fn rdlen(&self, compress: bool) -> Option { if compress { None @@ -201,7 +201,7 @@ impl ComposeRecordData for Mx { ) -> Result<(), Target::AppendError> { if target.can_compress() { self.preference.compose(target)?; - target.append_compressed_dname(&self.exchange) + target.append_compressed_name(&self.exchange) } else { self.preference.compose(target)?; self.exchange.compose(target) @@ -231,7 +231,7 @@ impl fmt::Display for Mx { #[cfg(all(feature = "std", feature = "bytes"))] mod test { use super::*; - use crate::base::name::Dname; + use crate::base::name::Name; use crate::base::rdata::test::{ test_compose_parse, test_rdlen, test_scan, }; @@ -241,9 +241,9 @@ mod test { #[test] #[allow(clippy::redundant_closure)] // lifetimes ... fn mx_compose_parse_scan() { - let rdata = Mx::>>::new( + let rdata = Mx::>>::new( 12, - Dname::from_str("mail.example.com").unwrap(), + Name::from_str("mail.example.com").unwrap(), ); test_rdlen(&rdata); test_compose_parse(&rdata, |parser| Mx::parse(parser)); diff --git a/src/rdata/rfc1035/dname.rs b/src/rdata/rfc1035/name.rs similarity index 91% rename from src/rdata/rfc1035/dname.rs rename to src/rdata/rfc1035/name.rs index 39a1979f2..78a7907c0 100644 --- a/src/rdata/rfc1035/dname.rs +++ b/src/rdata/rfc1035/name.rs @@ -3,7 +3,7 @@ //! This is a private module. It’s content is re-exported by the parent. use crate::base::cmp::CanonicalOrd; -use crate::base::name::{ParsedDname, ToDname}; +use crate::base::name::{ParsedName, ToName}; use crate::base::wire::ParseError; use core::{fmt, hash, str}; use core::cmp::Ordering; @@ -13,7 +13,7 @@ use octseq::parse::Parser; //------------ Cname -------------------------------------------------------- -dname_type_well_known! { +name_type_well_known! { /// CNAME record data. /// /// The CNAME record specifies the canonical or primary name for domain @@ -25,7 +25,7 @@ dname_type_well_known! { //------------ Mb ----------------------------------------------------------- -dname_type_well_known! { +name_type_well_known! { /// MB record data. /// /// The experimental MB record specifies a host that serves a mailbox. @@ -36,7 +36,7 @@ dname_type_well_known! { //------------ Md ----------------------------------------------------------- -dname_type_well_known! { +name_type_well_known! { /// MD record data. /// /// The MD record specifices a host which has a mail agent for @@ -51,7 +51,7 @@ dname_type_well_known! { //------------ Mf ----------------------------------------------------------- -dname_type_well_known! { +name_type_well_known! { /// MF record data. /// /// The MF record specifices a host which has a mail agent for @@ -66,7 +66,7 @@ dname_type_well_known! { //------------ Mg ----------------------------------------------------------- -dname_type_well_known! { +name_type_well_known! { /// MG record data. /// /// The MG record specifices a mailbox which is a member of the mail group @@ -80,7 +80,7 @@ dname_type_well_known! { //------------ Mr ----------------------------------------------------------- -dname_type_well_known! { +name_type_well_known! { /// MR record data. /// /// The MR record specifices a mailbox which is the proper rename of the @@ -94,7 +94,7 @@ dname_type_well_known! { //------------ Ns ----------------------------------------------------------- -dname_type_well_known! { +name_type_well_known! { /// NS record data. /// /// NS records specify hosts that are authoritative for a class and domain. @@ -105,7 +105,7 @@ dname_type_well_known! { //------------ Ptr ---------------------------------------------------------- -dname_type_well_known! { +name_type_well_known! { /// PTR record data. /// /// PRT records are used in special domains to point to some other location @@ -121,7 +121,7 @@ dname_type_well_known! { #[cfg(all(feature = "std", feature = "bytes"))] mod test { use super::*; - use crate::base::name::Dname; + use crate::base::name::Name; use crate::base::rdata::test::{ test_compose_parse, test_rdlen, test_scan, }; @@ -133,7 +133,7 @@ mod test { #[allow(clippy::redundant_closure)] // lifetimes ... fn cname_compose_parse_scan() { let rdata = - Cname::>>::from_str("www.example.com").unwrap(); + Cname::>>::from_str("www.example.com").unwrap(); test_rdlen(&rdata); test_compose_parse(&rdata, |parser| Cname::parse(parser)); test_scan(&["www.example.com"], Cname::scan, &rdata); diff --git a/src/rdata/rfc1035/soa.rs b/src/rdata/rfc1035/soa.rs index 95b3330b2..670897cf4 100644 --- a/src/rdata/rfc1035/soa.rs +++ b/src/rdata/rfc1035/soa.rs @@ -4,7 +4,7 @@ use crate::base::cmp::CanonicalOrd; use crate::base::iana::Rtype; -use crate::base::name::{FlattenInto, ParsedDname, ToDname}; +use crate::base::name::{FlattenInto, ParsedName, ToName}; use crate::base::rdata::{ ComposeRecordData, ParseRecordData, RecordData, }; @@ -128,12 +128,12 @@ impl Soa { )) } - pub fn scan>( + pub fn scan>( scanner: &mut S, ) -> Result { Ok(Self::new( - scanner.scan_dname()?, - scanner.scan_dname()?, + scanner.scan_name()?, + scanner.scan_name()?, Serial::scan(scanner)?, Ttl::scan(scanner)?, Ttl::scan(scanner)?, @@ -143,13 +143,13 @@ impl Soa { } } -impl Soa> { +impl Soa> { pub fn parse<'a, Src: Octets = Octs> + ?Sized + 'a>( parser: &mut Parser<'a, Src>, ) -> Result { Ok(Self::new( - ParsedDname::parse(parser)?, - ParsedDname::parse(parser)?, + ParsedName::parse(parser)?, + ParsedName::parse(parser)?, Serial::parse(parser)?, Ttl::parse(parser)?, Ttl::parse(parser)?, @@ -195,8 +195,8 @@ where impl PartialEq> for Soa where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn eq(&self, other: &Soa) -> bool { self.mname.name_eq(&other.mname) @@ -209,14 +209,14 @@ where } } -impl Eq for Soa {} +impl Eq for Soa {} //--- PartialOrd, Ord, and CanonicalOrd impl PartialOrd> for Soa where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn partial_cmp(&self, other: &Soa) -> Option { match self.mname.name_cmp(&other.mname) { @@ -247,7 +247,7 @@ where } } -impl Ord for Soa { +impl Ord for Soa { fn cmp(&self, other: &Self) -> Ordering { match self.mname.name_cmp(&other.mname) { Ordering::Equal => {} @@ -277,7 +277,7 @@ impl Ord for Soa { } } -impl CanonicalOrd> for Soa { +impl CanonicalOrd> for Soa { fn canonical_cmp(&self, other: &Soa) -> Ordering { match self.mname.lowercase_composed_cmp(&other.mname) { Ordering::Equal => {} @@ -316,7 +316,7 @@ impl RecordData for Soa { } impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> - for Soa>> + for Soa>> { fn parse_rdata( rtype: Rtype, @@ -330,7 +330,7 @@ impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> } } -impl ComposeRecordData for Soa { +impl ComposeRecordData for Soa { fn rdlen(&self, compress: bool) -> Option { if compress { None @@ -349,8 +349,8 @@ impl ComposeRecordData for Soa { target: &mut Target, ) -> Result<(), Target::AppendError> { if target.can_compress() { - target.append_compressed_dname(&self.mname)?; - target.append_compressed_dname(&self.rname)?; + target.append_compressed_name(&self.mname)?; + target.append_compressed_name(&self.rname)?; } else { self.mname.compose(target)?; self.rname.compose(target)?; @@ -368,7 +368,7 @@ impl ComposeRecordData for Soa { } } -impl Soa { +impl Soa { fn compose_fixed( &self, target: &mut Target, @@ -405,7 +405,7 @@ impl fmt::Display for Soa { #[cfg(all(feature = "std", feature = "bytes"))] mod test { use super::*; - use crate::base::name::Dname; + use crate::base::name::Name; use crate::base::rdata::test::{ test_compose_parse, test_rdlen, test_scan, }; @@ -415,9 +415,9 @@ mod test { #[test] #[allow(clippy::redundant_closure)] // lifetimes ... fn soa_compose_parse_scan() { - let rdata = Soa::>>::new( - Dname::from_str("m.example.com").unwrap(), - Dname::from_str("r.example.com").unwrap(), + let rdata = Soa::>>::new( + Name::from_str("m.example.com").unwrap(), + Name::from_str("r.example.com").unwrap(), Serial(11), Ttl::from_secs(12), Ttl::from_secs(13), diff --git a/src/rdata/srv.rs b/src/rdata/srv.rs index 7d83eecc1..f0ce9dd8b 100644 --- a/src/rdata/srv.rs +++ b/src/rdata/srv.rs @@ -6,7 +6,7 @@ use crate::base::cmp::CanonicalOrd; use crate::base::iana::Rtype; -use crate::base::name::{FlattenInto, ParsedDname, ToDname}; +use crate::base::name::{FlattenInto, ParsedName, ToName}; use crate::base::rdata::{ComposeRecordData, ParseRecordData, RecordData}; use crate::base::scan::{Scan, Scanner}; use crate::base::wire::{Compose, Composer, Parse, ParseError}; @@ -85,19 +85,19 @@ impl Srv { )) } - pub fn scan>( + pub fn scan>( scanner: &mut S, ) -> Result { Ok(Self::new( u16::scan(scanner)?, u16::scan(scanner)?, u16::scan(scanner)?, - scanner.scan_dname()?, + scanner.scan_name()?, )) } } -impl Srv> { +impl Srv> { pub fn parse<'a, Src: Octets = Octs> + ?Sized + 'a>( parser: &mut Parser<'a, Src>, ) -> Result { @@ -105,7 +105,7 @@ impl Srv> { u16::parse(parser)?, u16::parse(parser)?, u16::parse(parser)?, - ParsedDname::parse(parser)?, + ParsedName::parse(parser)?, )) } } @@ -140,8 +140,8 @@ impl, TName> FlattenInto> for Srv { impl PartialEq> for Srv where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn eq(&self, other: &Srv) -> bool { self.priority == other.priority @@ -151,14 +151,14 @@ where } } -impl Eq for Srv {} +impl Eq for Srv {} //--- PartialOrd, Ord, and CanonicalOrd impl PartialOrd> for Srv where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn partial_cmp(&self, other: &Srv) -> Option { match self.priority.partial_cmp(&other.priority) { @@ -177,7 +177,7 @@ where } } -impl Ord for Srv { +impl Ord for Srv { fn cmp(&self, other: &Self) -> Ordering { match self.priority.cmp(&other.priority) { Ordering::Equal => {} @@ -195,7 +195,7 @@ impl Ord for Srv { } } -impl CanonicalOrd> for Srv { +impl CanonicalOrd> for Srv { fn canonical_cmp(&self, other: &Srv) -> Ordering { match self.priority.cmp(&other.priority) { Ordering::Equal => {} @@ -222,7 +222,7 @@ impl RecordData for Srv { } impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> - for Srv>> + for Srv>> { fn parse_rdata( rtype: Rtype, @@ -236,7 +236,7 @@ impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> } } -impl ComposeRecordData for Srv { +impl ComposeRecordData for Srv { fn rdlen(&self, _compress: bool) -> Option { // SRV records are not compressed. Some(self.target.compose_len() + 6) @@ -259,7 +259,7 @@ impl ComposeRecordData for Srv { } } -impl Srv { +impl Srv { fn compose_head( &self, target: &mut Target, @@ -288,7 +288,7 @@ impl fmt::Display for Srv { #[cfg(all(feature = "std", feature = "bytes"))] mod test { use super::*; - use crate::base::name::Dname; + use crate::base::name::Name; use crate::base::rdata::test::{ test_compose_parse, test_rdlen, test_scan, }; @@ -302,7 +302,7 @@ mod test { 10, 11, 12, - Dname::>::from_str("example.com.").unwrap(), + Name::>::from_str("example.com.").unwrap(), ); test_rdlen(&rdata); test_compose_parse(&rdata, |parser| Srv::parse(parser)); diff --git a/src/rdata/svcb/rdata.rs b/src/rdata/svcb/rdata.rs index 02c973cff..7e1dea983 100644 --- a/src/rdata/svcb/rdata.rs +++ b/src/rdata/svcb/rdata.rs @@ -5,7 +5,7 @@ use super::SvcParams; use crate::base::cmp::CanonicalOrd; use crate::base::iana::Rtype; -use crate::base::name::{FlattenInto, ParsedDname, ToDname}; +use crate::base::name::{FlattenInto, ParsedName, ToName}; use crate::base::rdata::{ ComposeRecordData, LongRecordData, ParseRecordData, RecordData, }; @@ -134,7 +134,7 @@ impl SvcbRdata { pub fn new( priority: u16, target: Name, params: SvcParams ) -> Result - where Octs: AsRef<[u8]>, Name: ToDname { + where Octs: AsRef<[u8]>, Name: ToName { LongRecordData::check_len( usize::from( u16::COMPOSE_LEN + target.compose_len() @@ -156,13 +156,13 @@ impl SvcbRdata { } } -impl> SvcbRdata> { +impl> SvcbRdata> { /// Parses service bindings record data from its wire format. pub fn parse<'a, Src: Octets = Octs> + ?Sized + 'a>( parser: &mut Parser<'a, Src> ) -> Result { let priority = u16::parse(parser)?; - let target = ParsedDname::parse(parser)?; + let target = ParsedName::parse(parser)?; let params = SvcParams::parse(parser)?; Ok(unsafe { Self::new_unchecked(priority, target, params) @@ -279,8 +279,8 @@ for SvcbRdata where Octs: AsRef<[u8]>, OtherOcts: AsRef<[u8]>, - Name: ToDname, - OtherName: ToDname, + Name: ToName, + OtherName: ToName, { fn eq( &self, other: &SvcbRdata @@ -291,7 +291,7 @@ where } } -impl, Name: ToDname> Eq +impl, Name: ToName> Eq for SvcbRdata { } //--- Hash @@ -313,8 +313,8 @@ for SvcbRdata where Octs: AsRef<[u8]>, OtherOcts: AsRef<[u8]>, - Name: ToDname, - OtherName: ToDname, + Name: ToName, + OtherName: ToName, { fn partial_cmp( &self, other: &SvcbRdata @@ -331,7 +331,7 @@ where } } -impl, Name: ToDname> Ord +impl, Name: ToName> Ord for SvcbRdata { fn cmp(&self, other: &Self) -> cmp::Ordering { match self.priority.cmp(&other.priority) { @@ -352,8 +352,8 @@ for SvcbRdata where Octs: AsRef<[u8]>, OtherOcts: AsRef<[u8]>, - Name: ToDname, - OtherName: ToDname, + Name: ToName, + OtherName: ToName, { fn canonical_cmp( &self, other: &SvcbRdata @@ -385,7 +385,7 @@ impl RecordData for SvcbRdata { } impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> -for SvcbRdata, ParsedDname>> { +for SvcbRdata, ParsedName>> { fn parse_rdata( rtype: Rtype, parser: &mut Parser<'a, Octs> ) -> Result, ParseError> { @@ -399,7 +399,7 @@ for SvcbRdata, ParsedDname>> { } impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> -for SvcbRdata, ParsedDname>> { +for SvcbRdata, ParsedName>> { fn parse_rdata( rtype: Rtype, parser: &mut Parser<'a, Octs> ) -> Result, ParseError> { @@ -413,7 +413,7 @@ for SvcbRdata, ParsedDname>> { } impl ComposeRecordData for SvcbRdata -where Self: RecordData, Octs: AsRef<[u8]>, Name: ToDname { +where Self: RecordData, Octs: AsRef<[u8]>, Name: ToName { fn rdlen(&self, _compress: bool) -> Option { Some( u16::checked_add( @@ -471,12 +471,12 @@ mod test { use super::*; use super::super::UnknownSvcParam; use super::super::value::AllValues; - use crate::base::Dname; + use crate::base::Name; use octseq::array::Array; use core::str::FromStr; type Octets512 = Array<512>; - type Dname512 = Dname>; + type Dname512 = Name>; type Params512 = SvcParams>; // We only do two tests here to see if the SvcbRdata type itself is diff --git a/src/rdata/tsig.rs b/src/rdata/tsig.rs index bfd7fba25..f0f4bcfc9 100644 --- a/src/rdata/tsig.rs +++ b/src/rdata/tsig.rs @@ -6,7 +6,7 @@ use crate::base::cmp::CanonicalOrd; use crate::base::iana::{Rtype, TsigRcode}; -use crate::base::name::{FlattenInto, ParsedDname, ToDname}; +use crate::base::name::{FlattenInto, ParsedName, ToName}; use crate::base::rdata::{ ComposeRecordData, LongRecordData, ParseRecordData, RecordData }; @@ -98,7 +98,7 @@ impl Tsig { error: TsigRcode, other: O, ) -> Result - where O: AsRef<[u8]>, N: ToDname { + where O: AsRef<[u8]>, N: ToName { LongRecordData::check_len( 6 // time_signed + 2 // fudge @@ -288,11 +288,11 @@ impl Tsig { } } -impl Tsig> { +impl Tsig> { pub fn parse<'a, Src: Octets = Octs> + ?Sized + 'a>( parser: &mut Parser<'a, Src>, ) -> Result { - let algorithm = ParsedDname::parse(parser)?; + let algorithm = ParsedName::parse(parser)?; let time_signed = Time48::parse(parser)?; let fudge = u16::parse(parser)?; let mac_size = u16::parse(parser)?; @@ -359,8 +359,8 @@ impl PartialEq> for Tsig where O: AsRef<[u8]>, OO: AsRef<[u8]>, - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn eq(&self, other: &Tsig) -> bool { self.algorithm.name_eq(&other.algorithm) @@ -373,7 +373,7 @@ where } } -impl, N: ToDname> Eq for Tsig {} +impl, N: ToName> Eq for Tsig {} //--- PartialOrd, Ord, and CanonicalOrd @@ -381,8 +381,8 @@ impl PartialOrd> for Tsig where O: AsRef<[u8]>, OO: AsRef<[u8]>, - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn partial_cmp(&self, other: &Tsig) -> Option { match self.algorithm.name_cmp(&other.algorithm) { @@ -413,7 +413,7 @@ where } } -impl, N: ToDname> Ord for Tsig { +impl, N: ToName> Ord for Tsig { fn cmp(&self, other: &Self) -> Ordering { match self.algorithm.name_cmp(&other.algorithm) { Ordering::Equal => {} @@ -447,8 +447,8 @@ impl CanonicalOrd> for Tsig where O: AsRef<[u8]>, OO: AsRef<[u8]>, - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, { fn canonical_cmp(&self, other: &Tsig) -> Ordering { match self.algorithm.composed_cmp(&other.algorithm) { @@ -510,7 +510,7 @@ impl RecordData for Tsig { } impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> - for Tsig, ParsedDname>> + for Tsig, ParsedName>> { fn parse_rdata( rtype: Rtype, @@ -524,7 +524,7 @@ impl<'a, Octs: Octets + ?Sized> ParseRecordData<'a, Octs> } } -impl, Name: ToDname> ComposeRecordData +impl, Name: ToName> ComposeRecordData for Tsig { fn rdlen(&self, _compress: bool) -> Option { @@ -716,7 +716,7 @@ impl fmt::Display for Time48 { #[cfg(all(feature = "std", feature = "bytes"))] mod test { use super::*; - use crate::base::name::Dname; + use crate::base::name::Name; use crate::base::rdata::test::{test_compose_parse, test_rdlen}; use core::str::FromStr; use std::vec::Vec; @@ -725,7 +725,7 @@ mod test { #[allow(clippy::redundant_closure)] // lifetimes ... fn tsig_compose_parse_scan() { let rdata = Tsig::new( - Dname::>::from_str("key.example.com.").unwrap(), + Name::>::from_str("key.example.com.").unwrap(), Time48::now(), 12, "foo", diff --git a/src/rdata/zonemd.rs b/src/rdata/zonemd.rs index 8cbe990bd..b94140673 100644 --- a/src/rdata/zonemd.rs +++ b/src/rdata/zonemd.rs @@ -396,7 +396,7 @@ mod test { #[cfg(feature = "zonefile")] #[test] fn zonemd_parse_zonefile() { - use crate::base::Dname; + use crate::base::Name; use crate::rdata::ZoneRecordData; use crate::zonefile::inplace::{Entry, Zonefile}; @@ -418,7 +418,7 @@ ns2 3600 IN AAAA 2001:db8::63 "#; let mut zone = Zonefile::load(&mut content.as_bytes()).unwrap(); - zone.set_origin(Dname::root()); + zone.set_origin(Name::root()); while let Some(entry) = zone.next_entry().unwrap() { match entry { Entry::Record(record) => { diff --git a/src/resolv/lookup/addr.rs b/src/resolv/lookup/addr.rs index 0b63905b1..7adf3023e 100644 --- a/src/resolv/lookup/addr.rs +++ b/src/resolv/lookup/addr.rs @@ -2,13 +2,12 @@ use crate::base::iana::Rtype; use crate::base::message::RecordIter; -use crate::base::name::{Dname, DnameBuilder, ParsedDname}; +use crate::base::name::{Name, ParsedName}; use crate::rdata::Ptr; use crate::resolv::resolver::Resolver; use octseq::octets::Octets; use std::io; use std::net::IpAddr; -use std::str::FromStr; //------------ Octets128 ----------------------------------------------------- @@ -28,7 +27,8 @@ pub async fn lookup_addr( resolv: &R, addr: IpAddr, ) -> Result, io::Error> { - let name = dname_from_addr(addr); + let name = Name::::reverse_from_addr(addr) + .expect("address domain name too long"); resolv.query((name, Rtype::PTR)).await.map(FoundAddrs) } @@ -63,7 +63,7 @@ impl<'a, R: Resolver> IntoIterator for &'a FoundAddrs where R::Octets: Octets, { - type Item = ParsedDname<<::Octets as Octets>::Range<'a>>; + type Item = ParsedName<<::Octets as Octets>::Range<'a>>; type IntoIter = FoundAddrsIter<'a, R::Octets>; fn into_iter(self) -> Self::IntoIter { @@ -75,12 +75,12 @@ where /// An iterator over host names returned by address lookup. pub struct FoundAddrsIter<'a, Octs: Octets> { - name: Option>>, - answer: Option>>>>, + name: Option>>, + answer: Option>>>>, } impl<'a, Octs: Octets> Iterator for FoundAddrsIter<'a, Octs> { - type Item = ParsedDname>; + type Item = ParsedName>; #[allow(clippy::while_let_on_iterator)] fn next(&mut self) -> Option { @@ -94,78 +94,3 @@ impl<'a, Octs: Octets> Iterator for FoundAddrsIter<'a, Octs> { None } } - -//------------ Helper Functions --------------------------------------------- - -/// Translates an IP address into a domain name. -fn dname_from_addr(addr: IpAddr) -> Dname { - match addr { - IpAddr::V4(addr) => { - let octets = addr.octets(); - Dname::from_str(&format!( - "{}.{}.{}.{}.in-addr.arpa.", - octets[3], octets[2], octets[1], octets[0] - )) - .unwrap() - } - IpAddr::V6(addr) => { - let mut res = DnameBuilder::::new(); - for &item in addr.octets().iter().rev() { - res.append_label(&[hexdigit(item)]).unwrap(); - res.append_label(&[hexdigit(item >> 4)]).unwrap(); - } - res.append_label(b"ip6").unwrap(); - res.append_label(b"arpa").unwrap(); - res.into_dname().unwrap() - } - } -} - -fn hexdigit(nibble: u8) -> u8 { - match nibble & 0x0F { - 0 => b'0', - 1 => b'1', - 2 => b'2', - 3 => b'3', - 4 => b'4', - 5 => b'5', - 6 => b'6', - 7 => b'7', - 8 => b'8', - 9 => b'9', - 10 => b'A', - 11 => b'B', - 12 => b'C', - 13 => b'D', - 14 => b'E', - 15 => b'F', - _ => unreachable!(), - } -} - -//============ Tests ========================================================= - -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_dname_from_addr() { - assert_eq!( - dname_from_addr([192, 0, 2, 12].into()), - Dname::::from_str("12.2.0.192.in-addr.arpa").unwrap() - ); - assert_eq!( - dname_from_addr( - [0x2001, 0xdb8, 0x1234, 0x0, 0x5678, 0x1, 0x9abc, 0xdef] - .into() - ), - Dname::::from_str( - "f.e.d.0.c.b.a.9.1.0.0.0.8.7.6.5.\ - 0.0.0.0.4.3.2.1.8.b.d.0.1.0.0.2.\ - ip6.arpa" - ) - .unwrap() - ); - } -} diff --git a/src/resolv/lookup/host.rs b/src/resolv/lookup/host.rs index e4424e220..cbf248971 100644 --- a/src/resolv/lookup/host.rs +++ b/src/resolv/lookup/host.rs @@ -2,7 +2,7 @@ use crate::base::iana::Rtype; use crate::base::message::RecordIter; -use crate::base::name::{ParsedDname, ToDname, ToRelativeDname}; +use crate::base::name::{ParsedName, ToName, ToRelativeName}; use crate::rdata::{Aaaa, A}; use crate::resolv::resolver::{Resolver, SearchNames}; use octseq::octets::Octets; @@ -22,7 +22,7 @@ use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; /// return the canonical name. pub async fn lookup_host( resolver: &R, - qname: impl ToDname, + qname: impl ToName, ) -> Result, io::Error> { let (a, aaaa) = tokio::join!( resolver.query((&qname, Rtype::A)), @@ -35,7 +35,7 @@ pub async fn lookup_host( pub async fn search_host( resolver: &R, - qname: impl ToRelativeDname, + qname: impl ToRelativeName, ) -> Result, io::Error> { for suffix in resolver.search_iter() { if let Ok(name) = (&qname).chain(suffix) { @@ -110,7 +110,7 @@ impl FoundHosts where R::Octets: Octets, { - pub fn qname(&self) -> ParsedDname<::Range<'_>> { + pub fn qname(&self) -> ParsedName<::Range<'_>> { self.answer() .as_ref() .first_question() @@ -127,7 +127,7 @@ where /// one of them. pub fn canonical_name( &self, - ) -> ParsedDname<::Range<'_>> { + ) -> ParsedName<::Range<'_>> { self.answer().as_ref().canonical_name().unwrap() } @@ -179,8 +179,8 @@ where /// An iterator over the IP addresses returned by a host lookup. #[derive(Clone)] pub struct FoundHostsIter<'a> { - aaaa_name: Option>, - a_name: Option>, + aaaa_name: Option>, + a_name: Option>, aaaa: Option>, a: Option>, } diff --git a/src/resolv/lookup/srv.rs b/src/resolv/lookup/srv.rs index 171d9a67f..f4a79f292 100644 --- a/src/resolv/lookup/srv.rs +++ b/src/resolv/lookup/srv.rs @@ -3,7 +3,7 @@ use super::host::lookup_host; use crate::base::iana::{Class, Rtype}; use crate::base::message::Message; -use crate::base::name::{Dname, ToDname, ToRelativeDname}; +use crate::base::name::{Name, ToName, ToRelativeName}; use crate::base::wire::ParseError; use crate::rdata::{Aaaa, Srv, A}; use crate::resolv::resolver::Resolver; @@ -63,8 +63,8 @@ type OctetsVec = Vec; ///[`TcpStream::connect`]: tokio::net::TcpStream::connect pub async fn lookup_srv( resolver: &impl Resolver, - service: impl ToRelativeDname, - name: impl ToDname, + service: impl ToRelativeName, + name: impl ToName, fallback_port: u16, ) -> Result, SrvError> { let full_name = match (&service).chain(&name) { @@ -126,7 +126,7 @@ impl FoundSrvs { /// /// If not results were found, the iterator will yield a single entry /// with the bare host and the default fallback port. - pub fn into_srvs(self) -> impl Iterator>> { + pub fn into_srvs(self) -> impl Iterator>> { let (left, right) = match self.items { Ok(ok) => (Some(ok.into_iter()), None), Err(err) => (None, Some(std::iter::once(err))), @@ -162,7 +162,7 @@ impl FoundSrvs { impl FoundSrvs { fn new( answer: &Message<[u8]>, - fallback_name: impl ToDname, + fallback_name: impl ToName, fallback_port: u16, ) -> Result, SrvError> { let name = @@ -187,7 +187,7 @@ impl FoundSrvs { fn process_records( answer: &Message<[u8]>, - name: &impl ToDname, + name: &impl ToName, ) -> Result, SrvError> { let mut res = Vec::new(); // XXX We could also error out if any SRV error is broken? @@ -280,7 +280,7 @@ impl FoundSrvs { #[derive(Clone, Debug)] pub struct SrvItem { /// The SRV record. - srv: Srv>, + srv: Srv>, /// Fall back? #[allow(dead_code)] // XXX Check if we can actually remove it. @@ -291,22 +291,22 @@ pub struct SrvItem { } impl SrvItem { - fn from_rdata(srv: &Srv) -> Self { + fn from_rdata(srv: &Srv) -> Self { SrvItem { srv: Srv::new( srv.priority(), srv.weight(), srv.port(), - srv.target().to_dname(), + srv.target().to_name(), ), fallback: false, resolved: None, } } - fn fallback(name: impl ToDname, fallback_port: u16) -> Self { + fn fallback(name: impl ToName, fallback_port: u16) -> Self { SrvItem { - srv: Srv::new(0, 0, fallback_port, name.to_dname()), + srv: Srv::new(0, 0, fallback_port, name.to_name()), fallback: true, resolved: None, } @@ -345,14 +345,14 @@ impl SrvItem { } } -impl AsRef>> for SrvItem { - fn as_ref(&self) -> &Srv> { +impl AsRef>> for SrvItem { + fn as_ref(&self) -> &Srv> { &self.srv } } impl ops::Deref for SrvItem { - type Target = Srv>; + type Target = Srv>; fn deref(&self) -> &Self::Target { self.as_ref() @@ -364,7 +364,7 @@ impl ops::Deref for SrvItem { /// An SRV record which has itself been resolved into a [`SocketAddr`]. #[derive(Clone, Debug)] pub struct ResolvedSrvItem { - srv: Srv>, + srv: Srv>, resolved: Vec, } @@ -375,14 +375,14 @@ impl ResolvedSrvItem { } } -impl AsRef>> for ResolvedSrvItem { - fn as_ref(&self) -> &Srv> { +impl AsRef>> for ResolvedSrvItem { + fn as_ref(&self) -> &Srv> { &self.srv } } impl ops::Deref for ResolvedSrvItem { - type Target = Srv>; + type Target = Srv>; fn deref(&self) -> &Self::Target { self.as_ref() diff --git a/src/resolv/resolver.rs b/src/resolv/resolver.rs index 49dc7587b..c26584cb8 100644 --- a/src/resolv/resolver.rs +++ b/src/resolv/resolver.rs @@ -1,7 +1,7 @@ //! The trait defining an abstract resolver. use crate::base::message::Message; -use crate::base::name::ToDname; +use crate::base::name::ToName; use crate::base::question::Question; use std::future::Future; use std::io; @@ -33,7 +33,7 @@ pub trait Resolver { /// produces a future trying to answer the question. fn query(&self, question: Q) -> Self::Query where - N: ToDname, + N: ToName, Q: Into>; } @@ -48,7 +48,7 @@ pub trait Resolver { /// A search resolver is a resolver that provides such a list. This is /// implemented via an iterator over domain names. pub trait SearchNames { - type Name: ToDname; + type Name: ToName; type Iter: Iterator; /// Returns an iterator over the search suffixes. diff --git a/src/resolv/stub/conf.rs b/src/resolv/stub/conf.rs index cc23256c4..12e6839b9 100644 --- a/src/resolv/stub/conf.rs +++ b/src/resolv/stub/conf.rs @@ -8,7 +8,7 @@ //! //! Both parts are modeled along the lines of glibc’s resolver. -use crate::base::name::{self, Dname}; +use crate::base::name::{self, Name}; use smallvec::SmallVec; use std::cmp::Ordering; use std::default::Default; @@ -202,35 +202,13 @@ impl Default for ResolvOptions { /// The transport protocol to be used for a server. #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum Transport { - /// Unencrypted UDP transport. - Udp, + /// Unencrypted UDP transport, switch to TCP for truncated responses. + UdpTcp, /// Unencrypted TCP transport. Tcp, } -impl Transport { - /// Returns whether the transport is a preferred transport. - /// - /// Only preferred transports are considered initially. Only if a - /// truncated answer comes back will we consider streaming protocols - /// instead. - pub fn is_preferred(self) -> bool { - match self { - Transport::Udp => true, - Transport::Tcp => false, - } - } - - /// Returns whether the transport is a streaming protocol. - pub fn is_stream(self) -> bool { - match self { - Transport::Udp => false, - Transport::Tcp => true, - } - } -} - //------------ ServerConf ---------------------------------------------------- /// Configuration for one upstream DNS server. @@ -344,13 +322,13 @@ impl ResolvConf { if self.servers.is_empty() { // glibc just simply uses 127.0.0.1:53. Let's do that, too, // and claim it is for compatibility. - let addr = - SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 53); - self.servers.push(ServerConf::new(addr, Transport::Udp)); - self.servers.push(ServerConf::new(addr, Transport::Tcp)); + self.servers.push(ServerConf::new( + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 53), + Transport::UdpTcp, + )); } if self.options.search.is_empty() { - self.options.search.push(Dname::root()) + self.options.search.push(Name::root()) } for server in &mut self.servers { server.request_timeout = self.options.timeout @@ -409,8 +387,7 @@ impl ResolvConf { use std::net::ToSocketAddrs; for addr in (next_word(&mut words)?, 53).to_socket_addrs()? { - self.servers.push(ServerConf::new(addr, Transport::Udp)); - self.servers.push(ServerConf::new(addr, Transport::Tcp)); + self.servers.push(ServerConf::new(addr, Transport::UdpTcp)); } no_more_words(words) } @@ -610,7 +587,7 @@ impl fmt::Display for ResolvConf { //------------ SearchSuffix -------------------------------------------------- -pub type SearchSuffix = Dname>; +pub type SearchSuffix = Name>; //------------ SearchList ---------------------------------------------------- @@ -629,7 +606,7 @@ impl SearchList { } pub fn push_root(&mut self) { - self.search.push(Dname::root()) + self.search.push(Name::root()) } pub fn len(&self) -> usize { diff --git a/src/resolv/stub/mod.rs b/src/resolv/stub/mod.rs index fbc69377a..c6f7ddb0e 100644 --- a/src/resolv/stub/mod.rs +++ b/src/resolv/stub/mod.rs @@ -15,7 +15,7 @@ use self::conf::{ use crate::base::iana::Rcode; use crate::base::message::Message; use crate::base::message_builder::{AdditionalBuilder, MessageBuilder}; -use crate::base::name::{ToDname, ToRelativeDname}; +use crate::base::name::{ToName, ToRelativeName}; use crate::base::question::Question; use crate::net::client::dgram_stream; use crate::net::client::multi_stream; @@ -103,7 +103,7 @@ impl StubResolver { &self.options } - pub async fn query>>( + pub async fn query>>( &self, question: Q, ) -> Result { @@ -146,35 +146,24 @@ impl StubResolver { // We have 3 modes of operation: use_vc: only use TCP, ign_tc: only // UDP no fallback to TCP, and normal with is UDP falling back to TCP. - if self.options.use_vc { - for s in &self.servers { - if let Transport::Tcp = s.transport { - let (conn, tran) = multi_stream::Connection::new( - TcpConnect::new(s.addr), - ); - // Start the run function on a separate task. - let run_fut = tran.run(); - fut_list_tcp.push(async move { - run_fut.await; - }); - redun.add(Box::new(conn)).await?; - } - } - } else { - for s in &self.servers { - if let Transport::Udp = s.transport { - let udp_connect = UdpConnect::new(s.addr); - let tcp_connect = TcpConnect::new(s.addr); - let (conn, tran) = dgram_stream::Connection::new( - udp_connect, - tcp_connect, - ); - // Start the run function on a separate task. - fut_list_udp_tcp.push(async move { - tran.run().await; - }); - redun.add(Box::new(conn)).await?; - } + + for s in &self.servers { + // This assumes that Transport only has UdpTcp and Tcp. Sadly, a + // match doesn’t work here because of the use_cv flag. + if self.options.use_vc || matches!(s.transport, Transport::Tcp) { + let (conn, tran) = + multi_stream::Connection::new(TcpConnect::new(s.addr)); + // Start the run function on a separate task. + fut_list_tcp.push(tran.run()); + redun.add(Box::new(conn)).await?; + } else { + let udp_connect = UdpConnect::new(s.addr); + let tcp_connect = TcpConnect::new(s.addr); + let (conn, tran) = + dgram_stream::Connection::new(udp_connect, tcp_connect); + // Start the run function on a separate task. + fut_list_udp_tcp.push(tran.run()); + redun.add(Box::new(conn)).await?; } } @@ -232,14 +221,14 @@ impl StubResolver { pub async fn lookup_host( &self, - qname: impl ToDname, + qname: impl ToName, ) -> Result, io::Error> { lookup_host(&self, qname).await } pub async fn search_host( &self, - qname: impl ToRelativeDname, + qname: impl ToRelativeName, ) -> Result, io::Error> { search_host(&self, qname).await } @@ -249,8 +238,8 @@ impl StubResolver { /// See the documentation for the [`lookup_srv`] function for details. pub async fn lookup_srv( &self, - service: impl ToRelativeDname, - name: impl ToDname, + service: impl ToRelativeName, + name: impl ToName, fallback_port: u16, ) -> Result, SrvError> { lookup_srv(&self, service, name, fallback_port).await @@ -313,7 +302,7 @@ impl<'a> Resolver for &'a StubResolver { fn query(&self, question: Q) -> Self::Query where - N: ToDname, + N: ToName, Q: Into>, { let message = Query::create_message(question.into()); @@ -390,7 +379,7 @@ impl<'a> Query<'a> { } } - fn create_message(question: Question) -> QueryMessage { + fn create_message(question: Question) -> QueryMessage { let mut message = MessageBuilder::from_target(Default::default()) .expect("MessageBuilder should not fail"); message.header_mut().set_rd(true); diff --git a/src/sign/key.rs b/src/sign/key.rs index 9362ef97c..da9385780 100644 --- a/src/sign/key.rs +++ b/src/sign/key.rs @@ -1,5 +1,5 @@ use crate::base::iana::SecAlg; -use crate::base::name::ToDname; +use crate::base::name::ToName; use crate::rdata::{Dnskey, Ds}; pub trait SigningKey { @@ -8,7 +8,7 @@ pub trait SigningKey { type Error; fn dnskey(&self) -> Result, Self::Error>; - fn ds( + fn ds( &self, owner: N, ) -> Result, Self::Error>; @@ -32,7 +32,7 @@ impl<'a, K: SigningKey> SigningKey for &'a K { fn dnskey(&self) -> Result, Self::Error> { (*self).dnskey() } - fn ds( + fn ds( &self, owner: N, ) -> Result, Self::Error> { diff --git a/src/sign/records.rs b/src/sign/records.rs index 7d0c3a3d5..df2b4128a 100644 --- a/src/sign/records.rs +++ b/src/sign/records.rs @@ -3,7 +3,7 @@ use super::key::SigningKey; use crate::base::cmp::CanonicalOrd; use crate::base::iana::{Class, Rtype}; -use crate::base::name::ToDname; +use crate::base::name::ToName; use crate::base::rdata::{ComposeRecordData, RecordData}; use crate::base::record::Record; use crate::base::Ttl; @@ -30,7 +30,7 @@ impl SortedRecords { pub fn insert(&mut self, record: Record) -> Result<(), Record> where - N: ToDname, + N: ToName, D: RecordData + CanonicalOrd, { let idx = self @@ -55,7 +55,7 @@ impl SortedRecords { pub fn find_soa(&self) -> Option> where - N: ToDname, + N: ToName, D: RecordData, { self.rrsets().find(|rrset| rrset.rtype() == Rtype::SOA) @@ -70,11 +70,11 @@ impl SortedRecords { key: Key, ) -> Result>>, Key::Error> where - N: ToDname + Clone, + N: ToName + Clone, D: RecordData + ComposeRecordData, Key: SigningKey, Octets: From + AsRef<[u8]>, - ApexName: ToDname + Clone, + ApexName: ToName + Clone, { let mut res = Vec::new(); let mut buf = Vec::new(); @@ -168,12 +168,12 @@ impl SortedRecords { ttl: Ttl, ) -> Vec>> where - N: ToDname + Clone, + N: ToName + Clone, D: RecordData, Octets: FromBuilder, Octets::Builder: EmptyBuilder + Truncate + AsRef<[u8]> + AsMut<[u8]>, ::AppendError: fmt::Debug, - ApexName: ToDname, + ApexName: ToName, { let mut res = Vec::new(); @@ -264,7 +264,7 @@ impl Default for SortedRecords { impl From>> for SortedRecords where - N: ToDname, + N: ToName, D: RecordData + CanonicalOrd, { fn from(mut src: Vec>) -> Self { @@ -275,7 +275,7 @@ where impl FromIterator> for SortedRecords where - N: ToDname, + N: ToName, D: RecordData + CanonicalOrd, { fn from_iter>>(iter: T) -> Self { @@ -289,7 +289,7 @@ where impl Extend> for SortedRecords where - N: ToDname, + N: ToName, D: RecordData + CanonicalOrd, { fn extend>>(&mut self, iter: T) { @@ -333,17 +333,17 @@ impl<'a, N, D> Family<'a, N, D> { pub fn is_zone_cut(&self, apex: &FamilyName) -> bool where - N: ToDname, - NN: ToDname, + N: ToName, + NN: ToName, D: RecordData, { self.family_name().ne(apex) && self.records().any(|record| record.rtype() == Rtype::NS) } - pub fn is_in_zone(&self, apex: &FamilyName) -> bool + pub fn is_in_zone(&self, apex: &FamilyName) -> bool where - N: ToDname, + N: ToName, { self.owner().ends_with(&apex.owner) && self.class() == apex.class } @@ -396,7 +396,7 @@ impl FamilyName { key: K, ) -> Result>, K::Error> where - N: ToDname + Clone, + N: ToName + Clone, { key.ds(&self.owner) .map(|ds| self.clone().into_record(ttl, ds)) @@ -412,13 +412,13 @@ impl<'a, N: Clone> FamilyName<&'a N> { } } -impl PartialEq> for FamilyName { +impl PartialEq> for FamilyName { fn eq(&self, other: &FamilyName) -> bool { self.owner.name_eq(&other.owner) && self.class == other.class } } -impl PartialEq> for FamilyName { +impl PartialEq> for FamilyName { fn eq(&self, other: &Record) -> bool { self.owner.name_eq(other.owner()) && self.class == other.class() } @@ -484,9 +484,9 @@ impl<'a, N, D> RecordsIter<'a, N, D> { self.slice[0].owner() } - pub fn skip_before(&mut self, apex: &FamilyName) + pub fn skip_before(&mut self, apex: &FamilyName) where - N: ToDname, + N: ToName, { while let Some(first) = self.slice.first() { if apex == first { @@ -499,7 +499,7 @@ impl<'a, N, D> RecordsIter<'a, N, D> { impl<'a, N, D> Iterator for RecordsIter<'a, N, D> where - N: ToDname + 'a, + N: ToName + 'a, D: RecordData + 'a, { type Item = Family<'a, N, D>; @@ -539,7 +539,7 @@ impl<'a, N, D> RrsetIter<'a, N, D> { impl<'a, N, D> Iterator for RrsetIter<'a, N, D> where - N: ToDname + 'a, + N: ToName + 'a, D: RecordData + 'a, { type Item = Rrset<'a, N, D>; @@ -580,7 +580,7 @@ impl<'a, N, D> FamilyIter<'a, N, D> { impl<'a, N, D> Iterator for FamilyIter<'a, N, D> where - N: ToDname + 'a, + N: ToName + 'a, D: RecordData + 'a, { type Item = Rrset<'a, N, D>; diff --git a/src/sign/ring.rs b/src/sign/ring.rs index d70530e2e..bf4614f2b 100644 --- a/src/sign/ring.rs +++ b/src/sign/ring.rs @@ -4,7 +4,7 @@ use super::key::SigningKey; use crate::base::iana::{DigestAlg, SecAlg}; -use crate::base::name::ToDname; +use crate::base::name::ToName; use crate::base::rdata::ComposeRecordData; use crate::rdata::{Dnskey, Ds}; #[cfg(feature = "bytes")] @@ -70,7 +70,7 @@ impl<'a> SigningKey for Key<'a> { Ok(self.dnskey.clone()) } - fn ds( + fn ds( &self, owner: N, ) -> Result, Self::Error> { diff --git a/src/tsig/interop.rs b/src/tsig/interop.rs index 82b5c2def..f8055e252 100644 --- a/src/tsig/interop.rs +++ b/src/tsig/interop.rs @@ -6,7 +6,7 @@ use crate::base::message::Message; use crate::base::message_builder::{ AdditionalBuilder, AnswerBuilder, MessageBuilder, StreamTarget, }; -use crate::base::name::Dname; +use crate::base::name::Name; use crate::base::record::Ttl; use crate::rdata::tsig::Time48; use crate::rdata::{Soa, A}; @@ -48,7 +48,7 @@ fn tsig_client_nsd() { let (key, secret) = tsig::Key::generate( tsig::Algorithm::Sha1, &rng, - Dname::from_str("test.key.").unwrap(), + Name::from_str("test.key.").unwrap(), None, None, ) @@ -84,7 +84,7 @@ fn tsig_client_nsd() { // Create an AXFR request and send it to NSD. let request = TestBuilder::new_stream_vec(); let mut request = request - .request_axfr(Dname::>::from_str("example.com.").unwrap()) + .request_axfr(Name::>::from_str("example.com.").unwrap()) .unwrap() .additional(); let tran = tsig::ClientTransaction::request( @@ -127,7 +127,7 @@ fn tsig_server_drill() { let (key, secret) = tsig::Key::generate( tsig::Algorithm::Sha1, &rng, - Dname::from_str("test.key.").unwrap(), + Name::from_str("test.key.").unwrap(), None, None, ) @@ -199,7 +199,7 @@ fn tsig_client_sequence_nsd() { let (key, secret) = tsig::Key::generate( tsig::Algorithm::Sha1, &rng, - Dname::from_str("test.key.").unwrap(), + Name::from_str("test.key.").unwrap(), None, None, ) @@ -235,7 +235,7 @@ fn tsig_client_sequence_nsd() { let mut sock = TcpStream::connect("127.0.0.1:54323").unwrap(); let request = TestBuilder::new_stream_vec(); let mut request = request - .request_axfr(Dname::>::from_str("example.com.").unwrap()) + .request_axfr(Name::>::from_str("example.com.").unwrap()) .unwrap() .additional(); let mut tran = @@ -277,7 +277,7 @@ fn tsig_server_sequence_drill() { let (key, secret) = tsig::Key::generate( tsig::Algorithm::Sha1, &rng, - Dname::from_str("test.key.").unwrap(), + Name::from_str("test.key.").unwrap(), None, None, ) @@ -376,11 +376,11 @@ fn make_last_axfr(request: &TestMessage) -> TestAdditional { fn push_soa(builder: &mut TestAnswer) { builder .push(( - Dname::>::from_str("example.com.").unwrap(), + Name::>::from_str("example.com.").unwrap(), 3600, Soa::new( - Dname::>::from_str("mname.example.com.").unwrap(), - Dname::>::from_str("rname.example.com.").unwrap(), + Name::>::from_str("mname.example.com.").unwrap(), + Name::>::from_str("rname.example.com.").unwrap(), 12.into(), Ttl::from_secs(3600), Ttl::from_secs(3600), @@ -394,7 +394,7 @@ fn push_soa(builder: &mut TestAnswer) { fn push_a(builder: &mut TestAnswer, zero: u8, one: u8, two: u8) { builder .push(( - Dname::>::from_str("example.com.").unwrap(), + Name::>::from_str("example.com.").unwrap(), 3600, A::from_octets(10, zero, one, two), )) diff --git a/src/tsig/mod.rs b/src/tsig/mod.rs index 9dbd79dfb..a8fd9d525 100644 --- a/src/tsig/mod.rs +++ b/src/tsig/mod.rs @@ -62,7 +62,7 @@ use crate::base::message::Message; use crate::base::message_builder::{ AdditionalBuilder, MessageBuilder, PushError, }; -use crate::base::name::{Dname, Label, ParsedDname, ToDname, ToLabelIter}; +use crate::base::name::{Label, Name, ParsedName, ToLabelIter, ToName}; use crate::base::record::Record; use crate::base::wire::{Composer, ParseError}; use crate::rdata::tsig::{Time48, Tsig}; @@ -75,7 +75,7 @@ use std::collections::HashMap; //------------ KeyName ------------------------------------------------------- -pub type KeyName = Dname>; +pub type KeyName = Name>; //------------ Key ----------------------------------------------------------- @@ -267,7 +267,7 @@ impl Key { tsig: &MessageTsig, ) -> Result<(), ValidationError> { if *tsig.record.owner() != self.name - || *tsig.record.data().algorithm() != self.algorithm().to_dname() + || *tsig.record.data().algorithm() != self.algorithm().to_name() { Err(ValidationError::BadKey) } else { @@ -352,7 +352,7 @@ pub trait KeyStore { /// /// The method looks up a key based on a pair of name and algorithm. If /// the key can be found, it is returned. Otherwise, `None` is returned. - fn get_key( + fn get_key( &self, name: &N, algorithm: Algorithm, @@ -362,7 +362,7 @@ pub trait KeyStore { impl + Clone> KeyStore for K { type Key = Self; - fn get_key( + fn get_key( &self, name: &N, algorithm: Algorithm, @@ -385,13 +385,13 @@ where { type Key = K; - fn get_key( + fn get_key( &self, name: &N, algorithm: Algorithm, ) -> Option { // XXX This seems a bit wasteful. - let name = name.try_to_dname().ok()?; + let name = name.try_to_name().ok()?; self.get(&(name, algorithm)).cloned() } } @@ -1011,7 +1011,7 @@ impl> SigningContext { // 4.5.1. KEY check and error handling let algorithm = - match Algorithm::from_dname(tsig.record.data().algorithm()) { + match Algorithm::from_name(tsig.record.data().algorithm()) { Some(algorithm) => algorithm, None => return Err(ServerError::unsigned(TsigRcode::BADKEY)), }; @@ -1297,8 +1297,8 @@ struct MessageTsig<'a, Octs: Octets + ?Sized + 'a> { /// The actual record. #[allow(clippy::type_complexity)] record: Record< - ParsedDname>, - Tsig, ParsedDname>>, + ParsedName>, + Tsig, ParsedName>>, >, /// The index of the start of the record. @@ -1405,7 +1405,7 @@ impl Variables { // that the hmac is unreasonable large. Since we control its // creation, panicing in this case is fine. Tsig::new( - key.algorithm().to_dname(), + key.algorithm().to_name(), self.time_signed, self.fudge, hmac, @@ -1474,7 +1474,7 @@ impl Algorithm { /// Creates a value from its domain name representation. /// /// Returns `None` if the name doesn’t represent a known algorithm. - pub fn from_dname(name: &N) -> Option { + pub fn from_name(name: &N) -> Option { let mut labels = name.iter_labels(); let first = match labels.next() { Some(label) => label, @@ -1531,8 +1531,8 @@ impl Algorithm { } /// Returns a domain name for this value. - pub fn to_dname(self) -> Dname<&'static [u8]> { - unsafe { Dname::from_octets_unchecked(self.into_wire_slice()) } + pub fn to_name(self) -> Name<&'static [u8]> { + unsafe { Name::from_octets_unchecked(self.into_wire_slice()) } } /// Returns the native length of a signature created with this algorithm. diff --git a/src/validate.rs b/src/validate.rs index eb883178e..b6a82b64a 100644 --- a/src/validate.rs +++ b/src/validate.rs @@ -6,7 +6,7 @@ use crate::base::cmp::CanonicalOrd; use crate::base::iana::{DigestAlg, SecAlg}; -use crate::base::name::ToDname; +use crate::base::name::ToName; use crate::base::rdata::{ComposeRecordData, RecordData}; use crate::base::record::Record; use crate::base::wire::{Compose, Composer}; @@ -38,9 +38,9 @@ pub trait DnskeyExt { /// ``` /// /// [RFC 4034, Section 5.1.4]: https://tools.ietf.org/html/rfc4034#section-5.1.4 - fn digest( + fn digest( &self, - dname: &N, + name: &N, algorithm: DigestAlg, ) -> Result; } @@ -67,14 +67,14 @@ where /// ``` /// /// [RFC 4034, Section 5.1.4]: https://tools.ietf.org/html/rfc4034#section-5.1.4 - fn digest( + fn digest( &self, - dname: &N, + name: &N, algorithm: DigestAlg, ) -> Result { let mut buf: Vec = Vec::new(); with_infallible(|| { - dname.compose_canonical(&mut buf)?; + name.compose_canonical(&mut buf)?; self.compose_canonical_rdata(&mut buf) }); @@ -109,13 +109,13 @@ pub trait RrsigExt { /// the received RRset due to DNS name compression, decremented TTLs, or /// wildcard expansion. /// ``` - fn signed_data( + fn signed_data( &self, buf: &mut B, records: &mut [impl AsRef>], ) -> Result<(), B::AppendError> where - D: CanonicalOrd + ComposeRecordData + Sized; + D: RecordData + CanonicalOrd + ComposeRecordData + Sized; /// Attempt to use the cryptographic signature to authenticate the signed data, and thus authenticate the RRSET. /// The signed data is expected to be calculated as per [RFC4035, Section 5.3.2](https://tools.ietf.org/html/rfc4035#section-5.3.2). @@ -145,14 +145,14 @@ pub trait RrsigExt { ) -> Result<(), AlgorithmError>; } -impl, Name: ToDname> RrsigExt for Rrsig { - fn signed_data( +impl, Name: ToName> RrsigExt for Rrsig { + fn signed_data( &self, buf: &mut B, records: &mut [impl AsRef>], ) -> Result<(), B::AppendError> where - D: CanonicalOrd + ComposeRecordData + Sized, + D: RecordData + CanonicalOrd + ComposeRecordData + Sized, { // signed_data = RRSIG_RDATA | RR(1) | RR(2)... where // "|" denotes concatenation @@ -356,10 +356,10 @@ mod test { use bytes::Bytes; use std::str::FromStr; - type Dname = crate::base::name::Dname>; + type Name = crate::base::name::Name>; type Ds = crate::rdata::Ds>; type Dnskey = crate::rdata::Dnskey>; - type Rrsig = crate::rdata::Rrsig, Dname>; + type Rrsig = crate::rdata::Rrsig, Name>; // Returns current root KSK/ZSK for testing (2048b) fn root_pubkey() -> (Dnskey, Dnskey) { @@ -408,7 +408,7 @@ mod test { #[test] fn dnskey_digest() { let (dnskey, _) = root_pubkey(); - let owner = Dname::root(); + let owner = Name::root(); let expected = Ds::new( 20326, SecAlg::RSASHA256, @@ -428,7 +428,7 @@ mod test { #[test] fn dnskey_digest_unsupported() { let (dnskey, _) = root_pubkey(); - let owner = Dname::root(); + let owner = Name::root(); assert!(dnskey.digest(&owner, DigestAlg::GOST).is_err()); } @@ -472,7 +472,7 @@ mod test { 1560211200.into(), 1558396800.into(), 20326, - Dname::root(), + Name::root(), base64::decode::>( "otBkINZAQu7AvPKjr/xWIEE7+SoZtKgF8bzVynX6bfJMJuPay8jPvNmwXkZOdSoYlvFp0bk9JWJKCh8y5uoNfMFkN6OSrDkr3t0E+c8c0Mnmwkk5CETH3Gqxthi0yyRX5T4VlHU06/Ks4zI+XAgl3FBpOc554ivdzez8YCjAIGx7XgzzooEb7heMSlLc7S7/HNjw51TPRs4RxrAVcezieKCzPPpeWBhjE6R3oiSwrl0SBD4/yplrDlr7UHs/Atcm3MSgemdyr2sOoOUkVQCVpcj3SQQezoD2tCM7861CXEQdg5fjeHDtz285xHt5HJpA5cOcctRo4ihybfow/+V7AQ==", ) @@ -490,7 +490,7 @@ mod test { Timestamp::from_str("20210921162830").unwrap(), Timestamp::from_str("20210906162330").unwrap(), 35886, - "net.".parse::().unwrap(), + "net.".parse::().unwrap(), base64::decode::>( "j1s1IPMoZd0mbmelNVvcbYNe2tFCdLsLpNCnQ8xW6d91ujwPZ2yDlc3lU3hb+Jq3sPoj+5lVgB7fZzXQUQTPFWLF7zvW49da8pWuqzxFtg6EjXRBIWH5rpEhOcr+y3QolJcPOTx+/utCqt2tBKUUy3LfM6WgvopdSGaryWdwFJPW7qKHjyyLYxIGx5AEuLfzsA5XZf8CmpUheSRH99GRZoIB+sQzHuelWGMQ5A42DPvOVZFmTpIwiT2QaIpid4nJ7jNfahfwFrCoS+hvqjK9vktc5/6E/Mt7DwCQDaPt5cqDfYltUitQy+YA5YP5sOhINChYadZe+2N80OA+RKz0mA==", ) @@ -538,7 +538,7 @@ mod test { .unwrap(), ); - let owner = Dname::from_str("cloudflare.com.").unwrap(); + let owner = Name::from_str("cloudflare.com.").unwrap(); let rrsig = Rrsig::new( Rtype::DNSKEY, SecAlg::ECDSAP256SHA256, @@ -584,7 +584,7 @@ mod test { ); let owner = - Dname::from_octets(Vec::from(b"\x07ED25519\x02nl\x00".as_ref())) + Name::from_octets(Vec::from(b"\x07ED25519\x02nl\x00".as_ref())) .unwrap(); let rrsig = Rrsig::new( Rtype::DNSKEY, @@ -616,7 +616,7 @@ mod test { 1560211200.into(), 1558396800.into(), 20326, - Dname::root(), + Name::root(), base64::decode::>( "otBkINZAQu7AvPKjr/xWIEE7+SoZtKgF8bzVynX6bfJMJuPay8jPvNmwXkZ\ OdSoYlvFp0bk9JWJKCh8y5uoNfMFkN6OSrDkr3t0E+c8c0Mnmwkk5CETH3Gq\ @@ -629,7 +629,7 @@ mod test { ) .unwrap(); - let mut records: Vec, Dname>>> = + let mut records: Vec, Name>>> = [&ksk, &zsk] .iter() .cloned() @@ -676,7 +676,7 @@ mod test { Timestamp::from_str("20040509183619").unwrap(), Timestamp::from_str("20040409183619").unwrap(), 38519, - Dname::from_str("example.").unwrap(), + Name::from_str("example.").unwrap(), base64::decode::>( "OMK8rAZlepfzLWW75Dxd63jy2wswESzxDKG2f9AMN1CytCd10cYI\ SAxfAdvXSZ7xujKAtPbctvOQ2ofO7AZJ+d01EeeQTVBPq4/6KCWhq\ @@ -687,10 +687,10 @@ mod test { ) .unwrap(); let record = Record::new( - Dname::from_str("a.z.w.example.").unwrap(), + Name::from_str("a.z.w.example.").unwrap(), Class::IN, Ttl::from_secs(3600), - Mx::new(1, Dname::from_str("ai.example.").unwrap()), + Mx::new(1, Name::from_str("ai.example.").unwrap()), ); let signed_data = { let mut buf = Vec::new(); diff --git a/src/zonefile/inplace.rs b/src/zonefile/inplace.rs index e1464ceab..81dc40ed4 100644 --- a/src/zonefile/inplace.rs +++ b/src/zonefile/inplace.rs @@ -21,7 +21,7 @@ use octseq::str::Str; use crate::base::charstr::CharStr; use crate::base::iana::{Class, Rtype}; -use crate::base::name::{Chain, Dname, RelativeDname, ToDname}; +use crate::base::name::{Chain, Name, RelativeName, ToName}; use crate::base::record::Record; use crate::base::scan::{ BadSymbol, ConvertSymbols, EntrySymbol, Scan, Scanner, ScannerError, @@ -33,7 +33,7 @@ use crate::rdata::ZoneRecordData; //------------ Type Aliases -------------------------------------------------- /// The type used for scanned domain names. -pub type ScannedDname = Chain, Dname>; +pub type ScannedDname = Chain, Name>; /// The type used for scanned record data. pub type ScannedRecordData = ZoneRecordData; @@ -64,7 +64,7 @@ pub struct Zonefile { buf: SourceBuf, /// The current origin. - origin: Option>, + origin: Option>, /// The last owner. last_owner: Option, @@ -167,7 +167,7 @@ impl Zonefile { /// is not provided via this function or via an $ORIGIN directive, then /// any relative names encountered will cause iteration to terminate with /// a missing origin error. - pub fn set_origin(&mut self, origin: Dname) { + pub fn set_origin(&mut self, origin: Name) { self.origin = Some(origin) } @@ -192,7 +192,7 @@ impl Zonefile { } /// Returns the origin name of the zonefile. - pub fn origin(&self) -> Result, EntryError> { + pub fn origin(&self) -> Result, EntryError> { self.origin .as_ref() .cloned() @@ -226,7 +226,7 @@ pub enum Entry { path: ScannedString, /// The initial origin name of the included file, if provided. - origin: Option>, + origin: Option>, }, } @@ -243,7 +243,7 @@ enum ScannedEntry { Entry(Entry), /// An `$ORIGIN` directive changing the origin name. - Origin(Dname), + Origin(Name), /// A `$TTL` directive changing the default TTL if it isn’t given. Ttl(Ttl), @@ -322,7 +322,7 @@ impl<'a> EntryScanner<'a> { /// Scans a regular record with an owner name of `@`. fn scan_at_record(&mut self) -> Result { - let owner = RelativeDname::empty_bytes() + let owner = RelativeName::empty_bytes() .chain(match self.zonefile.origin.as_ref().cloned() { Some(origin) => origin, None => return Err(EntryError::missing_origin()), @@ -456,13 +456,13 @@ impl<'a> EntryScanner<'a> { fn scan_control(&mut self) -> Result { let ctrl = self.scan_string()?; if ctrl.eq_ignore_ascii_case("$ORIGIN") { - let origin = self.scan_dname()?.to_dname(); + let origin = self.scan_name()?.to_name(); self.zonefile.buf.require_line_feed()?; Ok(ScannedEntry::Origin(origin)) } else if ctrl.eq_ignore_ascii_case("$INCLUDE") { let path = self.scan_string()?; let origin = if !self.zonefile.buf.is_line_feed() { - Some(self.scan_dname()?.to_dname()) + Some(self.scan_name()?.to_name()) } else { None }; @@ -481,7 +481,7 @@ impl<'a> EntryScanner<'a> { impl<'a> Scanner for EntryScanner<'a> { type Octets = Bytes; type OctetsBuilder = BytesMut; - type Dname = ScannedDname; + type Name = ScannedDname; type Error = EntryError; fn has_space(&self) -> bool { @@ -613,7 +613,7 @@ impl<'a> Scanner for EntryScanner<'a> { Ok(res) } - fn scan_dname(&mut self) -> Result { + fn scan_name(&mut self) -> Result { // Because the labels in a domain name have their content preceeded // by the length octet, an unescaped domain name can be almost as is // if we have one extra octet to the left. Luckily, we always do @@ -639,16 +639,16 @@ impl<'a> Scanner for EntryScanner<'a> { // have an empty domain name which is just the origin. self.zonefile.buf.next_item()?; if start == 0 { - return RelativeDname::empty_bytes() + return RelativeName::empty_bytes() .chain(self.zonefile.origin()?) - .map_err(|_| EntryError::bad_dname()); + .map_err(|_| EntryError::bad_name()); } else { return unsafe { - RelativeDname::from_octets_unchecked( + RelativeName::from_octets_unchecked( self.zonefile.buf.split_to(write).freeze(), ) - .chain(Dname::root()) - .map_err(|_| EntryError::bad_dname()) + .chain(Name::root()) + .map_err(|_| EntryError::bad_name()) }; } } @@ -659,28 +659,28 @@ impl<'a> Scanner for EntryScanner<'a> { // continue to the next label. if write == 1 { if self.zonefile.buf.next_symbol()?.is_some() { - return Err(EntryError::bad_dname()); + return Err(EntryError::bad_name()); } else { self.zonefile.buf.next_item()?; - return Ok(RelativeDname::empty() - .chain(Dname::root()) + return Ok(RelativeName::empty() + .chain(Name::root()) .expect("failed to make root name")); } } if write > 254 { - return Err(EntryError::bad_dname()); + return Err(EntryError::bad_name()); } } Some(false) => { // Reached end of token. This means we have a relative - // dname. + // name. self.zonefile.buf.next_item()?; return unsafe { - RelativeDname::from_octets_unchecked( + RelativeName::from_octets_unchecked( self.zonefile.buf.split_to(write).freeze(), ) .chain(self.zonefile.origin()?) - .map_err(|_| EntryError::bad_dname()) + .map_err(|_| EntryError::bad_name()) }; } } @@ -850,7 +850,7 @@ impl<'a> EntryScanner<'a> { // A char symbol. Just increase the write index. *write += 1; if *write >= latest { - return Err(EntryError::bad_dname()); + return Err(EntryError::bad_name()); } } None => { @@ -891,7 +891,7 @@ impl<'a> EntryScanner<'a> { self.zonefile.buf.buf[*write] = sym.into_octet()?; *write += 1; if *write >= latest { - return Err(EntryError::bad_dname()); + return Err(EntryError::bad_name()); } } } @@ -1435,8 +1435,8 @@ impl EntryError { EntryError("bad charstr") } - fn bad_dname() -> Self { - EntryError("bad dname") + fn bad_name() -> Self { + EntryError("bad name") } fn unbalanced_parens() -> Self { @@ -1523,7 +1523,7 @@ impl std::error::Error for Error {} #[cfg(feature = "std")] mod test { use super::*; - use crate::base::ParsedDname; + use crate::base::ParsedName; use octseq::Parser; use std::vec::Vec; @@ -1564,10 +1564,9 @@ mod test { #[derive(serde::Deserialize)] #[allow(clippy::type_complexity)] struct TestCase { - origin: Dname, + origin: Name, zonefile: std::string::String, - result: - Vec, ZoneRecordData>>>, + result: Vec, ZoneRecordData>>>, } impl TestCase { @@ -1590,8 +1589,8 @@ mod test { let mut parser = Parser::from_ref(&buf); let parsed = Record::< - ParsedDname, - ZoneRecordData>, + ParsedName, + ZoneRecordData>, >::parse(&mut parser) .unwrap() .unwrap(); diff --git a/src/zonefile/mod.rs b/src/zonefile/mod.rs index 6cf90f22b..c208be2af 100644 --- a/src/zonefile/mod.rs +++ b/src/zonefile/mod.rs @@ -2,7 +2,4 @@ #![cfg(feature = "zonefile")] #![cfg_attr(docsrs, doc(cfg(feature = "zonefile")))] -pub mod error; pub mod inplace; -#[cfg(feature = "unstable-zonetree")] -pub mod parsed; diff --git a/src/zonetree/answer.rs b/src/zonetree/answer.rs index 1fa34afd8..12ae7cc8b 100644 --- a/src/zonetree/answer.rs +++ b/src/zonetree/answer.rs @@ -1,14 +1,17 @@ //! Answers to zone tree queries. -//------------ Answer -------------------------------------------------------- +use octseq::Octets; -use super::{SharedRr, SharedRrset, StoredDname}; use crate::base::iana::Rcode; use crate::base::message_builder::AdditionalBuilder; use crate::base::wire::Composer; use crate::base::Message; use crate::base::MessageBuilder; -use octseq::Octets; + +use super::types::StoredDname; +use super::{SharedRr, SharedRrset}; + +//------------ Answer -------------------------------------------------------- /// A DNS answer to a query against a [`Zone`]. /// diff --git a/src/zonefile/error.rs b/src/zonetree/error.rs similarity index 63% rename from src/zonefile/error.rs rename to src/zonetree/error.rs index eed516f63..917629347 100644 --- a/src/zonefile/error.rs +++ b/src/zonetree/error.rs @@ -1,19 +1,23 @@ //! Zone related errors. -//------------ ZoneCutError -------------------------------------------------- - use std::fmt::Display; +use std::io; use std::vec::Vec; use crate::base::Rtype; -use crate::zonetree::{StoredDname, StoredRecord}; +use crate::zonefile::inplace; -use super::inplace; +use super::types::{StoredDname, StoredRecord}; + +//------------ ZoneCutError -------------------------------------------------- /// A zone cut is not valid with respect to the zone's apex. #[derive(Clone, Copy, Debug)] pub enum ZoneCutError { + /// A zone cut cannot exist outside of the zone. OutOfZone, + + /// A zone cut cannot exist at the apex of a zone. ZoneCutAtApex, } @@ -37,7 +41,10 @@ impl Display for ZoneCutError { /// A CNAME is not valid with respect to the zone's apex. #[derive(Clone, Copy, Debug)] pub enum CnameError { + /// A CNAME cannot exist outside of the zone. OutOfZone, + + /// A CNAME cannot exist at the apex of a zone. CnameAtApex, } @@ -64,6 +71,7 @@ pub struct OutOfZone; //------------ RecordError --------------------------------------------------- +/// A zone file record is invalid. #[derive(Clone, Debug)] pub enum RecordError { /// The class of the record does not match the class of the zone. @@ -125,15 +133,20 @@ impl Display for RecordError { /// A set of problems relating to a zone. #[derive(Clone, Debug, Default)] pub struct ZoneErrors { - errors: Vec<(StoredDname, OwnerError)>, + errors: Vec<(StoredDname, ContextError)>, } impl ZoneErrors { - pub fn add_error(&mut self, name: StoredDname, error: OwnerError) { + /// Add an error to the set. + pub fn add_error(&mut self, name: StoredDname, error: ContextError) { self.errors.push((name, error)) } - pub fn into_result(self) -> Result<(), Self> { + /// Unwrap the set of errors. + /// + /// Returns the set of errors as [Result::Err(ZonErrors)] or [Result::Ok] + /// if the set is empty. + pub fn unwrap(self) -> Result<(), Self> { if self.errors.is_empty() { Ok(()) } else { @@ -152,10 +165,11 @@ impl Display for ZoneErrors { } } -//------------ OwnerError --------------------------------------------------- +//------------ ContextError -------------------------------------------------- +/// A zone file record is not correct for its context. #[derive(Clone, Debug)] -pub enum OwnerError { +pub enum ContextError { /// A NS RRset is missing at a zone cut. /// /// (This happens if there is only a DS RRset.) @@ -171,17 +185,70 @@ pub enum OwnerError { OutOfZone(Rtype), } -impl Display for OwnerError { +impl Display for ContextError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match self { - OwnerError::MissingNs => write!(f, "Missing NS"), - OwnerError::InvalidZonecut(err) => { + ContextError::MissingNs => write!(f, "Missing NS"), + ContextError::InvalidZonecut(err) => { write!(f, "Invalid zone cut: {err}") } - OwnerError::InvalidCname(err) => { + ContextError::InvalidCname(err) => { write!(f, "Invalid CNAME: {err}") } - OwnerError::OutOfZone(err) => write!(f, "Out of zone: {err}"), + ContextError::OutOfZone(err) => write!(f, "Out of zone: {err}"), + } + } +} + +//------------ ZoneTreeModificationError ------------------------------------- + +/// An attempt to modify a [`ZoneTree`] failed. +/// +/// [`ZoneTree`]: crate::zonetree::ZoneTree +#[derive(Debug)] +pub enum ZoneTreeModificationError { + /// The specified zone already exists. + ZoneExists, + + /// The specified zone does not exist. + ZoneDoesNotExist, + + /// The operation failed due to an I/O error. + Io(io::Error), +} + +impl From for ZoneTreeModificationError { + fn from(src: io::Error) -> Self { + ZoneTreeModificationError::Io(src) + } +} + +impl From for io::Error { + fn from(src: ZoneTreeModificationError) -> Self { + match src { + ZoneTreeModificationError::Io(err) => err, + ZoneTreeModificationError::ZoneDoesNotExist => { + io::Error::new(io::ErrorKind::Other, "zone does not exist") + } + ZoneTreeModificationError::ZoneExists => { + io::Error::new(io::ErrorKind::Other, "zone exists") + } + } + } +} + +impl Display for ZoneTreeModificationError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + ZoneTreeModificationError::ZoneExists => { + write!(f, "Zone already exists") + } + ZoneTreeModificationError::ZoneDoesNotExist => { + write!(f, "Zone does not exist") + } + ZoneTreeModificationError::Io(err) => { + write!(f, "Io error: {err}") + } } } } diff --git a/src/zonetree/in_memory/builder.rs b/src/zonetree/in_memory/builder.rs index 8e17363ea..38544129e 100644 --- a/src/zonetree/in_memory/builder.rs +++ b/src/zonetree/in_memory/builder.rs @@ -4,12 +4,10 @@ use std::sync::Arc; use std::vec::Vec; use crate::base::iana::Class; -use crate::base::name::{Label, ToDname}; -use crate::zonefile::error::{CnameError, OutOfZone, ZoneCutError}; -use crate::zonetree::types::ZoneCut; -use crate::zonetree::{ - SharedRr, SharedRrset, StoredDname, StoredRecord, Zone, -}; +use crate::base::name::{Label, ToName}; +use crate::zonetree::error::{CnameError, OutOfZone, ZoneCutError}; +use crate::zonetree::types::{StoredDname, StoredRecord, ZoneCut}; +use crate::zonetree::{SharedRr, SharedRrset, Zone}; use super::nodes::{Special, ZoneApex, ZoneNode}; use super::versioned::Version; @@ -53,7 +51,7 @@ use super::versioned::Version; /// /// [module docs]: crate::zonetree /// [`inplace::Zonefile`]: crate::zonefile::inplace::Zonefile -/// [`parsed::Zonefile`]: crate::zonefile::parsed::Zonefile +/// [`parsed::Zonefile`]: crate::zonetree::parsed::Zonefile /// [presentation format]: /// https://datatracker.ietf.org/doc/html/rfc9499#section-2-1.16.1.6.1.3 /// [`ReadableZone::query`]: crate::zonetree::ReadableZone::query() @@ -88,7 +86,7 @@ impl ZoneBuilder { /// Inserts a [`SharedRrset`] for the given owner name. pub fn insert_rrset( &mut self, - name: &impl ToDname, + name: &impl ToName, rrset: SharedRrset, ) -> Result<(), OutOfZone> { match self.get_node(self.apex.prepare_name(name)?) { @@ -119,7 +117,7 @@ impl ZoneBuilder { /// https://datatracker.ietf.org/doc/html/rfc4033#section-2 pub fn insert_zone_cut( &mut self, - name: &impl ToDname, + name: &impl ToName, ns: SharedRrset, ds: Option, glue: Vec, @@ -142,7 +140,7 @@ impl ZoneBuilder { /// [`Cname`]: crate::rdata::rfc1035::Cname pub fn insert_cname( &mut self, - name: &impl ToDname, + name: &impl ToName, cname: SharedRr, ) -> Result<(), CnameError> { let node = self.get_node(self.apex.prepare_name(name)?)?; diff --git a/src/zonetree/in_memory/nodes.rs b/src/zonetree/in_memory/nodes.rs index e7d789589..433296ce8 100644 --- a/src/zonetree/in_memory/nodes.rs +++ b/src/zonetree/in_memory/nodes.rs @@ -12,12 +12,12 @@ use parking_lot::{ use tokio::sync::Mutex; use crate::base::iana::{Class, Rtype}; -use crate::base::name::{Label, OwnedLabel, ToDname, ToLabelIter}; -use crate::zonefile::error::{CnameError, OutOfZone, ZoneCutError}; -use crate::zonetree::types::ZoneCut; +use crate::base::name::{Label, OwnedLabel, ToLabelIter, ToName}; +use crate::zonetree::error::{CnameError, OutOfZone, ZoneCutError}; +use crate::zonetree::types::{StoredDname, ZoneCut}; use crate::zonetree::walk::WalkState; use crate::zonetree::{ - ReadableZone, SharedRr, SharedRrset, StoredDname, WritableZone, ZoneStore, + ReadableZone, SharedRr, SharedRrset, WritableZone, ZoneStore, }; use super::read::ReadZone; @@ -69,7 +69,7 @@ impl ZoneApex { pub fn prepare_name<'l>( &self, - qname: &'l impl ToDname, + qname: &'l impl ToName, ) -> Result + Clone, OutOfZone> { let mut qname = qname.iter_labels().rev(); for apex_label in self.name().iter_labels().rev() { diff --git a/src/zonetree/in_memory/read.rs b/src/zonetree/in_memory/read.rs index afd480486..de71e11d2 100644 --- a/src/zonetree/in_memory/read.rs +++ b/src/zonetree/in_memory/read.rs @@ -7,8 +7,7 @@ use bytes::Bytes; use crate::base::iana::{Rcode, Rtype}; use crate::base::name::Label; -use crate::base::Dname; -use crate::zonefile::error::OutOfZone; +use crate::base::Name; use crate::zonetree::answer::{Answer, AnswerAuthority}; use crate::zonetree::types::ZoneCut; use crate::zonetree::walk::WalkState; @@ -17,6 +16,7 @@ use crate::zonetree::{ReadableZone, Rrset, SharedRr, SharedRrset, WalkOp}; use super::nodes::{NodeChildren, NodeRrsets, Special, ZoneApex, ZoneNode}; use super::versioned::Version; use super::versioned::VersionMarker; +use crate::zonetree::error::OutOfZone; //------------ ReadZone ------------------------------------------------------ @@ -269,7 +269,7 @@ impl ReadableZone for ReadZone { fn query( &self, - qname: Dname, + qname: Name, qtype: Rtype, ) -> Result { let mut qname = self.apex.prepare_name(&qname)?; diff --git a/src/zonetree/mod.rs b/src/zonetree/mod.rs index efadcee74..c8bb11c79 100644 --- a/src/zonetree/mod.rs +++ b/src/zonetree/mod.rs @@ -8,7 +8,7 @@ //! //! Individual `Zone`s within the tree can be looked up by containing or exact //! name, and then one can [`query`] the found `Zone` by [`Class`], [`Rtype`] and -//! [`Dname`] to produce an [`Answer`], which in turn can be used to produce a +//! [`Name`] to produce an [`Answer`], which in turn can be used to produce a //! response [`Message`] for serving to a DNS client. //! //! Trees can also be iterated over to inspect or export their content. @@ -36,8 +36,9 @@ //! //! ``` //! use domain::base::iana::{Class, Rcode, Rtype}; -//! use domain::base::name::Dname; -//! use domain::zonefile::{inplace, parsed}; +//! use domain::base::name::Name; +//! use domain::zonefile::inplace; +//! use domain::zonetree::parsed; //! use domain::zonetree::{Answer, Zone, ZoneBuilder, ZoneTree}; //! //! // Prepare some zone file bytes to demonstrate with. @@ -62,7 +63,7 @@ //! tree.insert_zone(zone).unwrap(); //! //! // Query the zone tree. -//! let qname = Dname::bytes_from_str("example.com").unwrap(); +//! let qname = Name::bytes_from_str("example.com").unwrap(); //! let qtype = Rtype::A; //! let found_zone = tree.find_zone(&qname, Class::IN).unwrap(); //! let res: Answer = found_zone.read().query(qname, qtype).unwrap(); @@ -74,17 +75,19 @@ //! [`query`]: crate::zonetree::ReadableZone::query //! [`Class`]: crate::base::iana::Class //! [`Rtype`]: crate::base::iana::Rtype -//! [`Dname`]: crate::base::name::Dname +//! [`Name`]: crate::base::name::Name //! [`Message`]: crate::base::Message //! [`NoError`]: crate::base::iana::code::Rcode::NOERROR //! [`NxDomain`]: crate::base::iana::code::Rcode::NXDOMAIN //! [`ZoneBuilder`]: in_memory::ZoneBuilder mod answer; +pub mod error; mod in_memory; +pub mod parsed; mod traits; mod tree; -mod types; +pub mod types; mod walk; mod zone; @@ -94,8 +97,6 @@ pub use self::traits::{ ReadableZone, WritableZone, WritableZoneNode, ZoneStore, }; pub use self::tree::ZoneTree; -pub use self::types::{ - Rrset, SharedRr, SharedRrset, StoredDname, StoredRecord, -}; +pub use self::types::{Rrset, SharedRr, SharedRrset}; pub use self::walk::WalkOp; pub use self::zone::Zone; diff --git a/src/zonefile/parsed.rs b/src/zonetree/parsed.rs similarity index 83% rename from src/zonefile/parsed.rs rename to src/zonetree/parsed.rs index a5ab05a8b..a153c1c6a 100644 --- a/src/zonefile/parsed.rs +++ b/src/zonetree/parsed.rs @@ -1,23 +1,23 @@ -//! Importing from and (in future) exporting to a zonefiles. +//! Importing from and (in future) exporting to a zone files. use std::collections::{BTreeMap, HashMap}; use std::vec::Vec; use tracing::trace; -use super::error::{OwnerError, RecordError, ZoneErrors}; -use super::inplace::{self, Entry}; - +use super::error::{ContextError, RecordError, ZoneErrors}; use crate::base::iana::{Class, Rtype}; -use crate::base::name::FlattenInto; -use crate::base::ToDname; +use crate::base::name::{FlattenInto, ToName}; use crate::rdata::ZoneRecordData; +use crate::zonefile::inplace::{self, Entry}; use crate::zonetree::ZoneBuilder; -use crate::zonetree::{Rrset, SharedRr, StoredDname, StoredRecord}; +use crate::zonetree::{Rrset, SharedRr}; + +use super::types::{StoredDname, StoredRecord}; //------------ Zonefile ------------------------------------------------------ -/// A parsed sanity checked representation of a zonefile. +/// A parsed sanity checked representation of a zone file. /// /// This type eases creation of a [`ZoneBuilder`] from a collection of /// [`StoredRecord`]s, e.g. and accepts only records that are valid within @@ -31,7 +31,7 @@ use crate::zonetree::{Rrset, SharedRr, StoredDname, StoredRecord}; /// Getter functions provide insight into the classification results. /// /// When ready the [`ZoneBuilder::try_from`] function can be used to convert -/// the parsed zonefile into a pre-populated [`ZoneBuilder`]. +/// the parsed zone file into a pre-populated [`ZoneBuilder`]. /// /// # Usage /// @@ -62,6 +62,8 @@ pub struct Zonefile { } impl Zonefile { + /// Creates an empty in-memory zone file representation for the given apex + /// and class. pub fn new(apex: StoredDname, class: Class) -> Self { Zonefile { origin: Some(apex), @@ -72,11 +74,15 @@ impl Zonefile { } impl Zonefile { + /// Sets the origin of the zone. + /// + /// If parsing a zone file one might call this method on encoutering an + /// `$ORIGIN` directive. pub fn set_origin(&mut self, origin: StoredDname) { self.origin = Some(origin) } - /// Inserts the record into the zone file. + /// Inserts the given record into the zone file. pub fn insert( &mut self, record: StoredRecord, @@ -89,7 +95,7 @@ impl Zonefile { if record.rtype() != Rtype::SOA { return Err(RecordError::MissingSoa(record)); } else { - let apex = record.owner().to_dname(); + let apex = record.owner().to_name(); self.class = Some(record.class()); self.origin = Some(apex); } @@ -154,26 +160,45 @@ impl Zonefile { } impl Zonefile { + /// The [origin] of the zone. + /// + /// [origin]: https://datatracker.ietf.org/doc/html/rfc9499#section-7-2.8 pub fn origin(&self) -> Option<&StoredDname> { self.origin.as_ref() } + /// The [class] of the zone. + /// + /// [class]: https://datatracker.ietf.org/doc/html/rfc9499#section-4-2.2 pub fn class(&self) -> Option { self.class } + /// The collection of normal records in the zone. + /// + /// Normal records are all records in the zone that are neither top of + /// zone administrative records, zone cuts nor glue records. pub fn normal(&self) -> &Owners { &self.normal } + /// The collection of [zone cut] records in the zone. + /// + /// [zone cut]: https://datatracker.ietf.org/doc/html/rfc9499#section-7-2.16 pub fn zone_cuts(&self) -> &Owners { &self.zone_cuts } + /// The collection of [CNAME] records in the zone. + /// + /// [CNAME]: https://datatracker.ietf.org/doc/html/rfc9499#section-7-2.16 pub fn cnames(&self) -> &Owners { &self.cnames } + /// The collection of records that lie outside the zone. + /// + /// In a valid zone this collection will be empty. pub fn out_of_zone(&self) -> &Owners { &self.out_of_zone } @@ -195,7 +220,7 @@ impl TryFrom for ZoneBuilder { let ns = match cut.ns { Some(ns) => ns.into_shared(), None => { - zone_err.add_error(name, OwnerError::MissingNs); + zone_err.add_error(name, ContextError::MissingNs); continue; } }; @@ -210,14 +235,14 @@ impl TryFrom for ZoneBuilder { } if let Err(err) = builder.insert_zone_cut(&name, ns, ds, glue) { - zone_err.add_error(name, OwnerError::InvalidZonecut(err)) + zone_err.add_error(name, ContextError::InvalidZonecut(err)) } } // Now insert all the CNAMEs. for (name, rrset) in zonefile.cnames.into_iter() { if let Err(err) = builder.insert_cname(&name, rrset) { - zone_err.add_error(name, OwnerError::InvalidCname(err)) + zone_err.add_error(name, ContextError::InvalidCname(err)) } } @@ -227,7 +252,7 @@ impl TryFrom for ZoneBuilder { if builder.insert_rrset(&name, rrset.into_shared()).is_err() { zone_err.add_error( name.clone(), - OwnerError::OutOfZone(rtype), + ContextError::OutOfZone(rtype), ); } } @@ -238,11 +263,11 @@ impl TryFrom for ZoneBuilder { for (name, rrsets) in zonefile.out_of_zone.into_iter() { for (rtype, _) in rrsets.into_iter() { zone_err - .add_error(name.clone(), OwnerError::OutOfZone(rtype)); + .add_error(name.clone(), ContextError::OutOfZone(rtype)); } } - zone_err.into_result().map(|_| builder) + zone_err.unwrap().map(|_| builder) } } @@ -269,6 +294,7 @@ impl TryFrom for Zonefile { //------------ Owners -------------------------------------------------------- +/// A set of records of a common type within a zone file. #[derive(Clone)] pub struct Owners { owners: BTreeMap, @@ -348,6 +374,9 @@ impl Default for Owners { //------------ Normal -------------------------------------------------------- +/// A collection of "normal" zone file records. +/// +/// I.e. zone file records that are not CNAMEs or zone cuts. #[derive(Clone, Default)] pub struct Normal { records: HashMap, @@ -374,6 +403,7 @@ impl Normal { //------------ ZoneCut ------------------------------------------------------- +/// The set of records that comprise a zone cut within a zone file. #[derive(Clone, Default)] pub struct ZoneCut { ns: Option, diff --git a/src/zonetree/traits.rs b/src/zonetree/traits.rs index 6d26afa4c..b9f4855b0 100644 --- a/src/zonetree/traits.rs +++ b/src/zonetree/traits.rs @@ -16,12 +16,12 @@ use std::sync::Arc; use crate::base::iana::Class; use crate::base::name::Label; -use crate::base::{Dname, Rtype}; -use crate::zonefile::error::OutOfZone; +use crate::base::{Name, Rtype}; use super::answer::Answer; -use super::types::ZoneCut; -use super::{SharedRr, SharedRrset, StoredDname, WalkOp}; +use super::error::OutOfZone; +use super::types::{StoredDname, ZoneCut}; +use super::{SharedRr, SharedRrset, WalkOp}; //------------ ZoneStore ----------------------------------------------------- @@ -77,7 +77,7 @@ pub trait ReadableZone: Send { /// https://www.rfc-editor.org/rfc/rfc1034#section-3.7.1 fn query( &self, - _qname: Dname, + _qname: Name, _qtype: Rtype, ) -> Result; @@ -92,7 +92,7 @@ pub trait ReadableZone: Send { /// Asynchronous variant of `query()`. fn query_async( &self, - qname: Dname, + qname: Name, qtype: Rtype, ) -> Pin> + Send>> { Box::pin(ready(self.query(qname, qtype))) @@ -126,12 +126,13 @@ pub trait WritableZone { /// Complete a write operation for the zone. /// - /// This function commits the changes accumulated since [`open`] was - /// invoked. Clients who obtain a [`ReadableZone`] interface to this zone - /// _before_ this function has been called will not see any of the changes - /// made since the last commit. Only clients who obtain a [`ReadableZone`] - /// _after_ invoking this function will be able to see the changes made - /// since [`open`] was called. called. + /// This function commits the changes accumulated since + /// [`WritableZone::open`] was invoked. Clients who obtain a + /// [`ReadableZone`] interface to this zone _before_ this function has + /// been called will not see any of the changes made since the last + /// commit. Only clients who obtain a [`ReadableZone`] _after_ invoking + /// this function will be able to see the changes made since + /// [`WritableZone::open`] was called. called. fn commit( &mut self, ) -> Pin>>>; diff --git a/src/zonetree/tree.rs b/src/zonetree/tree.rs index ed9f00758..8c8982089 100644 --- a/src/zonetree/tree.rs +++ b/src/zonetree/tree.rs @@ -1,14 +1,15 @@ //! The known set of zones. -use super::zone::Zone; -use crate::base::iana::Class; -use crate::base::name::{Label, OwnedLabel, ToDname, ToLabelIter}; use std::collections::hash_map; use std::collections::HashMap; -use std::fmt::Display; -use std::io; use std::vec::Vec; +use crate::base::iana::Class; +use crate::base::name::{Label, OwnedLabel, ToLabelIter, ToName}; + +use super::error::ZoneTreeModificationError; +use super::zone::Zone; + //------------ ZoneTree ------------------------------------------------------ /// A multi-rooted [`Zone`] hierarchy. @@ -28,7 +29,7 @@ impl ZoneTree { /// Gets a [`Zone`] for the given apex name and CLASS, if any. pub fn get_zone( &self, - apex_name: &impl ToDname, + apex_name: &impl ToName, class: Class, ) -> Option<&Zone> { self.roots @@ -54,7 +55,7 @@ impl ZoneTree { /// any. pub fn find_zone( &self, - qname: &impl ToDname, + qname: &impl ToName, class: Class, ) -> Option<&Zone> { self.roots.get(class)?.find_zone(qname.iter_labels().rev()) @@ -68,7 +69,7 @@ impl ZoneTree { /// Removes the specified [`Zone`], if any. pub fn remove_zone( &mut self, - apex_name: &impl ToDname, + apex_name: &impl ToName, class: Class, ) -> Result<(), ZoneTreeModificationError> { if let Some(root) = self.roots.get_mut(class) { @@ -256,48 +257,3 @@ impl<'a> Iterator for NodesIter<'a> { Some(node) } } - -//============ Error Types =================================================== - -#[derive(Debug)] -pub enum ZoneTreeModificationError { - ZoneExists, - ZoneDoesNotExist, - Io(io::Error), -} - -impl From for ZoneTreeModificationError { - fn from(src: io::Error) -> Self { - ZoneTreeModificationError::Io(src) - } -} - -impl From for io::Error { - fn from(src: ZoneTreeModificationError) -> Self { - match src { - ZoneTreeModificationError::Io(err) => err, - ZoneTreeModificationError::ZoneDoesNotExist => { - io::Error::new(io::ErrorKind::Other, "zone does not exist") - } - ZoneTreeModificationError::ZoneExists => { - io::Error::new(io::ErrorKind::Other, "zone exists") - } - } - } -} - -impl Display for ZoneTreeModificationError { - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - match self { - ZoneTreeModificationError::ZoneExists => { - write!(f, "Zone already exists") - } - ZoneTreeModificationError::ZoneDoesNotExist => { - write!(f, "Zone does not exist") - } - ZoneTreeModificationError::Io(err) => { - write!(f, "Io error: {err}") - } - } - } -} diff --git a/src/zonetree/types.rs b/src/zonetree/types.rs index 26dd2e072..938dc30e1 100644 --- a/src/zonetree/types.rs +++ b/src/zonetree/types.rs @@ -1,18 +1,21 @@ -use crate::base::name::Dname; -use crate::base::rdata::RecordData; -use crate::base::record::Record; -use crate::base::{iana::Rtype, Ttl}; -use crate::rdata::ZoneRecordData; -use bytes::Bytes; -use serde::{Deserialize, Serialize}; +//! Zone tree related types. + use std::ops; use std::sync::Arc; use std::vec::Vec; +use bytes::Bytes; +use serde::{Deserialize, Serialize}; + +use crate::base::rdata::RecordData; +use crate::base::{iana::Rtype, Ttl}; +use crate::base::{Name, Record}; +use crate::rdata::ZoneRecordData; + //------------ Type Aliases -------------------------------------------------- -/// A [`Bytes`] backed [`Dname`]. -pub type StoredDname = Dname; +/// A [`Bytes`] backed [`Name`]. +pub type StoredDname = Name; /// A [`Bytes`] backed [`ZoneRecordData`]. pub type StoredRecordData = ZoneRecordData; @@ -229,10 +232,18 @@ impl Serialize for SharedRrset { //------------ ZoneCut ------------------------------------------------------- +/// The representation of a zone cut within a zone tree. #[derive(Clone, Debug)] pub struct ZoneCut { + /// The owner name where the zone cut occurs. pub name: StoredDname, + + /// The NS record at the zone cut. pub ns: SharedRrset, + + /// The DS record at the zone cut (optional). pub ds: Option, + + /// Zero or more glue records at the zone cut. pub glue: Vec, } diff --git a/src/zonetree/walk.rs b/src/zonetree/walk.rs index 101c872dd..3c350abe2 100644 --- a/src/zonetree/walk.rs +++ b/src/zonetree/walk.rs @@ -6,13 +6,13 @@ use bytes::Bytes; use super::Rrset; use crate::base::name::OwnedLabel; -use crate::base::{Dname, DnameBuilder}; +use crate::base::{Name, NameBuilder}; /// A callback function invoked for each leaf node visited while walking a /// [`Zone`]. /// /// [`Zone`]: super::Zone -pub type WalkOp = Box, &Rrset) + Send + Sync>; +pub type WalkOp = Box, &Rrset) + Send + Sync>; struct WalkStateInner { op: WalkOp, @@ -49,11 +49,11 @@ impl WalkState { pub(super) fn op(&self, rrset: &Rrset) { if let Some(inner) = &self.inner { let labels = inner.label_stack.lock().unwrap(); - let mut dname = DnameBuilder::new_bytes(); + let mut dname = NameBuilder::new_bytes(); for label in labels.iter().rev() { dname.append_label(label.as_slice()).unwrap(); } - let owner = dname.into_dname().unwrap(); + let owner = dname.into_name().unwrap(); (inner.op)(owner, rrset); } } diff --git a/src/zonetree/zone.rs b/src/zonetree/zone.rs index bdac14649..3524f7ac8 100644 --- a/src/zonetree/zone.rs +++ b/src/zonetree/zone.rs @@ -5,12 +5,13 @@ use std::pin::Pin; use std::sync::Arc; use crate::base::iana::Class; -use crate::zonefile::error::{RecordError, ZoneErrors}; -use crate::zonefile::{inplace, parsed}; +use crate::zonefile::inplace; +use super::error::{RecordError, ZoneErrors}; use super::in_memory::ZoneBuilder; use super::traits::WritableZone; -use super::{ReadableZone, StoredDname, ZoneStore}; +use super::types::StoredDname; +use super::{parsed, ReadableZone, ZoneStore}; //------------ Zone ---------------------------------------------------------- diff --git a/tests/net-client-cache.rs b/tests/net-client-cache.rs index 1cab2bf12..2708db89d 100644 --- a/tests/net-client-cache.rs +++ b/tests/net-client-cache.rs @@ -14,7 +14,7 @@ use rstest::rstest; use tracing::instrument; // use domain::net::client::clock::{Clock, FakeClock}; -use domain::base::{Dname, MessageBuilder, Rtype}; +use domain::base::{MessageBuilder, Name, Rtype}; use domain::net::client::request::Error::NoTransportAvailable; use domain::net::client::request::{RequestMessage, SendRequest}; use domain::net::client::{cache, multi_stream, redundant}; @@ -83,7 +83,7 @@ async fn test_transport_error() { let mut msg = MessageBuilder::new_vec(); msg.header_mut().set_rd(true); let mut msg = msg.question(); - msg.push((Dname::vec_from_str("example.com").unwrap(), Rtype::AAAA)) + msg.push((Name::vec_from_str("example.com").unwrap(), Rtype::AAAA)) .unwrap(); let req = RequestMessage::new(msg); diff --git a/tests/net-server.rs b/tests/net-server.rs index 99b9d9fd6..afcb51521 100644 --- a/tests/net-server.rs +++ b/tests/net-server.rs @@ -16,8 +16,8 @@ use tracing::instrument; use tracing::{trace, warn}; use domain::base::iana::Rcode; +use domain::base::name::{Name, ToName}; use domain::base::wire::Composer; -use domain::base::{Dname, ToDname}; use domain::net::client::{dgram, stream}; use domain::net::server; use domain::net::server::buf::VecBufSource; @@ -272,8 +272,8 @@ fn test_service( ) -> ServiceResult> { fn as_record_and_dname( r: ScannedRecord, - ) -> Option<(ScannedRecord, Dname>)> { - let dname = r.owner().to_dname(); + ) -> Option<(ScannedRecord, Name>)> { + let dname = r.owner().to_name(); Some((r, dname)) } diff --git a/tests/net/stelline/matches.rs b/tests/net/stelline/matches.rs index 07ab71726..5b701ef1b 100644 --- a/tests/net/stelline/matches.rs +++ b/tests/net/stelline/matches.rs @@ -1,13 +1,12 @@ +use crate::net::stelline::parse_query; +use crate::net::stelline::parse_stelline::{Entry, Matches, Reply}; use domain::base::iana::{Opcode, OptRcode, Rtype}; use domain::base::opt::{Opt, OptRecord}; -use domain::base::{Message, ParsedDname, QuestionSection, RecordSection}; +use domain::base::{Message, ParsedName, QuestionSection, RecordSection}; use domain::dep::octseq::Octets; use domain::rdata::ZoneRecordData; use domain::zonefile::inplace::Entry as ZonefileEntry; -use crate::net::stelline::parse_query; -use crate::net::stelline::parse_stelline::{Entry, Matches, Reply}; - pub fn match_msg<'a, Octs: AsRef<[u8]> + Clone + Octets + 'a>( entry: &Entry, msg: &'a Message, @@ -401,7 +400,7 @@ fn match_section< } let msg_rdata = msg_rr .clone() - .into_record::>>() + .into_record::>>() .unwrap() .unwrap(); println!( diff --git a/tests/net/stelline/parse_query.rs b/tests/net/stelline/parse_query.rs index 9c510410d..c82df8aea 100644 --- a/tests/net/stelline/parse_query.rs +++ b/tests/net/stelline/parse_query.rs @@ -20,7 +20,7 @@ use bytes::{Buf, BufMut, Bytes, BytesMut}; use domain::base::charstr::CharStr; use domain::base::iana::{Class, Rtype}; -use domain::base::name::{Chain, Dname, RelativeDname, ToDname}; +use domain::base::name::{Chain, Name, RelativeName, ToName}; use domain::base::scan::{ BadSymbol, ConvertSymbols, EntrySymbol, Scan, Scanner, ScannerError, Symbol, SymbolOctetsError, @@ -31,7 +31,7 @@ use domain::dep::octseq::str::Str; //------------ Type Aliases -------------------------------------------------- /// The type used for scanned domain names. -pub type ScannedDname = Chain, Dname>; +pub type ScannedDname = Chain, Name>; /// The type used for scanned records. @@ -60,7 +60,7 @@ pub struct Zonefile { buf: SourceBuf, /// The current origin. - origin: Option>, + origin: Option>, /// The last owner. last_owner: Option, @@ -89,7 +89,7 @@ impl Zonefile { fn with_buf(buf: SourceBuf) -> Self { Zonefile { buf, - origin: Some(Dname::root_bytes()), + origin: Some(Name::root_bytes()), last_owner: None, last_ttl: Some(Ttl::ZERO), last_class: None, @@ -160,7 +160,7 @@ impl Zonefile { } /// Returns the origin name of the zonefile. - fn get_origin(&self) -> Result, EntryError> { + fn get_origin(&self) -> Result, EntryError> { self.origin .as_ref() .cloned() @@ -198,7 +198,7 @@ pub enum Entry { /// The initial origin name of the included file, if provided. #[allow(dead_code)] - origin: Option>, + origin: Option>, }, } @@ -215,7 +215,7 @@ enum ScannedEntry { Entry(Entry), /// An `$ORIGIN` directive changing the origin name. - Origin(Dname), + Origin(Name), /// A `$TTL` directive changing the default TTL if it isn’t given. Ttl(Ttl), @@ -294,7 +294,7 @@ impl<'a> EntryScanner<'a> { /// Scans a regular record with an owner name of `@`. fn scan_at_record(&mut self) -> Result { - let owner = RelativeDname::empty_bytes() + let owner = RelativeName::empty_bytes() .chain(match self.zonefile.origin.as_ref().cloned() { Some(origin) => origin, None => return Err(EntryError::missing_origin()), @@ -376,13 +376,13 @@ impl<'a> EntryScanner<'a> { fn scan_control(&mut self) -> Result { let ctrl = self.scan_string()?; if ctrl.eq_ignore_ascii_case("$ORIGIN") { - let origin = self.scan_dname()?.to_dname(); + let origin = self.scan_name()?.to_name(); self.zonefile.buf.require_line_feed()?; Ok(ScannedEntry::Origin(origin)) } else if ctrl.eq_ignore_ascii_case("$INCLUDE") { let path = self.scan_string()?; let origin = if !self.zonefile.buf.is_line_feed() { - Some(self.scan_dname()?.to_dname()) + Some(self.scan_name()?.to_name()) } else { None }; @@ -401,7 +401,7 @@ impl<'a> EntryScanner<'a> { impl<'a> Scanner for EntryScanner<'a> { type Octets = Bytes; type OctetsBuilder = BytesMut; - type Dname = ScannedDname; + type Name = ScannedDname; type Error = EntryError; fn has_space(&self) -> bool { @@ -533,7 +533,7 @@ impl<'a> Scanner for EntryScanner<'a> { Ok(res) } - fn scan_dname(&mut self) -> Result { + fn scan_name(&mut self) -> Result { // Because the labels in a domain name have their content preceeded // by the length octet, an unescaped domain name can be almost as is // if we have one extra octet to the left. Luckily, we always do @@ -559,15 +559,15 @@ impl<'a> Scanner for EntryScanner<'a> { // have an empty domain name which is just the origin. self.zonefile.buf.next_item()?; if start == 0 { - return RelativeDname::empty_bytes() + return RelativeName::empty_bytes() .chain(self.zonefile.get_origin()?) .map_err(|_| EntryError::bad_dname()); } else { return unsafe { - RelativeDname::from_octets_unchecked( + RelativeName::from_octets_unchecked( self.zonefile.buf.split_to(write).freeze(), ) - .chain(Dname::root()) + .chain(Name::root()) .map_err(|_| EntryError::bad_dname()) }; } @@ -583,7 +583,7 @@ impl<'a> Scanner for EntryScanner<'a> { // dname. self.zonefile.buf.next_item()?; return unsafe { - RelativeDname::from_octets_unchecked( + RelativeName::from_octets_unchecked( self.zonefile.buf.split_to(write).freeze(), ) .chain(self.zonefile.get_origin()?) @@ -1471,9 +1471,9 @@ mod test { #[derive(serde::Deserialize)] #[allow(clippy::type_complexity)] struct TestCase { - origin: Dname, + origin: Name, zonefile: std::string::String, - result: Vec, ZoneRecordData>>>, + result: Vec, ZoneRecordData>>>, } impl TestCase { From 592304b2fc927cbffea5d242fa679ea9186c4cba Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 30 Apr 2024 21:17:37 +0200 Subject: [PATCH 15/28] FIX: DisconnectWithFlush ignored in Stream Connection handler. --- src/net/server/connection.rs | 221 ++++++++++++++++++----------------- 1 file changed, 117 insertions(+), 104 deletions(-) diff --git a/src/net/server/connection.rs b/src/net/server/connection.rs index bab15a30a..486feb2e2 100644 --- a/src/net/server/connection.rs +++ b/src/net/server/connection.rs @@ -658,131 +658,144 @@ where where Svc::Stream: Send, { - if let Ok(buf) = res { - let received_at = Instant::now(); + match res { + Ok(buf) => { + let received_at = Instant::now(); - if enabled!(Level::TRACE) { - let pcap_text = to_pcap_text(&buf, buf.as_ref().len()); - trace!(addr = %self.addr, pcap_text, "Received message"); - } + if enabled!(Level::TRACE) { + let pcap_text = to_pcap_text(&buf, buf.as_ref().len()); + trace!(addr = %self.addr, pcap_text, "Received message"); + } - self.metrics.inc_num_received_requests(); + self.metrics.inc_num_received_requests(); - // Message received, reset the DNS idle timer - self.idle_timer.full_msg_received(); + // Message received, reset the DNS idle timer + self.idle_timer.full_msg_received(); - match Message::from_octets(buf) { - Err(err) => { - tracing::warn!( - "Failed while parsing request message: {err}" - ); - return Err(ConnectionEvent::ServiceError( - ServiceError::FormatError, - )); - } + match Message::from_octets(buf) { + Err(err) => { + tracing::warn!( + "Failed while parsing request message: {err}" + ); + return Err(ConnectionEvent::ServiceError( + ServiceError::FormatError, + )); + } - Ok(msg) => { - let ctx = NonUdpTransportContext::new(Some( - self.config.load().idle_timeout, - )); - let ctx = TransportSpecificContext::NonUdp(ctx); - let request = - Request::new(self.addr, received_at, msg, ctx); - - let svc = self.service.clone(); - let result_q_tx = self.result_q_tx.clone(); - let metrics = self.metrics.clone(); - let config = self.config.clone(); - - trace!( - "Spawning task to handle new message with id {}", - request.message().header().id() - ); - tokio::spawn(async move { - let request_id = request.message().header().id(); - trace!("Calling service for request id {request_id}"); - let mut stream = svc.call(request).await; - let mut in_transaction = false; - - trace!("Awaiting service call results for request id {request_id}"); - while let Some(Ok(call_result)) = stream.next().await - { - trace!("Processing service call result for request id {request_id}"); - let (response, feedback) = - call_result.into_inner(); - - if let Some(feedback) = feedback { - match feedback { - ServiceFeedback::Reconfigure { - idle_timeout, - } => { - if let Some(idle_timeout) = - idle_timeout - { - debug!( - "Reconfigured connection timeout to {idle_timeout:?}" - ); - let guard = config.load(); - let mut new_config = **guard; - new_config.idle_timeout = - idle_timeout; - config - .store(Arc::new(new_config)); + Ok(msg) => { + let ctx = NonUdpTransportContext::new(Some( + self.config.load().idle_timeout, + )); + let ctx = TransportSpecificContext::NonUdp(ctx); + let request = + Request::new(self.addr, received_at, msg, ctx); + + let svc = self.service.clone(); + let result_q_tx = self.result_q_tx.clone(); + let metrics = self.metrics.clone(); + let config = self.config.clone(); + + trace!( + "Spawning task to handle new message with id {}", + request.message().header().id() + ); + tokio::spawn(async move { + let request_id = request.message().header().id(); + trace!( + "Calling service for request id {request_id}" + ); + let mut stream = svc.call(request).await; + let mut in_transaction = false; + + trace!("Awaiting service call results for request id {request_id}"); + while let Some(Ok(call_result)) = + stream.next().await + { + trace!("Processing service call result for request id {request_id}"); + let (response, feedback) = + call_result.into_inner(); + + if let Some(feedback) = feedback { + match feedback { + ServiceFeedback::Reconfigure { + idle_timeout, + } => { + if let Some(idle_timeout) = + idle_timeout + { + debug!( + "Reconfigured connection timeout to {idle_timeout:?}" + ); + let guard = config.load(); + let mut new_config = **guard; + new_config.idle_timeout = + idle_timeout; + config.store(Arc::new( + new_config, + )); + } } - } - ServiceFeedback::BeginTransaction => { - in_transaction = true; - } + ServiceFeedback::BeginTransaction => { + in_transaction = true; + } - ServiceFeedback::EndTransaction => { - in_transaction = false; + ServiceFeedback::EndTransaction => { + in_transaction = false; + } } } - } - - if let Some(mut response) = response { - loop { - match result_q_tx.try_send(response) { - Ok(()) => { - trace!("Queued message for sending: # pending writes={}", result_q_tx.max_capacity() - - result_q_tx.capacity()); - metrics.set_num_pending_writes( - result_q_tx.max_capacity() - - result_q_tx.capacity(), - ); - break; - } - Err(TrySendError::Closed(_)) => { - error!("Unable to queue message for sending: server is shutting down."); - break; - } + if let Some(mut response) = response { + loop { + match result_q_tx.try_send(response) { + Ok(()) => { + let pending_writes = + result_q_tx + .max_capacity() + - result_q_tx + .capacity(); + trace!("Queued message for sending: # pending writes={pending_writes}"); + metrics + .set_num_pending_writes( + pending_writes, + ); + break; + } - Err(TrySendError::Full( - unused_response, - )) => { - if in_transaction { - // Wait until there is space in the message queue. - tokio::task::yield_now() - .await; - response = unused_response; - } else { - error!("Unable to queue message for sending: queue is full."); + Err(TrySendError::Closed(_)) => { + error!("Unable to queue message for sending: server is shutting down."); break; } + + Err(TrySendError::Full( + unused_response, + )) => { + if in_transaction { + // Wait until there is space in the message queue. + tokio::task::yield_now() + .await; + response = + unused_response; + } else { + error!("Unable to queue message for sending: queue is full."); + break; + } + } } } } } - } - trace!("Finished processing service call results for request id {request_id}"); - }); + trace!("Finished processing service call results for request id {request_id}"); + }); + } } + + Ok(()) } - } - Ok(()) + Err(err) => Err(err), + } } } From f4a1e0b151579a48f8195d3d3aa0235f3e3224c7 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 30 Apr 2024 21:18:05 +0200 Subject: [PATCH 16/28] Update syntax in broken doc tests. --- src/net/server/dgram.rs | 13 ++-------- src/net/server/service.rs | 54 +++++++++++++++------------------------ src/net/server/stream.rs | 12 ++------- src/net/server/util.rs | 25 +++++------------- 4 files changed, 32 insertions(+), 72 deletions(-) diff --git a/src/net/server/dgram.rs b/src/net/server/dgram.rs index dfd6961a5..5c80b34fd 100644 --- a/src/net/server/dgram.rs +++ b/src/net/server/dgram.rs @@ -209,20 +209,11 @@ type CommandReceiver = watch::Receiver; /// use domain::net::server::buf::VecBufSource; /// use domain::net::server::dgram::DgramServer; /// use domain::net::server::message::Request; -/// use domain::net::server::middleware::builder::MiddlewareBuilder; -/// use domain::net::server::service::{CallResult, ServiceError, Transaction}; +/// use domain::net::server::service::ServiceResult; /// use domain::net::server::stream::StreamServer; /// use domain::net::server::util::service_fn; /// -/// fn my_service(msg: Request>, _meta: ()) -/// -> Result< -/// Transaction, -/// Pin>, ServiceError> -/// > + Send>>, -/// >, -/// ServiceError, -/// > +/// fn my_service(msg: Request>, _meta: ()) -> ServiceResult> /// { /// todo!() /// } diff --git a/src/net/server/service.rs b/src/net/server/service.rs index 0b4c2ce71..0e2f277cc 100644 --- a/src/net/server/service.rs +++ b/src/net/server/service.rs @@ -67,45 +67,47 @@ pub type ServiceResult = Result, ServiceError>; /// use core::future::ready; /// use core::future::Ready; /// +/// use futures::stream::{once, Once, Stream}; +/// /// use domain::base::iana::{Class, Rcode}; /// use domain::base::message_builder::AdditionalBuilder; /// use domain::base::{Name, Message, MessageBuilder, StreamTarget}; /// use domain::net::server::message::Request; -/// use domain::net::server::service::{ -/// CallResult, Service, ServiceError, Transaction -/// }; +/// use domain::net::server::service::{CallResult, Service, ServiceResult}; /// use domain::net::server::util::mk_builder_for_target; /// use domain::rdata::A; /// /// fn mk_answer( /// msg: &Request>, /// builder: MessageBuilder>>, -/// ) -> Result>>, ServiceError> { -/// let mut answer = builder.start_answer(msg.message(), Rcode::NOERROR)?; +/// ) -> AdditionalBuilder>> { +/// let mut answer = builder +/// .start_answer(msg.message(), Rcode::NOERROR) +/// .unwrap(); /// answer.push(( /// Name::root_ref(), /// Class::IN, /// 86400, /// A::from_octets(192, 0, 2, 1), -/// ))?; -/// Ok(answer.additional()) +/// )).unwrap(); +/// answer.additional() /// } /// /// struct MyService; /// /// impl Service> for MyService { /// type Target = Vec; -/// type Future = Ready, ServiceError>>; +/// type Stream = Once>>; +/// type Future = Ready; /// /// fn call( /// &self, /// msg: Request>, -/// ) -> Result, ServiceError> { +/// ) -> Self::Future { /// let builder = mk_builder_for_target(); -/// let additional = mk_answer(&msg, builder)?; -/// let item = ready(Ok(CallResult::new(additional))); -/// let txn = Transaction::single(item); -/// Ok(txn) +/// let additional = mk_answer(&msg, builder); +/// let item = Ok(CallResult::new(additional)); +/// ready(once(ready(item))) /// } /// } /// ``` @@ -123,26 +125,13 @@ pub type ServiceResult = Result, ServiceError>; /// use domain::base::wire::Composer; /// use domain::dep::octseq::{OctetsBuilder, FreezeBuilder, Octets}; /// use domain::net::server::message::Request; -/// use domain::net::server::service::{CallResult, ServiceError, Transaction}; +/// use domain::net::server::service::{CallResult, ServiceError, ServiceResult}; /// use domain::net::server::util::mk_builder_for_target; /// use domain::rdata::A; /// -/// fn name_to_ip( -/// msg: Request>, -/// ) -> Result< -/// Transaction, ServiceError> -/// > + Send, -/// >, -/// ServiceError, -/// > -/// where -/// Target: Composer + Octets + FreezeBuilder + Default + Send, -/// ::AppendError: Debug, -/// { +/// fn name_to_ip(request: Request>) -> ServiceResult> { /// let mut out_answer = None; -/// if let Ok(question) = msg.message().sole_question() { +/// if let Ok(question) = request.message().sole_question() { /// let qname = question.qname(); /// let num_labels = qname.label_count(); /// if num_labels >= 5 { @@ -156,7 +145,7 @@ pub type ServiceResult = Result, ServiceError>; /// let builder = mk_builder_for_target(); /// let mut answer = /// builder -/// .start_answer(msg.message(), Rcode::NOERROR) +/// .start_answer(request.message(), Rcode::NOERROR) /// .unwrap(); /// answer /// .push((Name::root_ref(), Class::IN, 86400, a_rec)) @@ -169,14 +158,13 @@ pub type ServiceResult = Result, ServiceError>; /// if out_answer.is_none() { /// let builder = mk_builder_for_target(); /// let answer = builder -/// .start_answer(msg.message(), Rcode::REFUSED) +/// .start_answer(request.message(), Rcode::REFUSED) /// .unwrap(); /// out_answer = Some(answer); /// } /// /// let additional = out_answer.unwrap().additional(); -/// let item = Ok(CallResult::new(additional)); -/// Ok(Transaction::single(ready(item))) +/// Ok(CallResult::new(additional)) /// } /// ``` /// diff --git a/src/net/server/stream.rs b/src/net/server/stream.rs index 6ff3263c3..0d3547ac9 100644 --- a/src/net/server/stream.rs +++ b/src/net/server/stream.rs @@ -225,19 +225,11 @@ type CommandReceiver = watch::Receiver; /// use domain::base::Message; /// use domain::net::server::buf::VecBufSource; /// use domain::net::server::message::Request; -/// use domain::net::server::service::{CallResult, ServiceError, Transaction}; +/// use domain::net::server::service::{CallResult, ServiceError, ServiceResult}; /// use domain::net::server::stream::StreamServer; /// use domain::net::server::util::service_fn; /// -/// fn my_service(msg: Request>, _meta: ()) -/// -> Result< -/// Transaction, -/// Pin>, ServiceError> -/// > + Send>>, -/// >, -/// ServiceError, -/// > +/// fn my_service(msg: Request>, _meta: ()) -> ServiceResult> /// { /// todo!() /// } diff --git a/src/net/server/util.rs b/src/net/server/util.rs index 17fa2b022..f2048a300 100644 --- a/src/net/server/util.rs +++ b/src/net/server/util.rs @@ -68,7 +68,7 @@ where /// use domain::base::iana::Rcode; /// use domain::base::Message; /// use domain::net::server::message::Request; -/// use domain::net::server::service::{CallResult, ServiceError, Transaction}; +/// use domain::net::server::service::{CallResult, ServiceError, ServiceResult}; /// use domain::net::server::util::{mk_builder_for_target, service_fn}; /// /// // Define some types to make the example easier to read. @@ -77,23 +77,12 @@ where /// // Implement the application logic of our service. /// // Takes the received DNS request and any additional meta data you wish to /// // provide, and returns one or more future DNS responses. -/// fn my_service( -/// req: Request>, -/// _meta: MyMeta, -/// ) -> Result< -/// Transaction, -/// Pin>, ServiceError> -/// >>>, -/// >, -/// ServiceError, -/// > { -/// // For each request create a single response: -/// Ok(Transaction::single(Box::pin(async move { -/// let builder = mk_builder_for_target(); -/// let answer = builder.start_answer(req.message(), Rcode::NXDOMAIN)?; -/// Ok(CallResult::new(answer.additional())) -/// }))) +/// fn my_service(req: Request>, _meta: MyMeta) +/// -> ServiceResult> +/// { +/// let builder = mk_builder_for_target(); +/// let answer = builder.start_answer(req.message(), Rcode::NXDOMAIN)?; +/// Ok(CallResult::new(answer.additional())) /// } /// /// // Turn my_service() into an actual Service trait impl. From d236fffce9c35f748da69daf19d6aec27f37e5a1 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 30 Apr 2024 21:18:11 +0200 Subject: [PATCH 17/28] Add a bare bones demo of an RFC 9567 monitoring agent implemented using the domain server functionality. --- examples/serve-rfc9567-agent.rs | 149 ++++++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) create mode 100644 examples/serve-rfc9567-agent.rs diff --git a/examples/serve-rfc9567-agent.rs b/examples/serve-rfc9567-agent.rs new file mode 100644 index 000000000..86fb406ef --- /dev/null +++ b/examples/serve-rfc9567-agent.rs @@ -0,0 +1,149 @@ +#![cfg(feature = "siphasher")] +use core::future::pending; + +use std::str::FromStr; +use std::sync::Arc; + +use domain::rdata::rfc1035::TxtBuilder; +use tokio::net::{TcpSocket, UdpSocket}; +use tracing_subscriber::EnvFilter; + +use domain::base::iana::{Class, Rcode}; +use domain::base::name::{Label, ToLabelIter}; +use domain::base::{CharStr, NameBuilder, Ttl}; +use domain::net::server::buf::VecBufSource; +use domain::net::server::dgram::DgramServer; +use domain::net::server::message::Request; +use domain::net::server::middleware::cookies::CookiesMiddlewareSvc; +use domain::net::server::middleware::edns::EdnsMiddlewareSvc; +use domain::net::server::middleware::mandatory::MandatoryMiddlewareSvc; +use domain::net::server::service::{CallResult, ServiceResult}; +use domain::net::server::stream::StreamServer; +use domain::net::server::util::{mk_builder_for_target, service_fn}; + +//----------- my_service() --------------------------------------------------- + +fn my_service( + request: Request>, + _metadata: (), +) -> ServiceResult> { + let mut out_answer = None; + if let Ok(question) = request.message().sole_question() { + // We're expecting an RFC 9567 compatible query, i.e. a QNAME of the + // form: + // _er...._er. + // This has at least 6 labels. + // See: https://www.rfc-editor.org/rfc/rfc9567#name-constructing-the-report-que + let qname = question.qname(); + let num_labels = qname.label_count(); + if num_labels >= 6 { + let mut iter = qname.iter_labels(); + let _er = iter.next().unwrap(); + let rep_qtype = iter.next().unwrap(); + let mut rep_qname = NameBuilder::new_vec(); + let mut second_last_label = Option::<&Label>::None; + let mut last_label = None; + loop { + let label = iter.next().unwrap(); + if let Some(label) = second_last_label { + rep_qname.append_label(label.as_slice()).unwrap(); + } + if label == "_er" { + break; + } else { + second_last_label = last_label; + last_label = Some(label); + } + } + let rep_qname = rep_qname.finish(); + let edns_err_code = last_label.unwrap(); + + // Invoke local program to handle the error report + // TODO + eprintln!("Received error report:"); + eprintln!("QNAME: {rep_qname}"); + eprintln!("QTYPE: {rep_qtype}"); + eprintln!("EDNS error: {edns_err_code}"); + + // https://www.rfc-editor.org/rfc/rfc9567#section-6.3-1 + // "It is RECOMMENDED that the authoritative server for the agent + // domain reply with a positive response (i.e., not with NODATA or + // NXDOMAIN) containing a TXT record." + let builder = mk_builder_for_target(); + let mut answer = builder + .start_answer(request.message(), Rcode::NOERROR) + .unwrap(); + let mut txt_builder = TxtBuilder::>::new(); + let txt = { + let cs = + CharStr::>::from_str("Report received").unwrap(); + txt_builder.append_charstr(&cs).unwrap(); + txt_builder.finish().unwrap() + }; + answer + .push((qname, Class::IN, Ttl::from_days(1), txt)) + .unwrap(); + out_answer = Some(answer); + } + } + + if out_answer.is_none() { + let builder = mk_builder_for_target(); + out_answer = Some( + builder + .start_answer(request.message(), Rcode::REFUSED) + .unwrap(), + ); + } + + let additional = out_answer.unwrap().additional(); + Ok(CallResult::new(additional)) +} + +//----------- main() --------------------------------------------------------- + +#[tokio::main(flavor = "multi_thread")] +async fn main() { + // ----------------------------------------------------------------------- + // Setup logging. You can override the log level by setting environment + // variable RUST_LOG, e.g. RUST_LOG=trace. + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .with_thread_ids(true) + .without_time() + .try_init() + .ok(); + + // ----------------------------------------------------------------------- + // Create a service with accompanying middleware chain to answer incoming + // requests. + let svc = service_fn(my_service, ()); + // https://www.rfc-editor.org/rfc/rfc9567#section-6.3-2 "The monitoring + // agent SHOULD respond to queries received over UDP that have no DNS + // Cookie set with a response that has the truncation bit (TC bit) set to + // challenge the resolver to requery over TCP." + let svc = CookiesMiddlewareSvc::, _>::with_random_secret(svc); + let svc = EdnsMiddlewareSvc::, _>::new(svc); + let svc = MandatoryMiddlewareSvc::, _>::new(svc); + let svc = Arc::new(svc); + + // ----------------------------------------------------------------------- + // Run a DNS server on UDP port 8053 on 127.0.0.1 using the my_service + // service defined above and accompanying middleware. + let udpsocket = UdpSocket::bind("127.0.0.1:8053").await.unwrap(); + let buf = Arc::new(VecBufSource); + let srv = DgramServer::new(udpsocket, buf.clone(), svc.clone()); + tokio::spawn(async move { srv.run().await }); + + // ----------------------------------------------------------------------- + // Run a DNS server on TCP port 8053 on 127.0.0.1 using the same service. + let v4socket = TcpSocket::new_v4().unwrap(); + v4socket.set_reuseaddr(true).unwrap(); + v4socket.bind("127.0.0.1:8053".parse().unwrap()).unwrap(); + let v4listener = v4socket.listen(1024).unwrap(); + let buf = Arc::new(VecBufSource); + let srv = StreamServer::new(v4listener, buf.clone(), svc); + tokio::spawn(async move { srv.run().await }); + + pending().await +} From 31e5d7cb5e0054dafbabf3e3ab692f62637f687a Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 30 Apr 2024 21:40:54 +0200 Subject: [PATCH 18/28] Fix broken RustDoc references. --- src/net/server/dgram.rs | 70 +++----------------------- src/net/server/message.rs | 22 ++++---- src/net/server/middleware/cookies.rs | 3 +- src/net/server/middleware/edns.rs | 3 +- src/net/server/middleware/mandatory.rs | 6 +-- src/net/server/mod.rs | 34 ++++++------- src/net/server/service.rs | 22 ++++---- 7 files changed, 45 insertions(+), 115 deletions(-) diff --git a/src/net/server/dgram.rs b/src/net/server/dgram.rs index 5c80b34fd..03a040328 100644 --- a/src/net/server/dgram.rs +++ b/src/net/server/dgram.rs @@ -71,7 +71,7 @@ const WRITE_TIMEOUT: DefMinMax = DefMinMax::new( /// The value has to be between 512 and 4,096 per [RFC 6891]. The default /// value is 1232 per the [2020 DNS Flag Day]. /// -/// The [`Service`] and [`MiddlewareChain`] (if any) are responsible for +/// The [`Service`] and middleware chain (if any) are responsible for /// enforcing this limit. /// /// [2020 DNS Flag Day]: http://www.dnsflagday.net/2020/ @@ -104,7 +104,7 @@ impl Config { /// Pass `None` to prevent sending a limit suggestion to the middleware /// (if any) and service. /// - /// The [`Service`] and [`MiddlewareChain`] (if any) are response for + /// The [`Service`] and middleware chain (if any) are responsible for /// enforcing the suggested limit, or deciding what to do if this is None. /// /// # Reconfigure @@ -493,11 +493,11 @@ where trace!(%addr, pcap_text, "Received message"); } - let state = self.mk_state_for_request(); - let svc = self.service.clone(); let cfg = self.config.clone(); let metrics = self.metrics.clone(); + let cloned_sock = self.sock.clone(); + let write_timeout = self.config.load().write_timeout; tokio::spawn(async move { match Message::from_octets(buf) { @@ -544,10 +544,10 @@ where // Actually write the DNS response message bytes to the UDP // socket. let _ = Self::send_to( - &state.sock, + &cloned_sock, bytes, &addr, - state.write_timeout, + write_timeout, ) .await; @@ -648,18 +648,6 @@ where Ok(()) } } - - /// Helper function to package references to key parts of our server state - /// into a [`RequestState`] ready for passing through the - /// [`CommonMessageFlow`] call chain and ultimately back to ourselves at - /// [`process_call_reusult`]. - fn mk_state_for_request(&self) -> RequestState { - RequestState::new( - self.sock.clone(), - self.command_tx.clone(), - self.config.load().write_timeout, - ) - } } //--- Drop @@ -681,49 +669,3 @@ where let _ = self.shutdown(); } } - -//----------- RequestState --------------------------------------------------- - -/// Data needed by [`DgramServer::process_call_result`] which needs to be -/// passed through the [`CommonMessageFlow`] call chain. -pub struct RequestState { - /// The network socket over which this request was received and over which - /// the response should be sent. - sock: Arc, - - /// A sender for sending [`ServerCommand`]s. - /// - /// Used to signal the server to stop, reconfigure, etc. - command_tx: CommandSender, - - /// The maximum amount of time to wait for a response datagram to be - /// accepted by the operating system for writing back to the client. - write_timeout: Duration, -} - -impl RequestState { - /// Creates a new instance of [`RequestState`]. - fn new( - sock: Arc, - command_tx: CommandSender, - write_timeout: Duration, - ) -> Self { - Self { - sock, - command_tx, - write_timeout, - } - } -} - -//--- Clone - -impl Clone for RequestState { - fn clone(&self) -> Self { - Self { - sock: self.sock.clone(), - command_tx: self.command_tx.clone(), - write_timeout: self.write_timeout, - } - } -} diff --git a/src/net/server/message.rs b/src/net/server/message.rs index dd1d673ed..0eb8369cd 100644 --- a/src/net/server/message.rs +++ b/src/net/server/message.rs @@ -35,14 +35,14 @@ impl UdpTransportContext { /// allowed response size, `Some(n)` otherwise where `n` is the maximum /// number of bytes allowed for the response message. /// - /// The [`EdnsMiddlewareProcessor`] may adjust this limit. + /// The [`EdnsMiddlewareSvc`] may adjust this limit. /// - /// The [`MandatoryMiddlewareProcessor`] enforces this limit. + /// The [`MandatoryMiddlewareSvc`] enforces this limit. /// - /// [`EdnsMiddlewareProcessor`]: - /// crate::net::server::middleware::processors::edns::EdnsMiddlewareProcessor - /// [`MandatoryMiddlewareProcessor`]: - /// crate::net::server::middleware::processors::mandatory::MandatoryMiddlewareProcessor + /// [`EdnsMiddlewareSvc`]: + /// crate::net::server::middleware::edns::EdnsMiddlewareSvc + /// [`MandatoryMiddlewareSvc`]: + /// crate::net::server::middleware::mandatory::MandatoryMiddlewareSvc pub fn max_response_size_hint(&self) -> Option { *self.max_response_size_hint.lock().unwrap() } @@ -81,15 +81,15 @@ impl NonUdpTransportContext { /// This is provided by the server to indicate what the current timeout /// setting in effect is. /// - /// The [`EdnsMiddlewareProcessor`] may report this timeout value back to + /// The [`EdnsMiddlewareSvc`] may report this timeout value back to /// clients capable of interpreting it. /// /// [RFC 7766 section 6.2.3]: /// https://datatracker.ietf.org/doc/html/rfc7766#section-6.2.3 /// [RFC 78828]: https://www.rfc-editor.org/rfc/rfc7828 /// - /// [`EdnsMiddlewareProcessor`]: - /// crate::net::server::middleware::processors::edns::EdnsMiddlewareProcessor + /// [`EdnsMiddlewareSvc`]: + /// crate::net::server::middleware::edns::EdnsMiddlewareSvc pub fn idle_timeout(&self) -> Option { self.idle_timeout } @@ -103,9 +103,11 @@ impl NonUdpTransportContext { /// correctly. Some kinds of contextual information are only available for /// certain transport types. /// -/// Context values may be adjusted by processors in the [`MiddlewareChain`] +/// Context values may be adjusted by processors in the middleware chain /// and/or by the [`Service`] that receives the request, in order to influence /// the behaviour of other processors, the service or the server. +/// +/// [`Service`]: crate::net::server::service::Service #[derive(Debug, Clone)] pub enum TransportSpecificContext { /// Context for a UDP transport. diff --git a/src/net/server/middleware/cookies.rs b/src/net/server/middleware/cookies.rs index 3cc320a5f..c8fc81107 100644 --- a/src/net/server/middleware/cookies.rs +++ b/src/net/server/middleware/cookies.rs @@ -32,7 +32,7 @@ const FIVE_MINUTES_AS_SECS: u32 = 5 * 60; /// https://www.rfc-editor.org/rfc/rfc9018.html#section-4.3. const ONE_HOUR_AS_SECS: u32 = 60 * 60; -/// A DNS Cookies [`MiddlewareProcessor`]. +/// A DNS Cookies middleware service /// /// Standards covered by ths implementation: /// @@ -43,7 +43,6 @@ const ONE_HOUR_AS_SECS: u32 = 60 * 60; /// /// [7873]: https://datatracker.ietf.org/doc/html/rfc7873 /// [9018]: https://datatracker.ietf.org/doc/html/rfc7873 -/// [`MiddlewareProcessor`]: crate::net::server::middleware::processor::MiddlewareProcessor #[derive(Clone, Debug)] pub struct CookiesMiddlewareSvc { svc: Svc, diff --git a/src/net/server/middleware/edns.rs b/src/net/server/middleware/edns.rs index 38473ce41..45fbcc063 100644 --- a/src/net/server/middleware/edns.rs +++ b/src/net/server/middleware/edns.rs @@ -31,7 +31,7 @@ use super::stream::PostprocessingStream; /// [IANA registry]: https://www.iana.org/assignments/dns-parameters/dns-parameters.xhtml#dns-parameters-14 const EDNS_VERSION_ZERO: u8 = 0; -/// A [`MiddlewareProcessor`] for adding EDNS(0) related functionality. +/// A middleware service for adding EDNS(0) related functionality. /// /// Standards covered by ths implementation: /// @@ -44,7 +44,6 @@ const EDNS_VERSION_ZERO: u8 = 0; /// [6891]: https://datatracker.ietf.org/doc/html/rfc6891 /// [7828]: https://datatracker.ietf.org/doc/html/rfc7828 /// [9210]: https://datatracker.ietf.org/doc/html/rfc9210 -/// [`MiddlewareProcessor`]: crate::net::server::middleware::processor::MiddlewareProcessor #[derive(Clone, Debug, Default)] pub struct EdnsMiddlewareSvc { svc: Svc, diff --git a/src/net/server/middleware/mandatory.rs b/src/net/server/middleware/mandatory.rs index a205f1f48..bacc0747c 100644 --- a/src/net/server/middleware/mandatory.rs +++ b/src/net/server/middleware/mandatory.rs @@ -26,8 +26,8 @@ use super::stream::{MiddlewareStream, PostprocessingStream}; /// [RFC 1035 section 4.2.1]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1 pub const MINIMUM_RESPONSE_BYTE_LEN: u16 = 512; -/// A [`MiddlewareProcessor`] for enforcing core RFC MUST requirements on -/// processed messages. +/// A middleware service for enforcing core RFC MUST requirements on processed +/// messages. /// /// Standards covered by ths implementation: /// @@ -36,8 +36,6 @@ pub const MINIMUM_RESPONSE_BYTE_LEN: u16 = 512; /// | [1035] | TBD | /// | [2181] | TBD | /// -/// [`MiddlewareProcessor`]: -/// crate::net::server::middleware::processor::MiddlewareProcessor /// [1035]: https://datatracker.ietf.org/doc/html/rfc1035 /// [2181]: https://datatracker.ietf.org/doc/html/rfc2181 #[derive(Clone, Debug)] diff --git a/src/net/server/mod.rs b/src/net/server/mod.rs index cae1d2d8d..27fb886df 100644 --- a/src/net/server/mod.rs +++ b/src/net/server/mod.rs @@ -8,7 +8,7 @@ //! //! This module provides skeleton asynchronous server implementations based on //! the [Tokio](https://tokio.rs/) async runtime. In combination with an -//! appropriate network source, optional [`MiddlewareChain`] and your own +//! appropriate network source, optional middleware chain and your own //! [`Service`] implementation they can be used to run a standards compliant //! DNS server that answers requests based on the application logic you //! specify. @@ -32,7 +32,7 @@ //! # Getting started //! //! Servers are implemented by combining a server transport (see [dgram] and -//! [stream]), [`BufSource`], (optional) [`MiddlewareChain`] and [`Service`] +//! [stream]), [`BufSource`], (optional) MiddlewareChain and [`Service`] //! together. //! //! Whether using [`DgramServer`] or [`StreamServer`] the required steps are @@ -42,7 +42,7 @@ //! - Construct a server transport with `new()` passing in the network //! source as an argument. //! - Tune the server behaviour via builder functions such as -//! `with_middleware()`. +//! `with_config()`. //! - `run()` the server. //! - `shutdown()` the server, explicitly or on [`drop`]. //! @@ -92,20 +92,17 @@ //! specific layer of a server nor does it constitute part of the core //! application logic. //! -//! With Middleware mandatory functionality and logic required by all +//! With Middleware, mandatory functionality and logic required by all //! standards compliant DNS servers can be incorporated into your server by -//! building a [`MiddlewareChain`] starting from -//! [`MiddlewareBuilder::default`]. +//! layering your service on top of the [`MandatoryMiddlewareSvc`]. //! //! You can also opt to incorporate additional behaviours into your DNS server -//! from a selection of pre-supplied implementations via -//! [`MiddlewareBuilder::push`]. See the various implementations of -//! [`MiddlewareProcessor`] for more information. +//! from a selection of [pre-supplied middleware]. //! -//! And if the existing middleware processors don't meet your needs, maybe you +//! And if the existing middleware services don't meet your needs, maybe you //! have specific access control or rate limiting requirements for example, -//! you can implement [`MiddlewareProcessor`] yourself to add your own pre- -//! and post- processing stages into your DNS server. +//! you can implement your own middleware service to add your own pre- and +//! post- processing stages into your DNS server. //! //! ## Application logic //! @@ -131,7 +128,9 @@ //! //! ## Performance //! -//! Both [`DgramServer`] and [`StreamServer`] use [`CommonMessageFlow`] to +//!
TODO: This section is outdated!
+//! +//! Both [`DgramServer`] and [`StreamServer`] use `CommonMessageFlow` to //! pre-process the request, invoke [`Service::call`], and post-process the //! response. //! @@ -182,14 +181,9 @@ //! [`AsyncDgramSock`]: sock::AsyncDgramSock //! [`BufSource`]: buf::BufSource //! [`DgramServer`]: dgram::DgramServer -//! [`CommonMessageFlow`]: message::CommonMessageFlow //! [Middleware]: middleware -//! [`MiddlewareBuilder::default`]: -//! middleware::builder::MiddlewareBuilder::default() -//! [`MiddlewareBuilder::push`]: -//! middleware::builder::MiddlewareBuilder::push() -//! [`MiddlewareChain`]: middleware::chain::MiddlewareChain -//! [`MiddlewareProcessor`]: middleware::processor::MiddlewareProcessor +//! [`MandatoryMiddlewareSvc`]: middleware::mandatory::MandatoryMiddlewareSvc +//! [pre-supplied middleware]: middleware //! [`Service`]: service::Service //! [`Service::call`]: service::Service::call() //! [`StreamServer`]: stream::StreamServer diff --git a/src/net/server/service.rs b/src/net/server/service.rs index 0e2f277cc..de5d5f64a 100644 --- a/src/net/server/service.rs +++ b/src/net/server/service.rs @@ -1,9 +1,8 @@ //! The application logic of a DNS server. //! //! The [`Service::call`] function defines how the service should respond to a -//! given DNS request. resulting in a [`Transaction`] containing a transaction -//! that yields one or more future DNS responses, and/or a -//! [`ServiceFeedback`]. +//! given DNS request. resulting in a future that yields a stream of one or +//! more future DNS responses, and/or [`ServiceFeedback`]. use core::fmt::Display; use core::ops::Deref; @@ -29,20 +28,19 @@ pub type ServiceResult = Result, ServiceError>; /// requests. /// /// A request is "valid" if it passed successfully through the underlying -/// server (e.g. [`DgramServer`] or [`StreamServer`]) and [`MiddlewareChain`] +/// server (e.g. [`DgramServer`] or [`StreamServer`]) and middleware /// stages. /// /// For an overview of how services fit into the total flow of request and /// response handling see the [net::server module documentation]. /// /// Each [`Service`] implementation defines a [`call`] function which takes a -/// [`Request`] DNS request as input and returns either a [`Transaction`] on -/// success, or a [`ServiceError`] on failure, as output. +/// [`Request`] DNS request as input and returns a future that yields a stream +/// of one or more items each of which is either a [`CallResult`] or +/// [`ServiceError`]. /// -/// Each [`Transaction`] contains either a single DNS response message, or a -/// stream of DNS response messages (e.g. for a zone transfer). Each response -/// message is returned as a [`Future`] which the underlying server will -/// resolve to a [`CallResult`]. +/// Most DNS requests result in a single response, with the exception of AXFR +/// and IXFR requests which can result in a stream of responses. /// /// # Usage /// @@ -57,7 +55,7 @@ pub type ServiceResult = Result, ServiceError>; /// Whichever approach you choose it is important to minimize the work done /// before returning from [`Service::call`], as time spent here blocks the /// caller. Instead as much work as possible should be delegated to the -/// futures returned as a [`Transaction`]. +/// future returned. /// /// /// @@ -179,8 +177,6 @@ pub type ServiceResult = Result, ServiceError>; /// See [`service_fn`] for an example of how to use it to create a [`Service`] /// impl from a funciton. /// -/// [`MiddlewareChain`]: -/// crate::net::server::middleware::chain::MiddlewareChain /// [`DgramServer`]: crate::net::server::dgram::DgramServer /// [`StreamServer`]: crate::net::server::stream::StreamServer /// [net::server module documentation]: crate::net::server From a45a09f1b04caa15b82789b5960325a08ae040ad Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 30 Apr 2024 21:44:29 +0200 Subject: [PATCH 19/28] Fix compilation error "`main` function not found in crate `serve_rfc9567_agent`". --- Cargo.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index 7073e2877..7e31fbd12 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -130,6 +130,10 @@ required-features = ["zonefile", "unstable-zonetree"] name = "serve-zone" required-features = ["zonefile", "net", "unstable-server-transport", "unstable-zonetree"] +[[example]] +name = "serve-rfc9567-agent" +required-features = ["net", "unstable-server-transport", "siphasher"] + # This example is commented out because it is difficult, if not impossible, # when including the sqlx dependency, to make the dependency tree compatible # with both `cargo +nightly update -Z minimal versions` and the crate minimum From 429de28b9d8289384f749d6862d7e5e042e192b0 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 30 Apr 2024 21:48:08 +0200 Subject: [PATCH 20/28] Fix compilation error "use of unstable library feature 'option_as_slice'". --- src/net/server/middleware/mandatory.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/net/server/middleware/mandatory.rs b/src/net/server/middleware/mandatory.rs index bacc0747c..be41480ff 100644 --- a/src/net/server/middleware/mandatory.rs +++ b/src/net/server/middleware/mandatory.rs @@ -444,7 +444,7 @@ mod tests { let (response, _feedback) = call_result.into_inner(); // Get the response length - let new_size = response.as_slice().len(); + let new_size = response.unwrap().as_slice().len(); if new_size < old_size { Some(new_size) From ce0f02992b7adb3a779bad112fc0ab0705a314ae Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Wed, 1 May 2024 00:09:27 +0200 Subject: [PATCH 21/28] FIX: Cookie middleware doesn't do post-processing, so actually support the identity case. --- examples/server-transports.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server-transports.rs b/examples/server-transports.rs index fc8e5de68..869893ef0 100644 --- a/examples/server-transports.rs +++ b/examples/server-transports.rs @@ -537,6 +537,7 @@ where { type Target = Svc::Target; type Stream = MiddlewareStream< + Svc::Future, Svc::Stream, PostprocessingStream< RequestOctets, From 4313cfd086b5ec7c8be38f7d01e0995a0c639788 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Wed, 1 May 2024 00:09:44 +0200 Subject: [PATCH 22/28] Files missed from previous commit. --- src/net/server/middleware/cookies.rs | 110 +++++++++---------------- src/net/server/middleware/edns.rs | 1 + src/net/server/middleware/mandatory.rs | 1 + src/net/server/middleware/stream.rs | 46 ++++++++--- 4 files changed, 77 insertions(+), 81 deletions(-) diff --git a/src/net/server/middleware/cookies.rs b/src/net/server/middleware/cookies.rs index 11149495e..21894f9d3 100644 --- a/src/net/server/middleware/cookies.rs +++ b/src/net/server/middleware/cookies.rs @@ -18,12 +18,10 @@ use crate::base::wire::{Composer, ParseError}; use crate::base::{Serial, StreamTarget}; use crate::net::server::message::Request; use crate::net::server::middleware::stream::MiddlewareStream; -use crate::net::server::service::{CallResult, Service, ServiceResult}; +use crate::net::server::service::{CallResult, Service}; use crate::net::server::util::add_edns_options; use crate::net::server::util::{mk_builder_for_target, start_reply}; -use super::stream::PostprocessingStream; - /// The five minute period referred to by /// https://www.rfc-editor.org/rfc/rfc9018.html#section-4.3. const FIVE_MINUTES_AS_SECS: u32 = 5 * 60; @@ -396,51 +394,38 @@ where ControlFlow::Continue(()) } - fn postprocess( - _request: &Request, - _response: &mut AdditionalBuilder>, - _server_secret: [u8; 16], - ) where - RequestOctets: Octets, - { - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.1 - // No OPT RR or No COOKIE Option: - // If the request lacked a client cookie we don't need to do - // anything. - // - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.2 - // Malformed COOKIE Option: - // If the request COOKIE option was malformed we would have already - // rejected it during pre-processing so again nothing to do here. - // - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.3 - // Only a Client Cookie: - // If the request had a client cookie but no server cookie and - // we didn't already reject the request during pre-processing. - // - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.4 - // A Client Cookie and an Invalid Server Cookie: - // Per RFC 7873 this is handled the same way as the "Only a Client - // Cookie" case. - // - // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.5 - // A Client Cookie and a Valid Server Cookie - // Any server cookie will already have been validated during - // pre-processing, we don't need to check it again here. - } - - fn map_stream_item( - request: Request, - mut stream_item: ServiceResult, - server_secret: [u8; 16], - ) -> ServiceResult { - if let Ok(cr) = &mut stream_item { - if let Some(response) = cr.response_mut() { - Self::postprocess(&request, response, server_secret); - } - } - stream_item - } + // fn postprocess( + // _request: &Request, + // _response: &mut AdditionalBuilder>, + // _server_secret: [u8; 16], + // ) where + // RequestOctets: Octets, + // { + // // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.1 + // // No OPT RR or No COOKIE Option: + // // If the request lacked a client cookie we don't need to do + // // anything. + // // + // // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.2 + // // Malformed COOKIE Option: + // // If the request COOKIE option was malformed we would have already + // // rejected it during pre-processing so again nothing to do here. + // // + // // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.3 + // // Only a Client Cookie: + // // If the request had a client cookie but no server cookie and + // // we didn't already reject the request during pre-processing. + // // + // // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.4 + // // A Client Cookie and an Invalid Server Cookie: + // // Per RFC 7873 this is handled the same way as the "Only a Client + // // Cookie" case. + // // + // // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.5 + // // A Client Cookie and a Valid Server Cookie + // // Any server cookie will already have been validated during + // // pre-processing, we don't need to check it again here. + // } } //--- Service @@ -456,13 +441,9 @@ where { type Target = Svc::Target; type Stream = MiddlewareStream< + Svc::Future, + Svc::Stream, Svc::Stream, - PostprocessingStream< - RequestOctets, - Svc::Future, - Svc::Stream, - [u8; 16], - >, Once::Item>>, ::Item, >; @@ -472,24 +453,11 @@ where match self.preprocess(&request) { ControlFlow::Continue(()) => { let svc_call_fut = self.svc.call(request.clone()); - let map = PostprocessingStream::new( - svc_call_fut, - request, - self.server_secret, - Self::map_stream_item, - ); - ready(MiddlewareStream::Map(map)) - } - ControlFlow::Break(mut response) => { - Self::postprocess( - &request, - &mut response, - self.server_secret, - ); - ready(MiddlewareStream::Result(once(ready(Ok( - CallResult::new(response), - ))))) + ready(MiddlewareStream::IdentityFuture(svc_call_fut)) } + ControlFlow::Break(response) => ready(MiddlewareStream::Result( + once(ready(Ok(CallResult::new(response)))), + )), } } } diff --git a/src/net/server/middleware/edns.rs b/src/net/server/middleware/edns.rs index 45fbcc063..5dcfa6ec7 100644 --- a/src/net/server/middleware/edns.rs +++ b/src/net/server/middleware/edns.rs @@ -330,6 +330,7 @@ where { type Target = Svc::Target; type Stream = MiddlewareStream< + Svc::Future, Svc::Stream, PostprocessingStream, Once::Item>>, diff --git a/src/net/server/middleware/mandatory.rs b/src/net/server/middleware/mandatory.rs index be41480ff..b2b74bae1 100644 --- a/src/net/server/middleware/mandatory.rs +++ b/src/net/server/middleware/mandatory.rs @@ -285,6 +285,7 @@ where { type Target = Svc::Target; type Stream = MiddlewareStream< + Svc::Future, Svc::Stream, PostprocessingStream, Once::Item>>, diff --git a/src/net/server/middleware/stream.rs b/src/net/server/middleware/stream.rs index b20847333..04dd62a6b 100644 --- a/src/net/server/middleware/stream.rs +++ b/src/net/server/middleware/stream.rs @@ -1,5 +1,5 @@ use core::ops::DerefMut; -use core::task::{Context, Poll}; +use core::task::{ready, Context, Poll}; use std::pin::Pin; @@ -8,19 +8,31 @@ use futures::stream::{Stream, StreamExt}; use octseq::Octets; use crate::net::server::message::Request; +use core::future::Future; use tracing::trace; //------------ MiddlewareStream ---------------------------------------------- -pub enum MiddlewareStream -where +pub enum MiddlewareStream< + IdentityFuture, + IdentityStream, + MapStream, + ResultStream, + StreamItem, +> where + IdentityFuture: Future, IdentityStream: Stream, MapStream: Stream, ResultStream: Stream, { - /// The inner service response will be passed through this service without - /// modification. - Identity(IdentityStream), + /// The inner service response future will be passed through this service + /// without modification, resolving the future first and then the + /// resulting IdentityStream next. + IdentityFuture(IdentityFuture), + + /// The inner service response stream will be passed through this service + /// without modification. + IdentityStream(IdentityStream), /// Either a single response has been created without invoking the innter /// service, or the inner service response will be post-processed by this @@ -33,9 +45,17 @@ where //--- impl Stream -impl Stream - for MiddlewareStream +impl + Stream + for MiddlewareStream< + IdentityFuture, + IdentityStream, + MapStream, + ResultStream, + StreamItem, + > where + IdentityFuture: Future + Unpin, IdentityStream: Stream + Unpin, MapStream: Stream + Unpin, ResultStream: Stream + Unpin, @@ -48,7 +68,12 @@ where cx: &mut core::task::Context<'_>, ) -> Poll> { match self.deref_mut() { - MiddlewareStream::Identity(s) => s.poll_next_unpin(cx), + MiddlewareStream::IdentityFuture(f) => { + let stream = ready!(f.poll_unpin(cx)); + *self = MiddlewareStream::IdentityStream(stream); + self.poll_next(cx) + } + MiddlewareStream::IdentityStream(s) => s.poll_next_unpin(cx), MiddlewareStream::Map(s) => s.poll_next_unpin(cx), MiddlewareStream::Result(s) => s.poll_next_unpin(cx), } @@ -56,7 +81,8 @@ where fn size_hint(&self) -> (usize, Option) { match self { - MiddlewareStream::Identity(s) => s.size_hint(), + MiddlewareStream::IdentityFuture(_) => (0, None), + MiddlewareStream::IdentityStream(s) => s.size_hint(), MiddlewareStream::Map(s) => s.size_hint(), MiddlewareStream::Result(s) => s.size_hint(), } From e991c590559784170dc4e2f32f512eb64e386464 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Wed, 1 May 2024 00:15:14 +0200 Subject: [PATCH 23/28] Minor cleanup. --- src/net/server/middleware/stream.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/net/server/middleware/stream.rs b/src/net/server/middleware/stream.rs index 04dd62a6b..99c74f4d7 100644 --- a/src/net/server/middleware/stream.rs +++ b/src/net/server/middleware/stream.rs @@ -1,3 +1,4 @@ +use core::future::Future; use core::ops::DerefMut; use core::task::{ready, Context, Poll}; @@ -6,10 +7,9 @@ use std::pin::Pin; use futures::prelude::future::FutureExt; use futures::stream::{Stream, StreamExt}; use octseq::Octets; +use tracing::trace; use crate::net::server::message::Request; -use core::future::Future; -use tracing::trace; //------------ MiddlewareStream ---------------------------------------------- @@ -34,7 +34,7 @@ pub enum MiddlewareStream< /// without modification. IdentityStream(IdentityStream), - /// Either a single response has been created without invoking the innter + /// Either a single response has been created without invoking the inner /// service, or the inner service response will be post-processed by this /// service. Map(MapStream), @@ -164,13 +164,13 @@ where ) -> Poll> { match &mut self.state { PostprocessingStreamState::Pending(svc_call_fut) => { - let stream = futures::ready!(svc_call_fut.poll_unpin(cx)); + let stream = ready!(svc_call_fut.poll_unpin(cx)); trace!("Stream has become available"); self.state = PostprocessingStreamState::Streaming(stream); self.poll_next(cx) } PostprocessingStreamState::Streaming(stream) => { - let stream_item = futures::ready!(stream.poll_next_unpin(cx)); + let stream_item = ready!(stream.poll_next_unpin(cx)); trace!("Stream item retrieved, mapping to downstream type"); let request = self.request.clone(); let metadata = self.metadata.clone(); From 7b765ddcc5704a792ef77e658b553360e682db79 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Wed, 12 Jun 2024 11:44:55 +0200 Subject: [PATCH 24/28] Compilation fix. --- tests/net-server.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/net-server.rs b/tests/net-server.rs index 53375cb59..4212879d9 100644 --- a/tests/net-server.rs +++ b/tests/net-server.rs @@ -42,6 +42,7 @@ use domain::stelline::parse_stelline; use domain::stelline::parse_stelline::parse_file; use domain::stelline::parse_stelline::Config; use domain::stelline::parse_stelline::Matches; +use domain::utils::base16; //----------- Tests ---------------------------------------------------------- @@ -62,6 +63,7 @@ async fn server_tests(#[files("test-data/server/*.rpl")] rpl_file: PathBuf) { // Initialize tracing based logging. Override with env var RUST_LOG, e.g. // RUST_LOG=trace. DEBUG level will show the .rpl file name, Stelline step // numbers and types as they are being executed. + tracing_subscriber::fmt() .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) .with_thread_ids(true) @@ -128,7 +130,7 @@ async fn server_tests(#[files("test-data/server/*.rpl")] rpl_file: PathBuf) { #[cfg(feature = "siphasher")] let secret = server_config.cookies.secret.unwrap(); - let secret = hex::decode(secret).unwrap(); + let secret = base16::decode_vec(secret).unwrap(); let secret = <[u8; 16]>::try_from(secret).unwrap(); let svc = CookiesMiddlewareSvc::new(svc, secret) .with_denied_ips(server_config.cookies.ip_deny_list.clone()); From 16f369bf168a9cf9f109240644388b5a8446d8e6 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 23 Jul 2024 15:38:45 +0200 Subject: [PATCH 25/28] RustDoc fixes and additions, minor code tweaks and removal of commented out code. --- examples/serve-zone.rs | 23 ------ examples/server-transports.rs | 2 +- src/net/server/connection.rs | 6 +- src/net/server/middleware/cookies.rs | 2 +- src/net/server/middleware/mod.rs | 27 ++++++++ src/net/server/mod.rs | 100 ++++++++++----------------- src/net/server/service.rs | 48 ++++++------- src/net/server/tests/unit.rs | 1 - 8 files changed, 91 insertions(+), 118 deletions(-) diff --git a/examples/serve-zone.rs b/examples/serve-zone.rs index f5478eeff..356b64998 100644 --- a/examples/serve-zone.rs +++ b/examples/serve-zone.rs @@ -53,8 +53,6 @@ async fn main() { // Populate a zone tree with test data let zone_bytes = include_bytes!("../test-data/zonefiles/nsd-example.txt"); let mut zone_bytes = BufReader::new(&zone_bytes[..]); - // let zone_bytes = std::fs::File::open("/etc/nsd/zones/de-zone").unwrap(); - // let mut zone_bytes = BufReader::new(zone_bytes); // We're reading from static data so this cannot fail due to I/O error. // Don't handle errors that shouldn't happen, keep the example focused @@ -62,19 +60,6 @@ async fn main() { let reader = inplace::Zonefile::load(&mut zone_bytes).unwrap(); let zone = Zone::try_from(reader).unwrap(); - // TODO: Make changes to a zone to create a diff for IXFR use. - // let mut writer = zone.write().await; - // { - // let node = writer.open(true).await.unwrap(); - // let mut new_ns = Rrset::new(Rtype::NS, Ttl::from_secs(60)); - // let ns_rec = domain::rdata::Ns::new( - // Dname::from_str("write-test.example.com").unwrap(), - // ); - // new_ns.push_data(ns_rec.into()); - // node.update_rrset(SharedRrset::new(new_ns)).await.unwrap(); - // } - // let diff = writer.commit().await.unwrap(); - let mut zones = ZoneTree::new(); zones.insert_zone(zone.clone()).unwrap(); let zones = Arc::new(zones); @@ -82,14 +67,6 @@ async fn main() { let addr = "127.0.0.1:8053"; let svc = service_fn(my_service, zones); - // TODO: Insert XFR middleware to automagically handle AXFR and IXFR - // requests. - // let mut svc = XfrMiddlewareSvc::, _>::new(svc); - // svc.add_zone(zone.clone()); - // if let Some(diff) = diff { - // svc.add_diff(&zone, diff); - // } - #[cfg(feature = "siphasher")] let svc = CookiesMiddlewareSvc::, _>::with_random_secret(svc); let svc = EdnsMiddlewareSvc::, _>::new(svc); diff --git a/examples/server-transports.rs b/examples/server-transports.rs index 9eade99c7..59c48d3fe 100644 --- a/examples/server-transports.rs +++ b/examples/server-transports.rs @@ -275,7 +275,7 @@ fn query( eprintln!("Setting idle timeout to {idle_timeout:?}"); let builder = mk_builder_for_target(); - let answer = mk_answer(&request, builder).unwrap(); + let answer = mk_answer(&request, builder)?; Ok(CallResult::new(answer).with_feedback(cmd)) } diff --git a/src/net/server/connection.rs b/src/net/server/connection.rs index 3b901ded2..6ae98dfb5 100644 --- a/src/net/server/connection.rs +++ b/src/net/server/connection.rs @@ -2,10 +2,12 @@ use core::ops::{ControlFlow, Deref}; use core::time::Duration; +use std::fmt::Display; use std::io; use std::net::SocketAddr; use std::sync::Arc; +use arc_swap::ArcSwap; use futures::StreamExt; use octseq::Octets; use tokio::io::{ @@ -18,6 +20,7 @@ use tokio::time::{sleep_until, timeout}; use tracing::Level; use tracing::{debug, enabled, error, trace, warn}; +use crate::base::message_builder::AdditionalBuilder; use crate::base::wire::Composer; use crate::base::{Message, StreamTarget}; use crate::net::server::buf::BufSource; @@ -30,9 +33,6 @@ use crate::utils::config::DefMinMax; use super::message::{NonUdpTransportContext, TransportSpecificContext}; use super::stream::Config as ServerConfig; use super::ServerCommand; -use crate::base::message_builder::AdditionalBuilder; -use arc_swap::ArcSwap; -use std::fmt::Display; /// Limit on the amount of time to allow between client requests. /// diff --git a/src/net/server/middleware/cookies.rs b/src/net/server/middleware/cookies.rs index 1eeaa699a..f7dbcb623 100644 --- a/src/net/server/middleware/cookies.rs +++ b/src/net/server/middleware/cookies.rs @@ -1,4 +1,4 @@ -//! DNS Cookies related message processing. +//! RFC 7873 DNS Cookies related message processing. use core::future::{ready, Ready}; use core::marker::PhantomData; use core::ops::ControlFlow; diff --git a/src/net/server/middleware/mod.rs b/src/net/server/middleware/mod.rs index 1a1434ed8..3d367ca0c 100644 --- a/src/net/server/middleware/mod.rs +++ b/src/net/server/middleware/mod.rs @@ -1,3 +1,30 @@ +//! Request pre-processing and response post-processing middleware. +//! +//! Middleware sits between the server and the application [`Service`], +//! pre-processing requests and post-processing responses in order to +//! filter/reject/modify them according to policy and standards. +//! +//! Middleware is implemented in terms of the [`Service`] trait, just like +//! your application service, but also takes a [`Service`] instance as an +//! argument. This is intended to enable middleware to be composed in layers +//! one atop another, each layer receiving and pre-processing requests from +//! the layer beneath, passing them on to the layer above and then +//! post-processing the resulting responses and propagating them back down +//! through the layers to the server. +//! +//! Currently the following middleware are available: +//! +//! - [`MandatoryMiddlewareSvc`]: Core DNS RFC standards based message +//! processing for MUST requirements. +//! - [`EdnsMiddlewareSvc`]: RFC 6891 and related EDNS message processing. +//! - [`CookiesMiddlewareSvc`]: RFC 7873 DNS Cookies related message +//! processing. +//! +//! [`MandatoryMiddlewareSvc`]: mandatory::MandatoryMiddlewareSvc +//! [`EdnsMiddlewareSvc`]: edns::EdnsMiddlewareSvc +//! [`CookiesMiddlewareSvc`]: cookies::CookiesMiddlewareSvc +//! [`Service`]: crate::net::server::service::Service + #[cfg(feature = "siphasher")] pub mod cookies; pub mod edns; diff --git a/src/net/server/mod.rs b/src/net/server/mod.rs index 27fb886df..4a9d3fcc3 100644 --- a/src/net/server/mod.rs +++ b/src/net/server/mod.rs @@ -8,8 +8,8 @@ //! //! This module provides skeleton asynchronous server implementations based on //! the [Tokio](https://tokio.rs/) async runtime. In combination with an -//! appropriate network source, optional middleware chain and your own -//! [`Service`] implementation they can be used to run a standards compliant +//! appropriate network source, optional middleware services and your own +//! [`Service`] implementation, they can be used to run a standards compliant //! DNS server that answers requests based on the application logic you //! specify. //! @@ -19,30 +19,32 @@ //! requests and outgoing responses: //! //! ```text -//! --> network source - reads bytes from the client -//! --> server - deserializes requests -//! --> (optional) middleware chain - pre-processes requests -//! --> service - processes requests & -//! <-- generates responses -//! <-- (optional) middleware chain - post-processes responses -//! <-- server - serializes responses -//! <-- network source - writes bytes to the client +//! --> network source - reads bytes from the client +//! --> server - deserializes requests +//! --> (optional) middleware services - pre-processes requests +//! --> service - processes requests & +//! <-- generates responses +//! <-- (optional) middleware services - post-processes responses +//! <-- server - serializes responses +//! <-- network source - writes bytes to the client //! ```` //! //! # Getting started //! //! Servers are implemented by combining a server transport (see [dgram] and -//! [stream]), [`BufSource`], (optional) MiddlewareChain and [`Service`] -//! together. +//! [stream]), [`BufSource`] and [`Service`] together. Middleware [`Service`] +//! impls take an upstream [`Service`] instance as input during construction +//! allowing them to be layered on top of one another, with your own +//! application [`Service`] impl at the peak. //! //! Whether using [`DgramServer`] or [`StreamServer`] the required steps are //! the same. //! //! - Create an appropriate network source (more on this below). //! - Construct a server transport with `new()` passing in the network -//! source as an argument. -//! - Tune the server behaviour via builder functions such as -//! `with_config()`. +//! source and service instance as arguments. +//! - (optional) Tune the server behaviour via builder functions such as +//! `with_config()`. //! - `run()` the server. //! - `shutdown()` the server, explicitly or on [`drop`]. //! @@ -59,11 +61,8 @@ //! //! Modern DNS servers increasingly need to support stream based //! connection-oriented network transport protocols for additional response -//! capacity and connection security. -//! -//! This module provides support for both datagram and stream based network -//! transport protocols via the [`DgramServer`] and [`StreamServer`] types -//! respectively. +//! capacity and connection security. This module provides support for both +//! via the [`DgramServer`] and [`StreamServer`] types respectively. //! //! ## Datagram (e.g. UDP) servers //! @@ -92,24 +91,19 @@ //! specific layer of a server nor does it constitute part of the core //! application logic. //! -//! With Middleware, mandatory functionality and logic required by all -//! standards compliant DNS servers can be incorporated into your server by -//! layering your service on top of the [`MandatoryMiddlewareSvc`]. -//! -//! You can also opt to incorporate additional behaviours into your DNS server -//! from a selection of [pre-supplied middleware]. -//! -//! And if the existing middleware services don't meet your needs, maybe you -//! have specific access control or rate limiting requirements for example, -//! you can implement your own middleware service to add your own pre- and -//! post- processing stages into your DNS server. +//! Mandatory functionality and logic required by all standards compliant DNS +//! servers can be incorporated into your server by layering your service on +//! top of [`MandatoryMiddlewareSvc`]. Additional layers of behaviour can be +//! optionally added from a selection of [pre-supplied middleware] or +//! middleware that you create yourself. //! //! ## Application logic //! //! With the basic work of handling DNS requests and responses taken care of, //! the actual application logic that differentiates your DNS server from //! other DNS servers is left for you to define by implementing the -//! [`Service`] trait. +//! [`Service`] trait yourself and passing an instance of that service to the +//! server or middleware service as input. //! //! # Advanced //! @@ -128,43 +122,23 @@ //! //! ## Performance //! -//!
TODO: This section is outdated!
-//! -//! Both [`DgramServer`] and [`StreamServer`] use `CommonMessageFlow` to -//! pre-process the request, invoke [`Service::call`], and post-process the -//! response. -//! -//! - Pre-processing and [`Service::call`] invocation are done from the -//! Tokio task handling the request. For [`DgramServer`] this is the main -//! task that receives incoming messages. For [`StreamServer`] this is a -//! dedicated task per accepted connection. -//! - Post-processing is done in a new task request within which each future -//! resulting from invoking [`Service::call`] is awaited and the resulting -//! response is post-processed. -//! -//! The initial work done by [`Service::call`] should therefore complete as -//! quickly as possible, delegating as much of the work as it can to the -//! future(s) it returns. Until then it blocks the server from receiving new -//! messages, or in the case of [`StreamServer`], new messages for the -//! connection on which the current message was received. +//! Calls into the service layer from the servers are asynchronous and thus +//! managed by the Tokio async runtime. As with any Tokio application, long +//! running tasks should be spawned onto a separate threadpool, e.g. via +//! [`tokio::task::spawn_blocking()`] to avoid blocking the Tokio async +//! runtime. //! //! ## Clone, Arc, and shared state //! //! Both [`DgramServer`] and [`StreamServer`] take ownership of the //! [`Service`] impl passed to them. //! -//! While this may work for some scenarios, real DNS server applications will -//! likely need to accept client requests over multiple transports, will -//! require multiple instances of [`DgramServer`] and [`StreamServer`], and -//! the [`Service`] impl will likely need to have its own state. -//! -//! In these more complex scenarios it becomes more important to understand -//! how the servers work with the [`Service`] impl and the [`Clone`] and -//! [`Arc`] traits. -//! -//! [`DgramServer`] uses a single copy of the [`Service`] impl that it -//! receives but [`StreamServer`] requires that [`Service`] be [`Clone`] -//! because it clones it for each new connection that it accepts. +//! For each request received a new Tokio task is spawned to parse the request +//! bytes, pass it to the first service and process the response(s). +//! [`Service`] impls are therefore required to implement the [`Clone`] trait, +//! either directly or indirectly by for example wrapping the service instance +//! in an [`Arc`], so that [`Service::call`] can be invoked inside the task +//! handling the request. //! //! There are various approaches you can take to manage the sharing of state //! between server instances and processing tasks, for example: diff --git a/src/net/server/service.rs b/src/net/server/service.rs index 812dd642a..b4034e52d 100644 --- a/src/net/server/service.rs +++ b/src/net/server/service.rs @@ -24,17 +24,13 @@ use futures::stream::once; /// The type of item that `Service` implementations stream as output. pub type ServiceResult = Result, ServiceError>; -/// [`Service`]s are responsible for determining how to respond to valid DNS +/// `Service`s are responsible for determining how to respond to DNS /// requests. /// -/// A request is "valid" if it passed successfully through the underlying -/// server (e.g. [`DgramServer`] or [`StreamServer`]) and middleware -/// stages. -/// /// For an overview of how services fit into the total flow of request and -/// response handling see the [net::server module documentation]. +/// response handling see the [`net::server`] module documentation. /// -/// Each [`Service`] implementation defines a [`call`] function which takes a +/// Each `Service` implementation defines a [`call`] function which takes a /// [`Request`] DNS request as input and returns a future that yields a stream /// of one or more items each of which is either a [`CallResult`] or /// [`ServiceError`]. @@ -44,10 +40,10 @@ pub type ServiceResult = Result, ServiceError>; /// /// # Usage /// -/// There are three ways to implement the [`Service`] trait: +/// There are three ways to implement the `Service` trait: /// -/// 1. Implement the [`Service`] trait on a struct. -/// 2. Define a function compatible with the [`Service`] trait. +/// 1. Implement the `Service` trait on a struct. +/// 2. Define a function compatible with the `Service` trait. /// 3. Define a function compatible with [`service_fn`]. /// ///
@@ -59,7 +55,7 @@ pub type ServiceResult = Result, ServiceError>; /// ///
/// -/// # Implementing the [`Service`] trait on a `struct` +/// # Implementing the `Service` trait on a `struct` /// /// ``` /// use core::future::ready; @@ -110,7 +106,7 @@ pub type ServiceResult = Result, ServiceError>; /// } /// ``` /// -/// # Define a function compatible with the [`Service`] trait +/// # Define a function compatible with the `Service` trait /// /// ``` /// use core::fmt::Debug; @@ -174,12 +170,12 @@ pub type ServiceResult = Result, ServiceError>; /// /// # Define a function compatible with [`service_fn`] /// -/// See [`service_fn`] for an example of how to use it to create a [`Service`] -/// impl from a funciton. +/// See [`service_fn`] for an example of how to use it to create a `Service` +/// impl from a function. /// /// [`DgramServer`]: crate::net::server::dgram::DgramServer /// [`StreamServer`]: crate::net::server::stream::StreamServer -/// [net::server module documentation]: crate::net::server +/// [`net::server`]: crate::net::server /// [`call`]: Self::call() /// [`service_fn`]: crate::net::server::util::service_fn() pub trait Service + Send + Sync + Unpin = Vec> @@ -200,7 +196,7 @@ pub trait Service + Send + Sync + Unpin = Vec> //--- impl Service for Arc -/// Helper trait impl to treat an [`Arc`] as a [`Service`]. +/// Helper trait impl to treat an [`Arc`] as a `Service`. impl Service for Arc where RequestOctets: Unpin + Send + Sync + AsRef<[u8]>, @@ -217,7 +213,7 @@ where //--- impl Service for functions with matching signature -/// Helper trait impl to treat a function as a [`Service`]. +/// Helper trait impl to treat a function as a `Service`. impl Service for F where RequestOctets: AsRef<[u8]> + Send + Sync + Unpin, @@ -236,7 +232,7 @@ where //------------ ServiceError -------------------------------------------------- -/// An error reported by a [`Service`]. +/// An error reported by a `Service`. #[derive(Debug)] pub enum ServiceError { /// The service was unable to parse the request. @@ -295,21 +291,24 @@ impl From for ServiceError { //------------ ServiceFeedback ----------------------------------------------- -/// Feedback from a [`Service`] to a server asking it to do something. +/// Feedback from a `Service` to a server asking it to do something. #[derive(Copy, Clone, Debug)] pub enum ServiceFeedback { /// Ask the server to alter its configuration. For connection-oriented /// servers the changes will only apply to the current connection. Reconfigure { - /// If `Some`, the new idle timeout the [`Service`] would like the + /// If `Some`, the new idle timeout the `Service` would like the /// server to use. idle_timeout: Option, }, - /// Ensure that messages from this stream are all enqueued, don't drop - /// messages if the outgoing queue is full. + /// Ask the server to wait much longer for responses than it usually would + /// in order to ensure that an entire set of related response messages are + /// all sent back to the caller rather than being dropped if the outgoing + /// queue is full. BeginTransaction, + /// Signal to the server that the transaction that we began has ended. EndTransaction, } @@ -330,14 +329,11 @@ pub struct CallResult { /// Optional response to send back to the client. response: Option>>, - /// Optional feedback from the [`Service`] to the server. + /// Optional feedback from the `Service` to the server. feedback: Option, } impl CallResult -// where -// Target: OctetsBuilder + AsRef<[u8]> + AsMut<[u8]>, -// Target::AppendError: Into, { /// Construct a [`CallResult`] from a DNS response message. #[must_use] diff --git a/src/net/server/tests/unit.rs b/src/net/server/tests/unit.rs index 81c854faa..7d50ff628 100644 --- a/src/net/server/tests/unit.rs +++ b/src/net/server/tests/unit.rs @@ -402,7 +402,6 @@ async fn service_test() { let buf = MockBufSource; let my_service = Arc::new(MandatoryMiddlewareSvc::new(MyService::new())); - // let my_service = Arc::new(MyService::new()); let srv = Arc::new(StreamServer::new(listener, buf, my_service.clone())); From def12fdce6eff0a01d57a1ac7b7d1057aba2ffa4 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 23 Jul 2024 16:59:11 +0200 Subject: [PATCH 26/28] RustDoc fixes. --- src/net/server/util.rs | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/src/net/server/util.rs b/src/net/server/util.rs index 051cf5157..d9d5ec858 100644 --- a/src/net/server/util.rs +++ b/src/net/server/util.rs @@ -35,7 +35,7 @@ where ) } -//------------ streaming_service_fn() ---------------------------------------- +//------------ service_fn() -------------------------------------------------- /// Helper to simplify making a [`Service`] impl. /// @@ -43,17 +43,9 @@ where /// those of its associated types, but this makes implementing it for simple /// cases quite verbose. /// -/// `streaming_service_fn()` enables you to write a slightly simpler function -/// definition that implements the [`Service`] trait than implementing -/// [`Service`] directly. -/// -/// The provided function must produce a future that results in a stream of -/// futures. The envisaged use case for producing a stream of results in the -/// context of DNS is zone transfers. If you need to implement zone transfer -/// or other streaming support yourself then you should implement [`Service`] -/// directly or via `streaming_service_fn`. -/// -/// Most users should probably use `service_fn` instead. +/// `service_fn()` enables you to write a slightly simpler function definition +/// that implements the [`Service`] trait than implementing [`Service`] +/// directly. /// /// # Example /// From 0513f0b336771bc6f85aa4de09e7a831a289390c Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:02:52 +0200 Subject: [PATCH 27/28] Remove RFC 9567 example. as it has nothing to do with the other changes in this branch. --- Cargo.toml | 4 - examples/serve-rfc9567-agent.rs | 149 -------------------------------- 2 files changed, 153 deletions(-) delete mode 100644 examples/serve-rfc9567-agent.rs diff --git a/Cargo.toml b/Cargo.toml index e2e295133..995864729 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -125,10 +125,6 @@ required-features = ["zonefile", "unstable-zonetree"] name = "serve-zone" required-features = ["zonefile", "net", "unstable-server-transport", "unstable-zonetree"] -[[example]] -name = "serve-rfc9567-agent" -required-features = ["net", "unstable-server-transport", "siphasher"] - # This example is commented out because it is difficult, if not impossible, # when including the sqlx dependency, to make the dependency tree compatible # with both `cargo +nightly update -Z minimal versions` and the crate minimum diff --git a/examples/serve-rfc9567-agent.rs b/examples/serve-rfc9567-agent.rs deleted file mode 100644 index 86fb406ef..000000000 --- a/examples/serve-rfc9567-agent.rs +++ /dev/null @@ -1,149 +0,0 @@ -#![cfg(feature = "siphasher")] -use core::future::pending; - -use std::str::FromStr; -use std::sync::Arc; - -use domain::rdata::rfc1035::TxtBuilder; -use tokio::net::{TcpSocket, UdpSocket}; -use tracing_subscriber::EnvFilter; - -use domain::base::iana::{Class, Rcode}; -use domain::base::name::{Label, ToLabelIter}; -use domain::base::{CharStr, NameBuilder, Ttl}; -use domain::net::server::buf::VecBufSource; -use domain::net::server::dgram::DgramServer; -use domain::net::server::message::Request; -use domain::net::server::middleware::cookies::CookiesMiddlewareSvc; -use domain::net::server::middleware::edns::EdnsMiddlewareSvc; -use domain::net::server::middleware::mandatory::MandatoryMiddlewareSvc; -use domain::net::server::service::{CallResult, ServiceResult}; -use domain::net::server::stream::StreamServer; -use domain::net::server::util::{mk_builder_for_target, service_fn}; - -//----------- my_service() --------------------------------------------------- - -fn my_service( - request: Request>, - _metadata: (), -) -> ServiceResult> { - let mut out_answer = None; - if let Ok(question) = request.message().sole_question() { - // We're expecting an RFC 9567 compatible query, i.e. a QNAME of the - // form: - // _er...._er. - // This has at least 6 labels. - // See: https://www.rfc-editor.org/rfc/rfc9567#name-constructing-the-report-que - let qname = question.qname(); - let num_labels = qname.label_count(); - if num_labels >= 6 { - let mut iter = qname.iter_labels(); - let _er = iter.next().unwrap(); - let rep_qtype = iter.next().unwrap(); - let mut rep_qname = NameBuilder::new_vec(); - let mut second_last_label = Option::<&Label>::None; - let mut last_label = None; - loop { - let label = iter.next().unwrap(); - if let Some(label) = second_last_label { - rep_qname.append_label(label.as_slice()).unwrap(); - } - if label == "_er" { - break; - } else { - second_last_label = last_label; - last_label = Some(label); - } - } - let rep_qname = rep_qname.finish(); - let edns_err_code = last_label.unwrap(); - - // Invoke local program to handle the error report - // TODO - eprintln!("Received error report:"); - eprintln!("QNAME: {rep_qname}"); - eprintln!("QTYPE: {rep_qtype}"); - eprintln!("EDNS error: {edns_err_code}"); - - // https://www.rfc-editor.org/rfc/rfc9567#section-6.3-1 - // "It is RECOMMENDED that the authoritative server for the agent - // domain reply with a positive response (i.e., not with NODATA or - // NXDOMAIN) containing a TXT record." - let builder = mk_builder_for_target(); - let mut answer = builder - .start_answer(request.message(), Rcode::NOERROR) - .unwrap(); - let mut txt_builder = TxtBuilder::>::new(); - let txt = { - let cs = - CharStr::>::from_str("Report received").unwrap(); - txt_builder.append_charstr(&cs).unwrap(); - txt_builder.finish().unwrap() - }; - answer - .push((qname, Class::IN, Ttl::from_days(1), txt)) - .unwrap(); - out_answer = Some(answer); - } - } - - if out_answer.is_none() { - let builder = mk_builder_for_target(); - out_answer = Some( - builder - .start_answer(request.message(), Rcode::REFUSED) - .unwrap(), - ); - } - - let additional = out_answer.unwrap().additional(); - Ok(CallResult::new(additional)) -} - -//----------- main() --------------------------------------------------------- - -#[tokio::main(flavor = "multi_thread")] -async fn main() { - // ----------------------------------------------------------------------- - // Setup logging. You can override the log level by setting environment - // variable RUST_LOG, e.g. RUST_LOG=trace. - tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env()) - .with_thread_ids(true) - .without_time() - .try_init() - .ok(); - - // ----------------------------------------------------------------------- - // Create a service with accompanying middleware chain to answer incoming - // requests. - let svc = service_fn(my_service, ()); - // https://www.rfc-editor.org/rfc/rfc9567#section-6.3-2 "The monitoring - // agent SHOULD respond to queries received over UDP that have no DNS - // Cookie set with a response that has the truncation bit (TC bit) set to - // challenge the resolver to requery over TCP." - let svc = CookiesMiddlewareSvc::, _>::with_random_secret(svc); - let svc = EdnsMiddlewareSvc::, _>::new(svc); - let svc = MandatoryMiddlewareSvc::, _>::new(svc); - let svc = Arc::new(svc); - - // ----------------------------------------------------------------------- - // Run a DNS server on UDP port 8053 on 127.0.0.1 using the my_service - // service defined above and accompanying middleware. - let udpsocket = UdpSocket::bind("127.0.0.1:8053").await.unwrap(); - let buf = Arc::new(VecBufSource); - let srv = DgramServer::new(udpsocket, buf.clone(), svc.clone()); - tokio::spawn(async move { srv.run().await }); - - // ----------------------------------------------------------------------- - // Run a DNS server on TCP port 8053 on 127.0.0.1 using the same service. - let v4socket = TcpSocket::new_v4().unwrap(); - v4socket.set_reuseaddr(true).unwrap(); - v4socket.bind("127.0.0.1:8053".parse().unwrap()).unwrap(); - let v4listener = v4socket.listen(1024).unwrap(); - let buf = Arc::new(VecBufSource); - let srv = StreamServer::new(v4listener, buf.clone(), svc); - tokio::spawn(async move { srv.run().await }); - - pending().await -} From 47a3a15382e4517bf521bce2b5c0e9ce905b1754 Mon Sep 17 00:00:00 2001 From: Ximon Eighteen <3304436+ximon18@users.noreply.github.com> Date: Tue, 23 Jul 2024 17:03:16 +0200 Subject: [PATCH 28/28] cargo fmt. --- src/net/server/middleware/mod.rs | 4 ++-- src/net/server/service.rs | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/net/server/middleware/mod.rs b/src/net/server/middleware/mod.rs index 3d367ca0c..b364a83ba 100644 --- a/src/net/server/middleware/mod.rs +++ b/src/net/server/middleware/mod.rs @@ -11,9 +11,9 @@ //! the layer beneath, passing them on to the layer above and then //! post-processing the resulting responses and propagating them back down //! through the layers to the server. -//! +//! //! Currently the following middleware are available: -//! +//! //! - [`MandatoryMiddlewareSvc`]: Core DNS RFC standards based message //! processing for MUST requirements. //! - [`EdnsMiddlewareSvc`]: RFC 6891 and related EDNS message processing. diff --git a/src/net/server/service.rs b/src/net/server/service.rs index b4034e52d..c78e26051 100644 --- a/src/net/server/service.rs +++ b/src/net/server/service.rs @@ -333,8 +333,7 @@ pub struct CallResult { feedback: Option, } -impl CallResult -{ +impl CallResult { /// Construct a [`CallResult`] from a DNS response message. #[must_use] pub fn new(response: AdditionalBuilder>) -> Self {