1use super::prelude::*;
17
18use proc_macro2::Group;
19use Equality::*;
20
21#[derive(Debug, Copy, Clone, Eq, PartialEq)]
25pub enum Equality {
26 Equal,
27 Different,
28}
29
30impl Equality {
31 pub fn cmpeq<T: Eq>(a: &T, b: &T) -> Self {
35 if a == b {
36 Equal
37 } else {
38 Different
39 }
40 }
41}
42
43macro_rules! cmpeq {
54 { $a:expr, $b:expr } => {
55 cmpeq!(Equality::cmpeq(&$a, &$b));
56 };
57 { $r:expr } => {
58 if let d @ Different = $r {
59 return Ok(d);
60 }
61 };
62}
63
64pub fn flatten_none_groups(ts: TokenStream) -> TokenStream {
68 fn recurse(out: &mut TokenStream, input: TokenStream) {
69 for tt in input {
70 match tt {
71 TT::Group(g) if g.delimiter() == Delimiter::None => {
72 recurse(out, g.stream());
73 }
74 TT::Group(g) => {
75 let g = group_clone_set_stream(
76 &g,
77 flatten_none_groups(g.stream()),
78 );
79 out.extend([TT::Group(g)]);
80 }
81 _ => out.extend([tt]),
82 }
83 }
84 }
85
86 let mut out = TokenStream::new();
87 recurse(&mut out, ts);
88 out
89}
90
91trait LitComparable {
92 fn lc_compare(
93 a: &Self,
94 b: &Self,
95 cmp_loc: &ErrorLoc<'_>,
96 ) -> syn::Result<Equality>;
97}
98trait LitConvertible {
99 type V: Eq;
100 fn lc_convert(&self, cmp_loc: &ErrorLoc<'_>) -> syn::Result<Self::V>;
101}
102fn str_check_suffix(
103 suffix: &str,
104 span: Span,
105 cmp_loc: &ErrorLoc<'_>,
106) -> syn::Result<()> {
107 if suffix.is_empty() {
108 Ok(())
109 } else {
110 Err([(span, "literal"), *cmp_loc].error(
111 "comparison of string/byte/character literals with suffixes is not supported"
112 ))
113 }
114}
115macro_rules! impl_LitComparable_str { { $lit:ty, $val:ty } => {
116 impl LitConvertible for $lit {
117 type V = $val;
118 fn lc_convert(&self, cmp_loc: &ErrorLoc<'_>) -> syn::Result<Self::V> {
119 str_check_suffix(self.suffix(), self.span(), cmp_loc)?;
120 Ok(self.value())
121 }
122 }
123} }
124
125impl_LitComparable_str!(syn::LitStr, String);
126impl_LitComparable_str!(syn::LitByteStr, Vec<u8>);
127impl_LitComparable_str!(syn::LitByte, u8);
128impl_LitComparable_str!(syn::LitChar, char);
129
130impl<T: LitConvertible> LitComparable for T {
131 fn lc_compare(
132 a: &Self,
133 b: &Self,
134 cmp_loc: &ErrorLoc<'_>,
135 ) -> syn::Result<Equality> {
136 Ok(Equality::cmpeq(
137 &a.lc_convert(cmp_loc)?,
139 &b.lc_convert(cmp_loc)?,
140 ))
141 }
142}
143
144impl LitConvertible for syn::LitBool {
145 type V = ();
146 fn lc_convert(&self, _cmp_loc: &ErrorLoc<'_>) -> syn::Result<Self::V> {
147 Err(self.error(
148 "internal error - TokenTree::Literal parsed as syn::Lit::Bool",
149 ))
150 }
151}
152
153impl LitConvertible for syn::LitFloat {
154 type V = String;
155 fn lc_convert(&self, _cmp_loc: &ErrorLoc<'_>) -> syn::Result<Self::V> {
156 Ok(self.token().to_string())
157 }
158}
159
160impl LitComparable for syn::LitInt {
161 fn lc_compare(
162 a: &Self,
163 b: &Self,
164 cmp_loc: &ErrorLoc<'_>,
165 ) -> syn::Result<Equality> {
166 match (
167 a.base10_parse::<u64>(),
168 b.base10_parse::<u64>(),
169 ) {
170 (Ok(a), Ok(b)) => Ok(Equality::cmpeq(&a, &b)),
171 (Err(ae), Err(be)) => Err(
172 [(a.span(), &*format!("left: {}", ae)),
173 (b.span(), &*format!("right: {}", be)),
174 *cmp_loc,
175 ].error(
176 "integer literal comparison with both values >u64 is not supported"
177 )),
178 (Err(_), Ok(_)) | (Ok(_), Err(_)) => Ok(Different),
179 }
180 }
181}
182
183fn lit_cmpeq(
184 a: &TokenTree,
185 b: &TokenTree,
186 cmp_loc: &ErrorLoc<'_>,
187) -> syn::Result<Equality> {
188 let mk_lit = |tt: &TokenTree| -> syn::Result<syn::Lit> {
189 syn::parse2(tt.clone().into())
190 };
191
192 let a = mk_lit(a)?;
193 let b = mk_lit(b)?;
194
195 syn_lit_cmpeq_approx(a, b, cmp_loc)
196}
197
198pub fn syn_lit_cmpeq_approx(
202 a: syn::Lit,
203 b: syn::Lit,
204 cmp_loc: &ErrorLoc<'_>,
205) -> syn::Result<Equality> {
206 macro_rules! match_lits { { $( $V:ident )* } => {
207 let mut error_locs = vec![];
208 for (lit, lr) in [(&a, "left"), (&b, "right")] {
209 match lit {
210 $(
211 syn::Lit::$V(_) => {}
212 )*
213 _ => error_locs.push((lit.span(), lr)),
214 }
215 }
216 if !error_locs.is_empty() {
217 return Err(error_locs.error(
218 "unsupported literal(s) in approx_equal comparison"
219 ));
220 }
221
222 match (&a, &b) {
223 $(
224 (syn::Lit::$V(a), syn::Lit::$V(b))
225 => LitComparable::lc_compare(a, b, cmp_loc),
226 )*
227 _ => Ok(Different),
228 }
229 } }
230
231 match_lits! {
243 Str
244 ByteStr
245 Byte
246 Char
247 Bool
248 Int
249 Float
250 }
251}
252
253fn tt_cmpeq(
254 a: TokenTree,
255 b: TokenTree,
256 cmp_loc: &ErrorLoc<'_>,
257) -> syn::Result<Equality> {
258 let discrim = |tt: &_| match tt {
259 TT::Punct(_) => 0,
260 TT::Literal(_) => 1,
261 TT::Ident(_) => 2,
262 TT::Group(_) => 3,
263 };
264
265 cmpeq!(discrim(&a), discrim(&b));
266 match (a, b) {
267 (TT::Group(a), TT::Group(b)) => group_cmpeq(a, b, cmp_loc),
268 (a @ TT::Literal(_), b @ TT::Literal(_)) => lit_cmpeq(&a, &b, cmp_loc),
269 (a, b) => Ok(Equality::cmpeq(&a.to_string(), &b.to_string())),
270 }
271}
272
273fn group_cmpeq(
274 a: Group,
275 b: Group,
276 cmp_loc: &ErrorLoc<'_>,
277) -> syn::Result<Equality> {
278 let delim =
279 |g: &Group| Group::new(g.delimiter(), TokenStream::new()).to_string();
280 cmpeq!(delim(&a), delim(&b));
281 ts_cmpeq(a.stream(), b.stream(), cmp_loc)
282}
283
284fn ts_cmpeq(
286 a: TokenStream,
287 b: TokenStream,
288 cmp_loc: &ErrorLoc<'_>,
289) -> syn::Result<Equality> {
290 for ab in a.into_iter().zip_longest(b) {
291 let (a, b) = match ab {
292 EitherOrBoth::Both(a, b) => (a, b),
293 EitherOrBoth::Left(_) => return Ok(Different),
294 EitherOrBoth::Right(_) => return Ok(Different),
295 };
296 match tt_cmpeq(a, b, cmp_loc)? {
297 Equal => {}
298 neq => return Ok(neq),
299 }
300 }
301 return Ok(Equal);
302}
303
304pub fn tokens_cmpeq(
343 a: TokenStream,
344 b: TokenStream,
345 cmp_span: Span,
346) -> syn::Result<Equality> {
347 let a = flatten_none_groups(a);
348 let b = flatten_none_groups(b);
349 ts_cmpeq(a, b, &(cmp_span, "comparison"))
350}