1use std::{
2 borrow::Cow,
3 fmt,
4 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
5};
6
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Deserializer, Serialize, Serializer, de};
9
10#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
14pub enum NetAddr {
15 Inet(IpAddr, u16),
16 Name(Cow<'static, str>, u16),
17}
18
19impl NetAddr {
20 pub fn named<S>(name: S, port: u16) -> Self
21 where
22 S: Into<Cow<'static, str>>,
23 {
24 Self::Name(name.into(), port)
25 }
26
27 pub fn port(&self) -> u16 {
29 match self {
30 Self::Inet(_, p) => *p,
31 Self::Name(_, p) => *p,
32 }
33 }
34
35 pub fn set_port(&mut self, p: u16) {
37 match self {
38 Self::Inet(_, o) => *o = p,
39 Self::Name(_, o) => *o = p,
40 }
41 }
42
43 pub fn with_port(mut self, p: u16) -> Self {
44 match self {
45 Self::Inet(ip, _) => self = Self::Inet(ip, p),
46 Self::Name(hn, _) => self = Self::Name(hn, p),
47 }
48 self
49 }
50
51 pub fn with_offset(mut self, o: u16) -> Self {
52 debug_assert!(self.port().checked_add(o).is_some());
53 match self {
54 Self::Inet(ip, p) => self = Self::Inet(ip, p + o),
55 Self::Name(hn, p) => self = Self::Name(hn, p + o),
56 }
57 self
58 }
59
60 pub fn is_ip(&self) -> bool {
61 matches!(self, Self::Inet(..))
62 }
63
64 pub fn is_probably_global(&self) -> bool {
71 match self {
72 Self::Inet(IpAddr::V4(v4), _) => {
73 !(v4.is_loopback()
74 || v4.is_unspecified()
75 || v4.is_private()
76 || v4.is_link_local()
77 || v4.is_broadcast()
78 || v4.is_documentation())
79 },
80 Self::Inet(IpAddr::V6(v6), _) => {
81 !(v6.is_loopback() || v6.is_unspecified() || v6.is_multicast())
82 },
83 Self::Name(host, _) => !host.eq_ignore_ascii_case("localhost"),
84 }
85 }
86}
87
88impl fmt::Display for NetAddr {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 match self {
91 Self::Inet(a, p) => write!(f, "{a}:{p}"),
92 Self::Name(h, p) => write!(f, "{h}:{p}"),
93 }
94 }
95}
96
97impl From<(&str, u16)> for NetAddr {
98 fn from((h, p): (&str, u16)) -> Self {
99 Self::Name(h.to_string().into(), p)
100 }
101}
102
103impl From<(String, u16)> for NetAddr {
104 fn from((h, p): (String, u16)) -> Self {
105 Self::Name(h.into(), p)
106 }
107}
108
109impl From<(IpAddr, u16)> for NetAddr {
110 fn from((ip, p): (IpAddr, u16)) -> Self {
111 Self::Inet(ip, p)
112 }
113}
114
115impl From<(Ipv4Addr, u16)> for NetAddr {
116 fn from((ip, p): (Ipv4Addr, u16)) -> Self {
117 Self::Inet(IpAddr::V4(ip), p)
118 }
119}
120
121impl From<(Ipv6Addr, u16)> for NetAddr {
122 fn from((ip, p): (Ipv6Addr, u16)) -> Self {
123 Self::Inet(IpAddr::V6(ip), p)
124 }
125}
126
127impl From<SocketAddr> for NetAddr {
128 fn from(a: SocketAddr) -> Self {
129 Self::Inet(a.ip(), a.port())
130 }
131}
132
133impl std::str::FromStr for NetAddr {
134 type Err = InvalidNetAddr;
135
136 fn from_str(s: &str) -> Result<Self, Self::Err> {
137 if s.is_empty() {
138 return Err(InvalidNetAddr(()));
139 }
140
141 let parse = |a: &str, p: Option<&str>| {
142 let p: u16 = if let Some(p) = p {
143 p.parse().map_err(|_| InvalidNetAddr(()))?
144 } else {
145 0
146 };
147 let a = if a.starts_with('[') && a.ends_with(']') {
149 &a[1..a.len() - 1]
150 } else {
151 a
152 };
153 IpAddr::from_str(a)
154 .map(|a| Self::Inet(a, p))
155 .or_else(|_| Ok(Self::Name(a.to_string().into(), p)))
156 };
157
158 if s.starts_with('[') {
160 return match s.rfind("]:") {
161 Some(i) => parse(&s[..i + 1], Some(&s[i + 2..])),
162 None => parse(s, None),
163 };
164 }
165
166 match s.rsplit_once(':') {
167 None => parse(s, None),
168 Some((a, p)) => parse(a, Some(p)),
169 }
170 }
171}
172
173impl TryFrom<&str> for NetAddr {
174 type Error = InvalidNetAddr;
175
176 fn try_from(val: &str) -> Result<Self, Self::Error> {
177 val.parse()
178 }
179}
180
181#[derive(Debug, Clone, thiserror::Error)]
182#[error("invalid network address")]
183pub struct InvalidNetAddr(());
184
185#[cfg(feature = "serde")]
188impl Serialize for NetAddr {
189 fn serialize<S: Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
190 self.to_string().serialize(s)
191 }
192}
193
194#[cfg(feature = "serde")]
195impl<'de> Deserialize<'de> for NetAddr {
196 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
197 let s = String::deserialize(d)?;
198 let a = s.parse().map_err(de::Error::custom)?;
199 Ok(a)
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use std::{iter::repeat_with, net::IpAddr};
206
207 use quickcheck::{Arbitrary, Gen, quickcheck};
208
209 use super::NetAddr;
210
211 impl Arbitrary for NetAddr {
212 fn arbitrary(g: &mut Gen) -> Self {
213 let port = u16::arbitrary(g);
214 if bool::arbitrary(g) {
215 let len = u8::arbitrary(g);
216 let host: String = repeat_with(|| char::arbitrary(g))
217 .filter(|c| !"[]".contains(*c))
218 .take(len.into())
219 .collect();
220 NetAddr::Name(host.into(), port)
221 } else {
222 let ip = IpAddr::arbitrary(g);
223 NetAddr::Inet(ip, port)
224 }
225 }
226 }
227
228 quickcheck! {
229 fn prop_to_string_parse_identity(a: NetAddr) -> bool {
230 a.to_string().parse().ok() == Some(a)
231 }
232 }
233
234 #[test]
235 fn empty_is_invalid() {
236 assert!("".parse::<NetAddr>().is_err())
237 }
238
239 #[test]
240 fn test_is_probably_global() {
241 let cases: &[(&str, bool)] = &[
242 ("127.0.0.1:1234", false),
243 ("0.0.0.0:1234", false),
244 ("10.0.0.1:1234", false),
245 ("172.16.5.4:1234", false),
246 ("192.168.1.1:1234", false),
247 ("169.254.0.1:1234", false),
248 ("255.255.255.255:1234", false),
249 ("192.0.2.1:1234", false),
250 ("::1:1234", false),
251 (":::1234", false),
252 ("ff00::1:1234", false),
253 ("localhost:1234", false),
254 ("LOCALHOST:1234", false),
255 ("8.8.8.8:1234", true),
256 ("1.1.1.1:1234", true),
257 ("2606:4700:4700::1111:1234", true),
258 ("example.com:1234", true),
259 ("node.internal:1234", true),
260 ];
261 for (s, expected) in cases {
262 let a: NetAddr = s.parse().unwrap_or_else(|_| panic!("parse {s}"));
263 assert_eq!(a.is_probably_global(), *expected, "for input {s}");
264 }
265 }
266}