Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Require single class zonefiles by default, and give context if possible on parsing errors. #477

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/stelline/parse_stelline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ fn parse_section<Lines: Iterator<Item = Result<String, std::io::Error>>>(
origin = new_origin.to_string();
}
} else {
let mut zonefile = Zonefile::new();
let mut zonefile = Zonefile::new().allow_invalid();
zonefile.extend_from_slice(
format!("$ORIGIN {origin}\n").as_bytes(),
);
Expand Down
146 changes: 119 additions & 27 deletions src/zonefile/inplace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,14 @@ pub type ScannedString = Str<Bytes>;
/// into the memory buffer. The function [`load`][Self::load] can be used to
/// create a value directly from a reader.
///
/// Once data has been added, you can simply iterate over the value to
/// get entries. The [`next_entry`][Self::next_entry] method provides an
/// Once data has been added, you can simply iterate over the value to get
/// entries. The [`next_entry`][Self::next_entry] method provides an
/// alternative with a more question mark friendly signature.
///
/// By default RFC 1035 validity checks are enabled. At present only the first
/// check is implemented: "1. All RRs in the zonefile should have the same
/// class". To disable strict validation call [`allow_invalid()`] prior to
/// calling [`load()`].
#[derive(Clone, Debug)]
pub struct Zonefile {
/// This is where we keep the data of the next entry.
Expand All @@ -73,7 +78,11 @@ pub struct Zonefile {
last_ttl: Ttl,

/// The last class.
last_class: Class,
last_class: Option<Class>,

/// Whether the loaded zonefile should be required to pass RFC 1035
/// validity checks.
require_valid: bool,
}

impl Zonefile {
Expand All @@ -89,14 +98,21 @@ impl Zonefile {
)))
}

/// Disables RFC 1035 section 5.2 zonefile validity checks.
pub fn allow_invalid(mut self) -> Self {
self.require_valid = false;
self
}

/// Creates a new value using the given buffer.
fn with_buf(buf: SourceBuf) -> Self {
Zonefile {
buf,
origin: None,
last_owner: None,
last_ttl: Ttl::from_secs(3600),
last_class: Class::IN,
last_class: None,
require_valid: true,
}
}

Expand Down Expand Up @@ -342,12 +358,27 @@ impl<'a> EntryScanner<'a> {
self.zonefile.last_owner = Some(owner.clone());
}

let class = match class {
Some(class) => {
self.zonefile.last_class = class;
let class = match (class, self.zonefile.last_class) {
(Some(class), Some(last_class)) => {
if self.zonefile.require_valid && class != last_class {
return Err(EntryError::different_class(
last_class, class,
));
}
class
}

(Some(class), None) => {
self.zonefile.last_class = Some(class);
class
}
None => self.zonefile.last_class,

(None, Some(last_class)) => last_class,

(None, None) => {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not correct. There is no default class according to RFC 1035. Do other parsers accept leaving out the class entirely?

self.zonefile.last_class = Some(Class::IN);
Class::IN
}
};

let ttl = match ttl {
Expand Down Expand Up @@ -472,7 +503,7 @@ impl<'a> EntryScanner<'a> {
self.zonefile.buf.require_line_feed()?;
Ok(ScannedEntry::Ttl(Ttl::from_secs(ttl)))
} else {
Err(EntryError::unknown_control())
Err(EntryError::unknown_control(ctrl))
}
}
}
Expand Down Expand Up @@ -1438,75 +1469,136 @@ enum ItemCat {

/// An error returned by the entry scanner.
#[derive(Clone, Debug)]
pub struct EntryError(&'static str);
pub struct EntryError {
msg: &'static str,

#[cfg(feature = "std")]
context: Option<std::string::String>,
}

impl EntryError {
fn bad_symbol(_err: SymbolOctetsError) -> Self {
EntryError("bad symbol")
EntryError {
msg: "bad symbol",
#[cfg(feature = "std")]
context: Some(format!("{}", _err)),
}
}

fn bad_charstr() -> Self {
EntryError("bad charstr")
EntryError {
msg: "bad charstr",
#[cfg(feature = "std")]
context: None,
}
}

fn bad_name() -> Self {
EntryError("bad name")
EntryError {
msg: "bad name",
#[cfg(feature = "std")]
context: None,
}
}

fn unbalanced_parens() -> Self {
EntryError("unbalanced parens")
EntryError {
msg: "unbalanced parens",
#[cfg(feature = "std")]
context: None,
}
}

fn missing_last_owner() -> Self {
EntryError("missing last owner")
EntryError {
msg: "missing last owner",
#[cfg(feature = "std")]
context: None,
}
}

fn missing_origin() -> Self {
EntryError("missing origin")
EntryError {
msg: "missing origin",
#[cfg(feature = "std")]
context: None,
}
}

fn expected_rtype() -> Self {
EntryError("expected rtype")
EntryError {
msg: "expected rtype",
#[cfg(feature = "std")]
context: None,
}
}

fn unknown_control() -> Self {
EntryError("unknown control")
fn unknown_control(ctrl: Str<Bytes>) -> Self {
EntryError {
msg: "unknown control",
#[cfg(feature = "std")]
context: Some(format!("{}", ctrl)),
}
}

fn different_class(expected_class: Class, found_class: Class) -> Self {
EntryError {
msg: "different class",
#[cfg(feature = "std")]
context: Some(format!("{found_class} != {expected_class}")),
}
}
}

impl ScannerError for EntryError {
fn custom(msg: &'static str) -> Self {
EntryError(msg)
EntryError {
msg,
#[cfg(feature = "std")]
context: None,
}
}

fn end_of_entry() -> Self {
Self("unexpected end of entry")
Self::custom("unexpected end of entry")
}

fn short_buf() -> Self {
Self("short buffer")
Self::custom("short buffer")
}

fn trailing_tokens() -> Self {
Self("trailing tokens")
Self::custom("trailing tokens")
}
}

impl From<SymbolOctetsError> for EntryError {
fn from(_: SymbolOctetsError) -> Self {
EntryError("symbol octets error")
fn from(err: SymbolOctetsError) -> Self {
Self::bad_symbol(err)
}
}

impl From<BadSymbol> for EntryError {
fn from(_: BadSymbol) -> Self {
EntryError("bad symbol")
Self::custom("bad symbol")
}
}

impl fmt::Display for EntryError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(self.0.as_ref())
#[cfg(not(feature = "std"))]
{
f.write_str(self.msg)
}

#[cfg(feature = "std")]
{
if let Some(context) = &self.context {
f.write_fmt(format_args!("{}: {}", self.msg, context))
} else {
f.write_str(self.msg)
}
}
Comment on lines +1589 to +1601
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could be more concise if you write out the message unconditionally:

f.write_str(self.msg)?;
#[cfg(feature = "std")]
if let Some(context) = &self.context {
    write!(f, ": {}", context)?;
}

}
}

Expand Down
Loading