diff --git a/src/networking.rs b/src/networking.rs index 9828dc5..123c1e7 100644 --- a/src/networking.rs +++ b/src/networking.rs @@ -1,9 +1,8 @@ use { - log::error, - rand::{seq::SliceRandom, thread_rng as rng}, - std::net::{IpAddr, Ipv4Addr}, + crate::{files, structs::Args}, + std::{collections::HashSet, net::SocketAddr}, trust_dns_resolver::{ - config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}, + config::{NameServerConfig, NameServerConfigGroup, Protocol, ResolverConfig, ResolverOpts}, Resolver, }, }; @@ -19,29 +18,51 @@ pub fn get_records(resolver: &Resolver, domain: &str) -> String { } } -pub fn get_resolver(resolvers_ips: &[Ipv4Addr], opts: &ResolverOpts) -> Resolver { - match Resolver::new( - ResolverConfig::from_parts( - None, - vec![], - NameServerConfigGroup::from_ips_clear( - &[IpAddr::V4( - resolvers_ips - .choose(&mut rng()) - .expect("failed to read ipv4 string") - .to_owned(), - )], - 53, - false, - ), - ), - *opts, - ) { - Ok(resolver) => resolver, +pub fn get_resolver(nameserver_ips: HashSet, opts: ResolverOpts) -> Resolver { + let mut name_servers = NameServerConfigGroup::with_capacity(nameserver_ips.len() * 2); + name_servers.extend(nameserver_ips.into_iter().flat_map(|socket_addr| { + std::iter::once(NameServerConfig { + socket_addr, + protocol: Protocol::Udp, + tls_dns_name: None, + trust_nx_responses: false, + }) + .chain(std::iter::once(NameServerConfig { + socket_addr, + protocol: Protocol::Tcp, + tls_dns_name: None, + trust_nx_responses: false, + })) + })); + Resolver::new(ResolverConfig::from_parts(None, vec![], name_servers), opts).unwrap() +} - Err(e) => { - error!("Failed to create the resolver. Error: {}\n", e); - std::process::exit(1) +pub fn return_socket_address(args: &Args) -> HashSet { + let mut resolver_ips = HashSet::new(); + if args.custom_resolvers { + for r in &files::return_file_targets(&args, args.resolvers.clone()) { + let server = r.to_owned() + ":53"; + let socket_addr = SocketAddr::V4(match server.parse() { + Ok(a) => a, + Err(e) => unreachable!( + "Error parsing the server {}, only IPv4 are allowed. Error: {}", + r, e + ), + }); + resolver_ips.insert(socket_addr); + } + } else { + for r in &args.resolvers { + let server = r.to_owned() + ":53"; + let socket_addr = SocketAddr::V4(match server.parse() { + Ok(a) => a, + Err(e) => unreachable!( + "Error parsing the server {}, only IPv4 are allowed. Error: {}", + r, e + ), + }); + resolver_ips.insert(socket_addr); } } + resolver_ips } diff --git a/src/resolver_engine.rs b/src/resolver_engine.rs index bc933bd..8bedae9 100644 --- a/src/resolver_engine.rs +++ b/src/resolver_engine.rs @@ -15,7 +15,10 @@ use { net::Ipv4Addr, time::Duration, }, - trust_dns_resolver::config::ResolverOpts, + trust_dns_resolver::{ + config::{LookupIpStrategy, ResolverOpts}, + Resolver, + }, }; lazy_static! { @@ -57,7 +60,16 @@ pub fn parallel_resolver_all(args: &mut Args) -> Result<()> { ) } - let data = parallel_resolver_engine(&args, args.targets.clone()); + let opts = ResolverOpts { + timeout: Duration::from_secs(1), + ip_strategy: LookupIpStrategy::Ipv4Only, + num_concurrent_reqs: 1, + ..Default::default() + }; + + let resolver = networking::get_resolver(networking::return_socket_address(args), opts); + + let data = parallel_resolver_engine(&args, args.targets.clone(), resolver); let mut table = Table::new(); table.set_titles(row![ @@ -186,19 +198,17 @@ pub fn parallel_resolver_all(args: &mut Args) -> Result<()> { Ok(()) } -fn parallel_resolver_engine(args: &Args, targets: HashSet) -> HashMap { - let opts = ResolverOpts { - timeout: Duration::from_secs(2), - ..Default::default() - }; - +fn parallel_resolver_engine( + args: &Args, + targets: HashSet, + resolver: Resolver, +) -> HashMap { let resolv_data: HashMap = targets .par_iter() .map(|target| { let fqdn_target = format!("{}.", target); let mut resolv_data = ResolvData::default(); - resolv_data.ip = - networking::get_records(&networking::get_resolver(&RESOLVERS, &opts), &fqdn_target); + resolv_data.ip = networking::get_records(&resolver, &fqdn_target); (target.to_owned(), resolv_data) }) .collect();