Skip to content

Commit b5f238b

Browse files
dirkbonhommeartembilan
authored andcommitted
GH-111: Implement batch mode for Kcl adapter
Fixes #111
1 parent c601324 commit b5f238b

File tree

3 files changed

+185
-85
lines changed

3 files changed

+185
-85
lines changed

src/main/java/org/springframework/integration/aws/inbound/kinesis/KclMessageDrivenChannelAdapter.java

+118-39
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,17 @@
1616

1717
package org.springframework.integration.aws.inbound.kinesis;
1818

19+
import java.util.ArrayList;
1920
import java.util.Arrays;
2021
import java.util.List;
2122
import java.util.UUID;
23+
import java.util.stream.Collectors;
24+
25+
import javax.annotation.Nullable;
2226

2327
import org.springframework.core.AttributeAccessor;
28+
import org.springframework.core.convert.converter.Converter;
29+
import org.springframework.core.serializer.support.DeserializingConverter;
2430
import org.springframework.core.task.SimpleAsyncTaskExecutor;
2531
import org.springframework.core.task.TaskExecutor;
2632
import org.springframework.core.task.support.ExecutorServiceAdapter;
@@ -101,6 +107,10 @@ public class KclMessageDrivenChannelAdapter extends MessageProducerSupport {
101107

102108
private int consumerBackoff;
103109

110+
private Converter<byte[], Object> converter = new DeserializingConverter();
111+
112+
private ListenerMode listenerMode = ListenerMode.record;
113+
104114
private long checkpointsInterval = 5_000L;
105115

106116
private CheckpointMode checkpointMode = CheckpointMode.batch;
@@ -167,6 +177,20 @@ public void setConsumerBackoff(int consumerBackoff) {
167177
this.consumerBackoff = Math.max(1000, consumerBackoff);
168178
}
169179

180+
/**
181+
* Specify a {@link Converter} to deserialize the {@code byte[]} from record's body.
182+
* Can be {@code null} meaning no deserialization.
183+
* @param converter the {@link Converter} to use or null
184+
*/
185+
public void setConverter(Converter<byte[], Object> converter) {
186+
this.converter = converter;
187+
}
188+
189+
public void setListenerMode(ListenerMode listenerMode) {
190+
Assert.notNull(listenerMode, "'listenerMode' must not be null");
191+
this.listenerMode = listenerMode;
192+
}
193+
170194
/**
171195
* Sets the interval between 2 checkpoints.
172196
* @param checkpointsInterval interval between 2 checkpoints (in milliseconds)
@@ -226,6 +250,13 @@ KinesisClientLibConfiguration.DEFAULT_CLEANUP_LEASES_UPON_SHARDS_COMPLETION, new
226250
@Override
227251
protected void doStart() {
228252
super.doStart();
253+
254+
if (ListenerMode.batch.equals(this.listenerMode) && CheckpointMode.record.equals(this.checkpointMode)) {
255+
this.checkpointMode = CheckpointMode.batch;
256+
logger.warn("The 'checkpointMode' is overridden from [CheckpointMode.record] to [CheckpointMode.batch] "
257+
+ "because it does not make sense in case of [ListenerMode.batch].");
258+
}
259+
229260
this.executor.execute(this.scheduler);
230261
}
231262

@@ -288,47 +319,63 @@ public void processRecords(List<Record> records, IRecordProcessorCheckpointer ch
288319
if (logger.isDebugEnabled()) {
289320
logger.debug("Processing " + records.size() + " records from " + this.shardId);
290321
}
291-
for (Record record : records) {
292-
try {
293-
processSingleRecord(record, checkpointer);
294-
}
295-
catch (Throwable t) {
296-
logger.warn("Caught throwable while processing record " + record, t);
297-
}
298-
finally {
299-
attributesHolder.remove();
300-
// Checkpoint once every checkpoint interval.
301-
if (CheckpointMode.periodic.equals(KclMessageDrivenChannelAdapter.this.checkpointMode) &&
302-
System.currentTimeMillis() > nextCheckpointTimeInMillis) {
303-
checkpoint(checkpointer);
304-
this.nextCheckpointTimeInMillis = System.currentTimeMillis() + checkpointsInterval;
322+
323+
try {
324+
if (ListenerMode.record.equals(KclMessageDrivenChannelAdapter.this.listenerMode)) {
325+
for (Record record : records) {
326+
processSingleRecord(record, checkpointer);
327+
checkpointIfRecordMode(checkpointer, record);
328+
checkpointIfPeriodicMode(checkpointer, record);
305329
}
306330
}
331+
else if (ListenerMode.batch.equals(KclMessageDrivenChannelAdapter.this.listenerMode)) {
332+
processMultipleRecords(records, checkpointer);
333+
checkpointIfPeriodicMode(checkpointer, null);
334+
}
335+
checkpointIfBatchMode(checkpointer);
307336
}
308-
309-
// checkpoint if needed
310-
if (CheckpointMode.batch.equals(KclMessageDrivenChannelAdapter.this.checkpointMode)) {
311-
checkpoint(checkpointer);
337+
finally {
338+
attributesHolder.remove();
312339
}
313340
}
314341

315-
/**
316-
* Process a single record.
317-
* @param record The record to be processed.
318-
* @param checkpointer the checkpointer to use if the checkpointMode is record
319-
*/
320342
private void processSingleRecord(Record record, IRecordProcessorCheckpointer checkpointer) {
321-
// Convert AWS Record in Spring Message.
322-
performSend(prepareMessageForRecord(record, checkpointer), record);
343+
performSend(prepareMessageForRecord(record), record, checkpointer);
344+
}
323345

324-
// checkpoint if needed
325-
if (CheckpointMode.record.equals(KclMessageDrivenChannelAdapter.this.checkpointMode)) {
326-
checkpoint(checkpointer);
346+
private void processMultipleRecords(List<Record> records, IRecordProcessorCheckpointer checkpointer) {
347+
Object payload = records;
348+
349+
if (KclMessageDrivenChannelAdapter.this.embeddedHeadersMapper != null) {
350+
payload = records.stream().map(this::prepareMessageForRecord).collect(Collectors.toList());
351+
}
352+
353+
final List<String> partitionKeys;
354+
final List<String> sequenceNumbers;
355+
if (KclMessageDrivenChannelAdapter.this.converter != null) {
356+
partitionKeys = new ArrayList<>();
357+
sequenceNumbers = new ArrayList<>();
358+
359+
payload = records.stream().map(r -> {
360+
partitionKeys.add(r.getPartitionKey());
361+
sequenceNumbers.add(r.getSequenceNumber());
362+
363+
return KclMessageDrivenChannelAdapter.this.converter.convert(r.getData().array());
364+
}).collect(Collectors.toList());
365+
}
366+
else {
367+
partitionKeys = null;
368+
sequenceNumbers = null;
327369
}
370+
371+
AbstractIntegrationMessageBuilder<?> messageBuilder = getMessageBuilderFactory().withPayload(payload)
372+
.setHeader(AwsHeaders.RECEIVED_PARTITION_KEY, partitionKeys)
373+
.setHeader(AwsHeaders.RECEIVED_SEQUENCE_NUMBER, sequenceNumbers);
374+
375+
performSend(messageBuilder, records, checkpointer);
328376
}
329377

330-
private AbstractIntegrationMessageBuilder<Object> prepareMessageForRecord(Record record,
331-
IRecordProcessorCheckpointer checkpointer) {
378+
private AbstractIntegrationMessageBuilder<Object> prepareMessageForRecord(Record record) {
332379
Object payload = record.getData().array();
333380
Message<?> messageToUse = null;
334381

@@ -347,11 +394,13 @@ private AbstractIntegrationMessageBuilder<Object> prepareMessageForRecord(Record
347394
}
348395
}
349396

397+
if (payload instanceof byte[] && KclMessageDrivenChannelAdapter.this.converter != null) {
398+
payload = KclMessageDrivenChannelAdapter.this.converter.convert((byte[]) payload);
399+
}
400+
350401
AbstractIntegrationMessageBuilder<Object> messageBuilder = getMessageBuilderFactory().withPayload(payload)
351402
.setHeader(AwsHeaders.RECEIVED_PARTITION_KEY, record.getPartitionKey())
352-
.setHeader(AwsHeaders.RECEIVED_SEQUENCE_NUMBER, record.getSequenceNumber())
353-
.setHeader(AwsHeaders.RECEIVED_STREAM, KclMessageDrivenChannelAdapter.this.stream)
354-
.setHeader(AwsHeaders.SHARD, this.shardId);
403+
.setHeader(AwsHeaders.RECEIVED_SEQUENCE_NUMBER, record.getSequenceNumber());
355404

356405
if (KclMessageDrivenChannelAdapter.this.bindSourceRecord) {
357406
messageBuilder.setHeader(IntegrationMessageHeaderAccessor.SOURCE_DATA, record);
@@ -361,14 +410,18 @@ private AbstractIntegrationMessageBuilder<Object> prepareMessageForRecord(Record
361410
messageBuilder.copyHeadersIfAbsent(messageToUse.getHeaders());
362411
}
363412

413+
return messageBuilder;
414+
}
415+
416+
private void performSend(AbstractIntegrationMessageBuilder<?> messageBuilder, Object rawRecord,
417+
IRecordProcessorCheckpointer checkpointer) {
418+
messageBuilder.setHeader(AwsHeaders.RECEIVED_STREAM, KclMessageDrivenChannelAdapter.this.stream)
419+
.setHeader(AwsHeaders.SHARD, this.shardId);
420+
364421
if (CheckpointMode.manual.equals(KclMessageDrivenChannelAdapter.this.checkpointMode)) {
365422
messageBuilder.setHeader(AwsHeaders.CHECKPOINTER, checkpointer);
366423
}
367424

368-
return messageBuilder;
369-
}
370-
371-
private void performSend(AbstractIntegrationMessageBuilder<?> messageBuilder, Object rawRecord) {
372425
Message<?> messageToSend = messageBuilder.build();
373426
setAttributesIfNecessary(rawRecord, messageToSend);
374427
try {
@@ -397,13 +450,19 @@ private void setAttributesIfNecessary(Object record, Message<?> message) {
397450
/**
398451
* Checkpoint with retries.
399452
* @param checkpointer checkpointer
453+
* @param record last processed record
400454
*/
401-
private void checkpoint(IRecordProcessorCheckpointer checkpointer) {
455+
private void checkpoint(IRecordProcessorCheckpointer checkpointer, @Nullable Record record) {
402456
if (logger.isInfoEnabled()) {
403457
logger.info("Checkpointing shard " + shardId);
404458
}
405459
try {
406-
checkpointer.checkpoint();
460+
if (record == null) {
461+
checkpointer.checkpoint();
462+
}
463+
else {
464+
checkpointer.checkpoint(record);
465+
}
407466
}
408467
catch (ShutdownException se) {
409468
// Ignore checkpoint if the processor instance has been shutdown (fail
@@ -424,6 +483,26 @@ private void checkpoint(IRecordProcessorCheckpointer checkpointer) {
424483
}
425484
}
426485

486+
private void checkpointIfBatchMode(IRecordProcessorCheckpointer checkpointer) {
487+
if (CheckpointMode.batch.equals(KclMessageDrivenChannelAdapter.this.checkpointMode)) {
488+
checkpoint(checkpointer, null);
489+
}
490+
}
491+
492+
private void checkpointIfRecordMode(IRecordProcessorCheckpointer checkpointer, Record record) {
493+
if (CheckpointMode.record.equals(KclMessageDrivenChannelAdapter.this.checkpointMode)) {
494+
checkpoint(checkpointer, record);
495+
}
496+
}
497+
498+
private void checkpointIfPeriodicMode(IRecordProcessorCheckpointer checkpointer, @Nullable Record record) {
499+
if (CheckpointMode.periodic.equals(KclMessageDrivenChannelAdapter.this.checkpointMode)
500+
&& System.currentTimeMillis() > nextCheckpointTimeInMillis) {
501+
checkpoint(checkpointer, record);
502+
this.nextCheckpointTimeInMillis = System.currentTimeMillis() + checkpointsInterval;
503+
}
504+
}
505+
427506
@Override
428507
public void shutdown(IRecordProcessorCheckpointer checkpointer, ShutdownReason reason) {
429508
if (logger.isInfoEnabled()) {

0 commit comments

Comments
 (0)