Skip to content

Commit 39ecbda

Browse files
Jithendar12ritiktrianz
authored andcommitted
Replace Confluent ProtobufSchema parser with protoc compiler (awslabs#2796)
1 parent bf953ad commit 39ecbda

File tree

7 files changed

+410
-116
lines changed

7 files changed

+410
-116
lines changed

athena-msk/pom.xml

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -43,28 +43,6 @@
4343
<version>2.1.21</version>
4444
<scope>runtime</scope>
4545
</dependency>
46-
<dependency>
47-
<groupId>com.squareup.wire</groupId>
48-
<artifactId>wire-schema</artifactId>
49-
<version>5.3.3</version>
50-
</dependency>
51-
<dependency>
52-
<groupId>com.squareup.wire</groupId>
53-
<artifactId>wire-schema-jvm</artifactId>
54-
<version>4.9.0</version>
55-
</dependency>
56-
<dependency>
57-
<groupId>com.squareup.wire</groupId>
58-
<artifactId>wire-runtime-jvm</artifactId>
59-
<version>5.3.3</version>
60-
<scope>runtime</scope>
61-
</dependency>
62-
<dependency>
63-
<groupId>com.squareup.wire</groupId>
64-
<artifactId>wire-compiler</artifactId>
65-
<version>5.3.3</version>
66-
<scope>runtime</scope>
67-
</dependency>
6846
<dependency>
6947
<groupId>software.amazon.msk</groupId>
7048
<artifactId>aws-msk-iam-auth</artifactId>
@@ -83,7 +61,7 @@
8361
<dependency>
8462
<groupId>org.apache.kafka</groupId>
8563
<artifactId>kafka-clients</artifactId>
86-
<version>7.9.1-ce</version>
64+
<version>4.0.0</version>
8765
</dependency>
8866
<dependency>
8967
<groupId>org.apache.avro</groupId>
@@ -95,16 +73,17 @@
9573
<artifactId>protobuf-java</artifactId>
9674
<version>3.25.5</version>
9775
</dependency>
76+
<dependency>
77+
<groupId>com.github.os72</groupId>
78+
<artifactId>protoc-jar</artifactId>
79+
<version>3.11.4</version>
80+
</dependency>
81+
9882
<dependency>
9983
<groupId>software.amazon.glue</groupId>
10084
<artifactId>schema-registry-serde</artifactId>
10185
<version>1.1.23</version>
10286
</dependency>
103-
<dependency>
104-
<groupId>io.confluent</groupId>
105-
<artifactId>kafka-protobuf-provider</artifactId>
106-
<version>7.7.2</version>
107-
</dependency>
10887
<dependency>
10988
<groupId>com.fasterxml.jackson.core</groupId>
11089
<artifactId>jackson-annotations</artifactId>
@@ -274,10 +253,4 @@
274253
</plugin>
275254
</plugins>
276255
</build>
277-
<repositories>
278-
<repository>
279-
<id>confluent</id>
280-
<url>https://packages.confluent.io/maven/</url>
281-
</repository>
282-
</repositories>
283256
</project>

athena-msk/src/main/java/com/amazonaws/athena/connectors/msk/AmazonMskMetadataHandler.java

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,11 @@
3838
import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse;
3939
import com.amazonaws.athena.connector.util.PaginatedRequestIterator;
4040
import com.amazonaws.athena.connectors.msk.dto.AvroTopicSchema;
41+
import com.amazonaws.athena.connectors.msk.dto.MSKField;
4142
import com.amazonaws.athena.connectors.msk.dto.SplitParameters;
4243
import com.amazonaws.athena.connectors.msk.dto.TopicPartitionPiece;
4344
import com.amazonaws.athena.connectors.msk.dto.TopicSchema;
4445
import com.google.common.annotations.VisibleForTesting;
45-
import com.google.protobuf.Descriptors;
46-
import io.confluent.kafka.schemaregistry.protobuf.ProtobufSchema;
4746
import org.apache.arrow.vector.types.pojo.Field;
4847
import org.apache.arrow.vector.types.pojo.FieldType;
4948
import org.apache.arrow.vector.types.pojo.Schema;
@@ -460,19 +459,21 @@ private Schema getSchema(String glueRegistryName, String glueSchemaName) throws
460459
schemaBuilder.addMetadata("dataFormat", AVRO_DATA_FORMAT);
461460
}
462461
else if (dataFormat.equalsIgnoreCase(PROTOBUF_DATA_FORMAT)) {
463-
// Get protobuf topic schema from Glue registry
464-
String glueSchema = registryReader.getSchemaDef(glueRegistryName, glueSchemaName);
465-
ProtobufSchema protobufSchema = new ProtobufSchema(glueSchema);
466-
Descriptors.Descriptor descriptor = protobufSchema.toDescriptor();
467-
// Creating ArrowType for each field in the protobuf topic schema.
468-
for (Descriptors.FieldDescriptor fieldDescriptor : descriptor.getFields()) {
462+
// Get protobuf topic schema from Glue registry using protoc compiler
463+
List<MSKField> fields = registryReader.getProtobufFields(glueRegistryName, glueSchemaName);
464+
// Creating ArrowType for each field in the protobuf topic schema
465+
for (MSKField field : fields) {
469466
FieldType fieldType = new FieldType(
470467
true,
471-
AmazonMskUtils.toArrowType(fieldDescriptor.getType().toString()),
472-
null
468+
AmazonMskUtils.toArrowType(field.getType()),
469+
null,
470+
com.google.common.collect.ImmutableMap.of(
471+
"name", field.getName(),
472+
"type", field.getType()
473+
)
473474
);
474-
Field field = new Field(fieldDescriptor.getName(), fieldType, null);
475-
schemaBuilder.addField(field);
475+
Field arrowField = new Field(field.getName(), fieldType, null);
476+
schemaBuilder.addField(arrowField);
476477
}
477478
schemaBuilder.addMetadata("dataFormat", PROTOBUF_DATA_FORMAT);
478479
}

athena-msk/src/main/java/com/amazonaws/athena/connectors/msk/AmazonMskUtils.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ enum AuthType {
100100
private static final String KAFKA_MAX_PARTITION_FETCH_BYTES_CONFIG = "max.partition.fetch.bytes";
101101
private static final String KAFKA_KEY_DESERIALIZER_CLASS_CONFIG = "key.deserializer";
102102
private static final String KAFKA_VALUE_DESERIALIZER_CLASS_CONFIG = "value.deserializer";
103-
104103
private static final ObjectMapper objectMapper = new ObjectMapper();
105104

106105
private AmazonMskUtils() {}

athena-msk/src/main/java/com/amazonaws/athena/connectors/msk/GlueRegistryReader.java

Lines changed: 115 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,26 +19,139 @@
1919
*/
2020
package com.amazonaws.athena.connectors.msk;
2121

22+
import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException;
23+
import com.amazonaws.athena.connectors.msk.dto.MSKField;
2224
import com.fasterxml.jackson.databind.DeserializationFeature;
2325
import com.fasterxml.jackson.databind.ObjectMapper;
26+
import com.github.os72.protocjar.Protoc;
27+
import com.google.protobuf.DescriptorProtos.DescriptorProto;
28+
import com.google.protobuf.DescriptorProtos.FieldDescriptorProto;
29+
import com.google.protobuf.DescriptorProtos.FileDescriptorSet;
30+
import org.slf4j.Logger;
31+
import org.slf4j.LoggerFactory;
2432
import software.amazon.awssdk.services.glue.GlueClient;
33+
import software.amazon.awssdk.services.glue.model.ErrorDetails;
34+
import software.amazon.awssdk.services.glue.model.FederationSourceErrorCode;
2535
import software.amazon.awssdk.services.glue.model.GetSchemaRequest;
2636
import software.amazon.awssdk.services.glue.model.GetSchemaResponse;
2737
import software.amazon.awssdk.services.glue.model.GetSchemaVersionRequest;
2838
import software.amazon.awssdk.services.glue.model.GetSchemaVersionResponse;
2939
import software.amazon.awssdk.services.glue.model.SchemaId;
3040
import software.amazon.awssdk.services.glue.model.SchemaVersionNumber;
3141

42+
import java.io.FileInputStream;
43+
import java.io.IOException;
44+
import java.nio.file.Files;
45+
import java.nio.file.Path;
46+
import java.nio.file.Paths;
47+
import java.util.ArrayList;
48+
import java.util.List;
49+
import java.util.UUID;
50+
3251
public class GlueRegistryReader
3352
{
53+
private static final Logger logger = LoggerFactory.getLogger(GlueRegistryReader.class);
3454
private static final ObjectMapper objectMapper;
55+
private static final String PROTO_FILE = "schema.proto";
56+
private static final String DESC_FILE = "schema.desc";
3557

3658
static {
3759
objectMapper = new ObjectMapper();
3860
objectMapper.enable(DeserializationFeature.ACCEPT_EMPTY_STRING_AS_NULL_OBJECT);
3961
objectMapper.disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES);
4062
}
4163

64+
/**
65+
* Parse protobuf schema definition from Glue Schema Registry using protoc compiler
66+
* @param glueRegistryName Registry name
67+
* @param glueSchemaName Schema name
68+
* @return List of MSKField objects containing field information
69+
* @throws AthenaConnectorException if schema parsing fails
70+
*/
71+
public List<MSKField> getProtobufFields(String glueRegistryName, String glueSchemaName)
72+
{
73+
// Get schema from Glue
74+
GetSchemaVersionResponse schemaVersionResponse = getSchemaVersionResult(glueRegistryName, glueSchemaName);
75+
String schemaDef = schemaVersionResponse.schemaDefinition();
76+
77+
// Create a unique temp directory using UUID
78+
Path protoDir = Paths.get("/tmp", "proto_" + UUID.randomUUID());
79+
Path protoFile = protoDir.resolve(PROTO_FILE);
80+
Path descFile = protoDir.resolve(DESC_FILE);
81+
82+
try {
83+
Files.createDirectories(protoDir);
84+
Files.writeString(protoFile, schemaDef);
85+
// Compile using protoc-jar
86+
int exitCode = Protoc.runProtoc(new String[]{
87+
"--descriptor_set_out=" + descFile.toAbsolutePath(),
88+
"--proto_path=" + protoDir.toAbsolutePath(),
89+
protoFile.getFileName().toString()
90+
});
91+
92+
if (exitCode != 0 || !Files.exists(descFile)) {
93+
throw new AthenaConnectorException(
94+
"Failed to generate descriptor set with protoc",
95+
ErrorDetails.builder()
96+
.errorCode(FederationSourceErrorCode.INTERNAL_SERVICE_EXCEPTION.toString())
97+
.build()
98+
);
99+
}
100+
101+
try (FileInputStream fis = new FileInputStream(descFile.toFile())) {
102+
FileDescriptorSet descriptorSet = FileDescriptorSet.parseFrom(fis);
103+
104+
if (descriptorSet.getFileList().isEmpty() ||
105+
descriptorSet.getFile(0).getMessageTypeList().isEmpty()) {
106+
throw new AthenaConnectorException(
107+
"No message types found in compiled schema",
108+
ErrorDetails.builder()
109+
.errorCode(FederationSourceErrorCode.INVALID_RESPONSE_EXCEPTION.toString())
110+
.build()
111+
);
112+
}
113+
114+
List<MSKField> fields = new ArrayList<>();
115+
DescriptorProto messageType = descriptorSet.getFile(0).getMessageType(0);
116+
for (FieldDescriptorProto field : messageType.getFieldList()) {
117+
String fieldType = getFieldTypeString(field);
118+
fields.add(new MSKField(field.getName(), fieldType));
119+
}
120+
121+
return fields;
122+
}
123+
}
124+
catch (IOException | InterruptedException e) {
125+
throw new AthenaConnectorException(
126+
"Error while handling schema files or protoc execution",
127+
ErrorDetails.builder()
128+
.errorCode(FederationSourceErrorCode.INTERNAL_SERVICE_EXCEPTION.toString())
129+
.build()
130+
);
131+
}
132+
finally {
133+
// Clean up temporary files
134+
try {
135+
Files.deleteIfExists(protoFile);
136+
Files.deleteIfExists(descFile);
137+
Files.deleteIfExists(protoDir);
138+
}
139+
catch (IOException e) {
140+
logger.warn("Failed to clean up temporary proto directory: {}", protoDir.toAbsolutePath(), e);
141+
}
142+
}
143+
}
144+
145+
/**
146+
* Convert protobuf field type to string representation
147+
*/
148+
private String getFieldTypeString(FieldDescriptorProto field)
149+
{
150+
String baseType = field.getType().toString().toLowerCase().replace("type_", "");
151+
return field.getLabel() == FieldDescriptorProto.Label.LABEL_REPEATED ?
152+
"repeated " + baseType : baseType;
153+
}
154+
42155
/**
43156
* Fetch glue schema content for latest version
44157
* @param glueRegistryName
@@ -62,6 +175,7 @@ public GetSchemaVersionResponse getSchemaVersionResult(String glueRegistryName,
62175
.build()
63176
);
64177
}
178+
65179
/**
66180
* fetch schema file content from glue schema.
67181
*
@@ -77,14 +191,10 @@ public <T> T getGlueSchema(String glueRegistryName, String glueSchemaName, Class
77191
GetSchemaVersionResponse result = getSchemaVersionResult(glueRegistryName, glueSchemaName);
78192
return objectMapper.readValue(result.schemaDefinition(), clazz);
79193
}
194+
80195
public String getGlueSchemaType(String glueRegistryName, String glueSchemaName)
81196
{
82197
GetSchemaVersionResponse result = getSchemaVersionResult(glueRegistryName, glueSchemaName);
83198
return result.dataFormatAsString();
84199
}
85-
public String getSchemaDef(String glueRegistryName, String glueSchemaName)
86-
{
87-
GetSchemaVersionResponse result = getSchemaVersionResult(glueRegistryName, glueSchemaName);
88-
return result.schemaDefinition();
89-
}
90200
}

athena-msk/src/main/java/com/amazonaws/athena/connectors/msk/dto/MSKField.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ public MSKField(String name, String mapping, String type, String formatHint, Obj
4040
this.value = value;
4141
}
4242

43+
public MSKField(String name, String type)
44+
{
45+
this.name = name;
46+
this.type = type;
47+
}
48+
4349
public String getName()
4450
{
4551
return name;

athena-msk/src/test/java/com/amazonaws/athena/connectors/msk/AmazonMskMetadataHandlerTest.java

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import com.amazonaws.athena.connector.lambda.domain.predicate.Constraints;
2828
import com.amazonaws.athena.connector.lambda.metadata.*;
2929
import com.amazonaws.athena.connector.lambda.security.FederatedIdentity;
30+
import org.apache.arrow.vector.types.Types;
3031
import org.apache.kafka.clients.consumer.MockConsumer;
3132
import org.apache.kafka.clients.consumer.OffsetResetStrategy;
3233
import org.apache.kafka.common.PartitionInfo;
@@ -57,8 +58,10 @@
5758
import java.util.Map;
5859
import java.util.stream.Collectors;
5960

61+
import static com.amazonaws.athena.connectors.msk.AmazonMskConstants.PROTOBUF_DATA_FORMAT;
6062
import static org.junit.Assert.assertEquals;
6163
import static org.junit.Assert.assertNull;
64+
import static org.junit.Assert.fail;
6265
import static org.mockito.ArgumentMatchers.any;
6366
import static org.mockito.Mockito.mock;
6467
import static org.mockito.Mockito.when;
@@ -181,6 +184,65 @@ public void testDoGetTable() throws Exception {
181184
assertEquals(1, getTableResponse.getSchema().getFields().size());
182185
}
183186

187+
@Test
188+
public void testDoGetTableWithProtobufSchema()
189+
{
190+
String arn = "defaultarn";
191+
String schemaName = "defaultschemaname";
192+
String schemaVersionId = "defaultversionid";
193+
Long latestSchemaVersion = 123L;
194+
GetSchemaResponse getSchemaResponse = GetSchemaResponse.builder()
195+
.schemaArn(arn)
196+
.schemaName(schemaName)
197+
.latestSchemaVersion(latestSchemaVersion)
198+
.build();
199+
GetSchemaVersionResponse getSchemaVersionResponse = GetSchemaVersionResponse.builder()
200+
.schemaArn(arn)
201+
.schemaVersionId(schemaVersionId)
202+
.schemaDefinition("syntax = \"proto3\";\n" +
203+
"package test;\n" +
204+
"message TestMessage {\n" +
205+
" int32 id = 1;\n" +
206+
" string name = 2;\n" +
207+
" double value = 3;\n" +
208+
"}")
209+
.dataFormat(DataFormat.PROTOBUF)
210+
.build();
211+
Mockito.when(awsGlue.getSchema(any(GetSchemaRequest.class))).thenReturn(getSchemaResponse);
212+
Mockito.when(awsGlue.getSchemaVersion(any(GetSchemaVersionRequest.class))).thenReturn(getSchemaVersionResponse);
213+
214+
GetTableRequest getTableRequest = new GetTableRequest(
215+
federatedIdentity,
216+
QUERY_ID,
217+
"kafka",
218+
new TableName("default", "testmessage"),
219+
Collections.emptyMap()
220+
);
221+
222+
GetTableResponse getTableResponse;
223+
try {
224+
getTableResponse = amazonMskMetadataHandler.doGetTable(blockAllocator, getTableRequest);
225+
}
226+
catch (Exception e) {
227+
fail("Unexpected exception in doGetTable():" + e.getMessage());
228+
return;
229+
}
230+
231+
// Verify schema field names
232+
assertEquals(3, getTableResponse.getSchema().getFields().size());
233+
assertEquals("id", getTableResponse.getSchema().getFields().get(0).getName());
234+
assertEquals("name", getTableResponse.getSchema().getFields().get(1).getName());
235+
assertEquals("value", getTableResponse.getSchema().getFields().get(2).getName());
236+
237+
// verify schema field types
238+
assertEquals(Types.MinorType.INT.getType(), getTableResponse.getSchema().getFields().get(0).getType());
239+
assertEquals(Types.MinorType.VARCHAR.getType(), getTableResponse.getSchema().getFields().get(1).getType());
240+
assertEquals(Types.MinorType.FLOAT8.getType(), getTableResponse.getSchema().getFields().get(2).getType());
241+
242+
// Verify data format metadata
243+
assertEquals(PROTOBUF_DATA_FORMAT, getTableResponse.getSchema().getCustomMetadata().get("dataFormat"));
244+
}
245+
184246
@Test
185247
public void testDoGetSplits() throws Exception
186248
{

0 commit comments

Comments
 (0)