@@ -202,23 +202,53 @@ impl<'fmt, 'ast, 'buf> JoinNodesBuilder<'fmt, 'ast, 'buf> {
202
202
}
203
203
}
204
204
205
+ #[ derive( Copy , Clone , Debug ) ]
206
+ enum Entries {
207
+ /// No previous entry
208
+ None ,
209
+ /// One previous ending at the given position.
210
+ One ( TextSize ) ,
211
+ /// More than one entry, the last one ending at the specific position.
212
+ MoreThanOne ( TextSize ) ,
213
+ }
214
+
215
+ impl Entries {
216
+ fn position ( self ) -> Option < TextSize > {
217
+ match self {
218
+ Entries :: None => None ,
219
+ Entries :: One ( position) | Entries :: MoreThanOne ( position) => Some ( position) ,
220
+ }
221
+ }
222
+
223
+ const fn is_one_or_more ( self ) -> bool {
224
+ !matches ! ( self , Entries :: None )
225
+ }
226
+
227
+ const fn is_more_than_one ( self ) -> bool {
228
+ matches ! ( self , Entries :: MoreThanOne ( _) )
229
+ }
230
+
231
+ const fn next ( self , end_position : TextSize ) -> Self {
232
+ match self {
233
+ Entries :: None => Entries :: One ( end_position) ,
234
+ Entries :: One ( _) | Entries :: MoreThanOne ( _) => Entries :: MoreThanOne ( end_position) ,
235
+ }
236
+ }
237
+ }
238
+
205
239
pub ( crate ) struct JoinCommaSeparatedBuilder < ' fmt , ' ast , ' buf > {
206
240
result : FormatResult < ( ) > ,
207
241
fmt : & ' fmt mut PyFormatter < ' ast , ' buf > ,
208
- end_of_last_entry : Option < TextSize > ,
242
+ entries : Entries ,
209
243
sequence_end : TextSize ,
210
- /// We need to track whether we have more than one entry since a sole entry doesn't get a
211
- /// magic trailing comma even when expanded
212
- len : usize ,
213
244
}
214
245
215
246
impl < ' fmt , ' ast , ' buf > JoinCommaSeparatedBuilder < ' fmt , ' ast , ' buf > {
216
247
fn new ( f : & ' fmt mut PyFormatter < ' ast , ' buf > , sequence_end : TextSize ) -> Self {
217
248
Self {
218
249
fmt : f,
219
250
result : Ok ( ( ) ) ,
220
- end_of_last_entry : None ,
221
- len : 0 ,
251
+ entries : Entries :: None ,
222
252
sequence_end,
223
253
}
224
254
}
@@ -245,12 +275,11 @@ impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
245
275
Separator : Format < PyFormatContext < ' ast > > ,
246
276
{
247
277
self . result = self . result . and_then ( |_| {
248
- if self . end_of_last_entry . is_some ( ) {
278
+ if self . entries . is_one_or_more ( ) {
249
279
write ! ( self . fmt, [ text( "," ) , separator] ) ?;
250
280
}
251
281
252
- self . end_of_last_entry = Some ( node. end ( ) ) ;
253
- self . len += 1 ;
282
+ self . entries = self . entries . next ( node. end ( ) ) ;
254
283
255
284
content. fmt ( self . fmt )
256
285
} ) ;
@@ -286,7 +315,7 @@ impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
286
315
287
316
pub ( crate ) fn finish ( & mut self ) -> FormatResult < ( ) > {
288
317
self . result . and_then ( |_| {
289
- if let Some ( last_end) = self . end_of_last_entry . take ( ) {
318
+ if let Some ( last_end) = self . entries . position ( ) {
290
319
let magic_trailing_comma = match self . fmt . options ( ) . magic_trailing_comma ( ) {
291
320
MagicTrailingComma :: Respect => {
292
321
let first_token = SimpleTokenizer :: new (
@@ -310,7 +339,7 @@ impl<'fmt, 'ast, 'buf> JoinCommaSeparatedBuilder<'fmt, 'ast, 'buf> {
310
339
311
340
// If there is a single entry, only keep the magic trailing comma, don't add it if
312
341
// it wasn't there. If there is more than one entry, always add it.
313
- if magic_trailing_comma || self . len > 1 {
342
+ if magic_trailing_comma || self . entries . is_more_than_one ( ) {
314
343
if_group_breaks ( & text ( "," ) ) . fmt ( self . fmt ) ?;
315
344
}
316
345
0 commit comments