Skip to content

Commit

Permalink
✨ Add symlink subcommand and support for multiple query types (#423)
Browse files Browse the repository at this point in the history
1. Create smylink: `./smartdns symlink ./dig`

2. Query A and AAAA record: `./dig example.com a+aaaa`
  • Loading branch information
mokeyish authored Nov 3, 2024
1 parent de2df7a commit 27eb918
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 28 deletions.
14 changes: 13 additions & 1 deletion src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ pub struct Cli {

impl Cli {
pub fn parse() -> Self {
#[cfg(feature = "resolve-cli")]
if ResolveCommand::is_resolve_cli() {
return ResolveCommand::parse().into();
}

match Self::try_parse() {
Ok(cli) => cli,
Err(e) => {
Expand Down Expand Up @@ -96,10 +101,17 @@ pub enum Commands {
command: ServiceCommands,
},

/// Perform DNS resolution. Can be used in place of the standard OS resolution facilities.
/// Perform DNS resolution.
#[cfg(feature = "resolve-cli")]
Resolve(ResolveCommand),

/// Create a symbolic link to the Smart-DNS binary (drop-in replacement for `dig`, `nslookup`, `resolve` etc.)
#[cfg(feature = "resolve-cli")]
Symlink {
/// The path to the symlink to create.
link: std::path::PathBuf,
},

/// Test configuration and exit
Test {
/// Config file
Expand Down
19 changes: 19 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,25 @@ impl Cli {
drop(_guard);
command.execute();
}
#[cfg(all(feature = "resolve-cli", any(unix, windows)))]
Commands::Symlink { link } => {
let original = std::env::current_exe().expect("failed to get current exe path");
if link.exists() {
println!("link already exists");
return;
}

#[cfg(unix)]
let res = std::os::unix::fs::symlink(original, link);

#[cfg(windows)]
let res = std::os::windows::fs::symlink_file(original, link);

match res {
Ok(()) => println!("symlink created"),
Err(err) => println!("failed to create symlink, {}", err),
}
}
#[allow(unreachable_patterns)]
_ => {
unimplemented!()
Expand Down
117 changes: 90 additions & 27 deletions src/resolver.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::collections::HashMap;
use std::path::Path;
use std::{ops::Deref, str::FromStr, time::Duration};

use clap::Parser;
Expand All @@ -7,10 +8,7 @@ use console::{style, StyledObject};

use crate::libdns::proto::{
op::Message,
rr::{
DNSClass as QueryClass, Name as Domain, Record, RecordData, RecordType as QueryType,
RecordType,
},
rr::{DNSClass as QueryClass, Name as Domain, Record, RecordData, RecordType},
xfer::Protocol as DnsOverProtocol,
};

Expand All @@ -30,7 +28,7 @@ impl ResolveCommand {
}
}
let domain = self.domain().clone();
let query_type = self.q_type();
let query_types = self.q_type();

let palette = Colours::pretty();

Expand All @@ -50,17 +48,19 @@ impl ResolveCommand {
DnsClient::builder().build().await
};

let options = LookupOptions {
record_type: query_type,
..Default::default()
};

match dns_client.lookup(domain, options).await {
Ok(res) => {
print(&res, &palette);
}
Err(err) => {
println!("{}", err);
for query_type in query_types {
let options = LookupOptions {
record_type: *query_type,
..Default::default()
};

match dns_client.lookup(domain.clone(), options).await {
Ok(res) => {
print(&res, &palette);
}
Err(err) => {
println!("{}", err);
}
}
}
});
Expand Down Expand Up @@ -100,7 +100,7 @@ pub struct ResolveCommand {

/// is one of (a,any,mx,ns,soa,hinfo,axfr,txt,...)
#[arg(value_name = "q-type", default_value = "a", value_parser = Self::parse_query_type)]
q_type: QueryType,
q_type: QueryTypes,

/// is one of (in,hs,ch,...)
#[arg(value_name = "q-class", default_value = "in", value_parser = Self::parse_query_class)]
Expand All @@ -112,13 +112,26 @@ pub struct ResolveCommand {
}

impl ResolveCommand {
pub fn parse() -> Self {
match Parser::try_parse() {
Ok(cli) => cli,
Err(e) => {
if let Ok(resolve_command) = ResolveCommand::try_parse() {
return resolve_command;
}
e.exit()
}
}
}

pub fn try_parse() -> Result<Self, String> {
use DnsOverProtocol::*;
let mut proto = None;
let mut q_type = None;
let mut q_types = vec![];
let mut q_class = None;
let mut domain = None;
let mut global_server = None;
let mut prev_parsing_qtype = false;

for arg in std::env::args().skip(1) {
if arg == "resolve" {
Expand Down Expand Up @@ -159,11 +172,18 @@ impl ResolveCommand {
continue;
}

if q_type.is_none() {
if let Ok(t) = Self::parse_query_type(arg.as_str()) {
q_type = Some(t);
if q_types.is_empty() {
if let Ok(t) = Self::parse_query_type(&arg) {
q_types = t.0;
prev_parsing_qtype = true;
continue;
}
} else if prev_parsing_qtype {
if let Ok(t) = Self::parse_query_type(&arg) {
q_types.extend(t.0);
continue;
}
prev_parsing_qtype = false;
}

if q_class.is_none() {
Expand All @@ -179,14 +199,17 @@ impl ResolveCommand {
continue;
}
}

return Err(format!("Invalid argument {arg}"));
}

let Some(domain) = domain else {
return Err("domain is required".to_string());
};

let q_type = q_type.unwrap_or(QueryType::A);
if q_types.is_empty() {
q_types.push(RecordType::A);
}
let q_class = q_class.unwrap_or(QueryClass::IN);

Ok(Self {
Expand All @@ -198,11 +221,27 @@ impl ResolveCommand {
h3: matches!(proto, Some(H3)),
global_server,
domain,
q_type,
q_type: QueryTypes(q_types),
q_class,
})
}

pub fn is_resolve_cli() -> bool {
std::env::args()
.next()
.as_deref()
.map(Path::new)
.and_then(|s| s.file_stem())
.and_then(|s| s.to_str())
.map(|s| match s {
"dig" => true,
"nslookup" => true,
"resolve" => true,
_ => false,
})
.unwrap_or_default()
}

pub fn proto(&self) -> Option<DnsOverProtocol> {
use DnsOverProtocol::*;
if self.udp {
Expand Down Expand Up @@ -230,8 +269,8 @@ impl ResolveCommand {
&self.domain
}

pub fn q_type(&self) -> QueryType {
self.q_type
pub fn q_type(&self) -> &[RecordType] {
&self.q_type.0
}

pub fn q_class(&self) -> QueryClass {
Expand All @@ -245,14 +284,38 @@ impl ResolveCommand {
Err(format!("Invalid global server: {}", s))
}
}
fn parse_query_type(s: &str) -> Result<QueryType, String> {
QueryType::from_str(s.to_uppercase().as_str()).map_err(|e| e.to_string())
fn parse_query_type(s: &str) -> Result<QueryTypes, String> {
if s.contains("+") {
let mut types = Vec::new();
let mut last_err = None;
for t in s.split('+') {
match RecordType::from_str(t.to_uppercase().as_str()) {
Ok(t) => types.push(t),
Err(err) => last_err = Some(err),
}
}

if types.is_empty() {
return Err(last_err
.map(|e| e.to_string())
.unwrap_or("Invalid query type".to_string()));
}

Ok(QueryTypes(types))
} else {
RecordType::from_str(s.to_uppercase().as_str())
.map(|q| QueryTypes(vec![q]))
.map_err(|e| e.to_string())
}
}
fn parse_query_class(s: &str) -> Result<QueryClass, String> {
QueryClass::from_str(s.to_uppercase().as_str()).map_err(|e| e.to_string())
}
}

#[derive(Debug, Clone)]
struct QueryTypes(Vec<RecordType>);

fn print(message: &Message, palette: &Colours) {
for r in message.answers() {
print_record(&r, palette);
Expand Down

0 comments on commit 27eb918

Please sign in to comment.