Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*-
* #%L
* athena-vertica
* %%
* Copyright (C) 2019 - 2025 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.vertica;

import org.junit.Before;
import org.junit.Test;

import java.util.HashMap;
import java.util.Map;

import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.DATABASE;
import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.DEFAULT;
import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.HOST;
import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.PORT;
import static com.amazonaws.athena.connector.lambda.connection.EnvironmentConstants.SECRET_NAME;
import static org.junit.Assert.assertEquals;

public class VerticaEnvironmentPropertiesTest {
private Map<String, String> connectionProperties;
private VerticaEnvironmentProperties verticaEnvironmentProperties;

@Before
public void setUp() {
connectionProperties = new HashMap<>();
connectionProperties.put(HOST, "vertica-cluster-endpoint");
connectionProperties.put(DATABASE, "verticadb");
connectionProperties.put(SECRET_NAME, "vertica-secret");
connectionProperties.put(PORT, "1234");
verticaEnvironmentProperties = new VerticaEnvironmentProperties();
}

@Test
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also add negative test cases here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank You, added test cases for negative scenarios.

public void verticaConnectionPropertiesTest() {
Map<String, String> verticaConnectionProperties = verticaEnvironmentProperties.connectionPropertiesToEnvironment(connectionProperties);

String expectedConnectionString = "vertica://jdbc:vertica://vertica-cluster-endpoint:1234/verticadb?${vertica-secret}";
assertEquals(expectedConnectionString, verticaConnectionProperties.get(DEFAULT));
}
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import java.io.ByteArrayInputStream;
import java.io.InputStream;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
Expand All @@ -31,6 +32,12 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.BitVector;
import org.apache.arrow.vector.TinyIntVector;
import org.apache.arrow.vector.SmallIntVector;
import org.apache.arrow.vector.Float8Vector;
import org.apache.arrow.vector.Float4Vector;
import org.apache.arrow.vector.DecimalVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
Expand Down Expand Up @@ -164,6 +171,7 @@ public void after()
logger.info("{}: exit ", testName.getMethodName());
}


@Test
public void doReadRecordsNoSpill()
throws Exception
Expand Down Expand Up @@ -305,45 +313,97 @@ public byte[] getBytes()
private VectorSchemaRoot createRoot()
{
Schema schema = SchemaBuilder.newBuilder()
.addBitField("bitField")
.addTinyIntField("tinyIntField")
.addSmallIntField("smallIntField")
.addBigIntField("day")
.addBigIntField("month")
.addBigIntField("year")
.addFloat4Field("float4Field")
.addFloat8Field("float8Field")
.addDecimalField("decimalField", 38, 10)

.addStringField("preparedStmt")
.addStringField("queryId")
.addStringField("awsRegionSql")
.build();

VectorSchemaRoot schemaRoot = VectorSchemaRoot.create(schema, bufferAllocator);

BitVector bitVector = (BitVector) schemaRoot.getVector("bitField");
bitVector.allocateNew(2);
bitVector.set(0, 1);
bitVector.set(1, 0);
bitVector.setValueCount(2);

TinyIntVector tinyIntVector = (TinyIntVector) schemaRoot.getVector("tinyIntField");
tinyIntVector.allocateNew(2);
tinyIntVector.set(0, (byte) 10);
tinyIntVector.set(1, (byte) 20);
tinyIntVector.setValueCount(2);

SmallIntVector smallIntVector = (SmallIntVector) schemaRoot.getVector("smallIntField");
smallIntVector.allocateNew(2);
smallIntVector.set(0, (short) 100);
smallIntVector.set(1, (short) 200);
smallIntVector.setValueCount(2);

BigIntVector dayVector = (BigIntVector) schemaRoot.getVector("day");
dayVector.allocateNew(2);
dayVector.set(0, 0);
dayVector.set(1, 1);
dayVector.setValueCount(2);

BigIntVector monthVector = (BigIntVector) schemaRoot.getVector("month");
monthVector.allocateNew(2);
monthVector.set(0, 0);
monthVector.set(1, 1);
monthVector.setValueCount(2);

BigIntVector yearVector = (BigIntVector) schemaRoot.getVector("year");
yearVector.allocateNew(2);
yearVector.set(0, 2000);
yearVector.set(1, 2001);
yearVector.setValueCount(2);

Float4Vector float4Vector = (Float4Vector) schemaRoot.getVector("float4Field");
float4Vector.allocateNew(2);
float4Vector.set(0, 1.5f);
float4Vector.set(1, 2.5f);
float4Vector.setValueCount(2);

Float8Vector float8Vector = (Float8Vector) schemaRoot.getVector("float8Field");
float8Vector.allocateNew(2);
float8Vector.set(0, 3.141592653);
float8Vector.set(1, 2.718281828);
float8Vector.setValueCount(2);

DecimalVector decimalVector = (DecimalVector) schemaRoot.getVector("decimalField");
decimalVector.allocateNew(2);
decimalVector.set(0, new BigDecimal("123.4567890123"));
decimalVector.set(1, new BigDecimal("987.6543210987"));
decimalVector.setValueCount(2);

VarCharVector stmtVector = (VarCharVector) schemaRoot.getVector("preparedStmt");
stmtVector.allocateNew(2);
stmtVector.set(0, new Text("test1"));
stmtVector.set(1, new Text("test2"));
stmtVector.setValueCount(2);

VarCharVector idVector = (VarCharVector) schemaRoot.getVector("queryId");
idVector.allocateNew(2);
idVector.set(0, new Text("queryID1"));
idVector.set(1, new Text("queryID2"));
idVector.setValueCount(2);

VarCharVector regionVector = (VarCharVector) schemaRoot.getVector("awsRegionSql");
regionVector.allocateNew(2);
regionVector.set(0, new Text("region1"));
regionVector.set(1, new Text("region2"));
regionVector.setValueCount(2);

schemaRoot.setRowCount(2);
return schemaRoot;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,39 @@
package com.amazonaws.athena.connectors.vertica;

import com.amazonaws.athena.connector.lambda.domain.TableName;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.sql.*;
import org.mockito.junit.MockitoJUnitRunner;

import java.lang.reflect.Method;
import java.security.cert.Certificate;
import java.security.cert.X509Certificate;
import java.sql.Types;
import java.sql.ResultSet;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.SQLException;
import java.util.Base64;
import java.util.concurrent.atomic.AtomicInteger;

import static org.mockito.ArgumentMatchers.nullable;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.when;

@RunWith(MockitoJUnitRunner.class)
public class VerticaSchemaUtilsTest extends TestBase
{
private static final Logger logger = LoggerFactory.getLogger(VerticaSchemaUtils.class);
private Connection connection;
private DatabaseMetaData databaseMetaData;
private TableName tableName;

@Mock
private X509Certificate mockCertificate;

@Before
public void setUp() throws SQLException
Expand All @@ -60,11 +72,27 @@ public void buildTableSchema() throws SQLException
int [] types = {Types.INTEGER, Types.INTEGER};*/

String[] schema = {"TABLE_SCHEM", "TABLE_NAME", "COLUMN_NAME", "TYPE_NAME"};
Object[][] values = {{"testSchema", "testTable1", "id", "bigint"}, {"testSchema", "testTable1", "date", "timestamp"},
{"testSchema", "testTable1", "orders", "integer"}, {"testSchema", "testTable1", "price", "float4"},
{"testSchema", "testTable1", "shop", "varchar"}
};
int[] types = {Types.BIGINT, Types.TIMESTAMP, Types.INTEGER,Types.FLOAT, Types.VARCHAR, Types.VARCHAR};
Object[][] values = {
{"testSchema", "testTable1", "bit_col", "BIT"},
{"testSchema", "testTable1", "tinyint_col", "TINYINT"},
{"testSchema", "testTable1", "smallint_col", "SMALLINT"},
{"testSchema", "testTable1", "integer_col", "INTEGER"},
{"testSchema", "testTable1", "bigint_col", "BIGINT"},
{"testSchema", "testTable1", "float4_col", "FLOAT4"},
{"testSchema", "testTable1", "float8_col", "FLOAT8"},
{"testSchema", "testTable1", "numeric_col", "NUMERIC"},
{"testSchema", "testTable1", "boolean_col", "BOOLEAN"},
{"testSchema", "testTable1", "varchar_col", "VARCHAR"},
{"testSchema", "testTable1", "timestamp_col", "TIMESTAMP"},
{"testSchema", "testTable1", "timestamptz_col", "TIMESTAMPTZ"},
{"testSchema", "testTable1", "datetime_col", "DATETIME"},
{"testSchema", "testTable1", "unknown_col", "UNKNOWN"}
};
int[] types = {
Types.BIT, Types.TINYINT, Types.SMALLINT, Types.INTEGER, Types.BIGINT,
Types.FLOAT, Types.DOUBLE, Types.NUMERIC, Types.BOOLEAN, Types.VARCHAR,
Types.TIMESTAMP, Types.TIMESTAMP, Types.DATE, Types.OTHER
};

AtomicInteger rowNumber = new AtomicInteger(-1);
ResultSet resultSet = mockResultSet(schema, types, values, rowNumber);
Expand All @@ -75,12 +103,52 @@ public void buildTableSchema() throws SQLException
VerticaSchemaUtils verticaSchemaUtils = new VerticaSchemaUtils();
Schema mockSchema = verticaSchemaUtils.buildTableSchema(this.connection, tableName);

Field testDateField = mockSchema.findField("date");
Assert.assertEquals("Utf8", testDateField.getType().toString());
assertEquals("Bool", mockSchema.findField("bit_col").getType().toString());
assertEquals("Int(8, true)", mockSchema.findField("tinyint_col").getType().toString());
assertEquals("Int(16, true)", mockSchema.findField("smallint_col").getType().toString());
assertEquals("Int(64, true)", mockSchema.findField("integer_col").getType().toString());
assertEquals("Int(64, true)", mockSchema.findField("bigint_col").getType().toString());
assertEquals("FloatingPoint(SINGLE)", mockSchema.findField("float4_col").getType().toString());
assertEquals("FloatingPoint(DOUBLE)", mockSchema.findField("float8_col").getType().toString());
assertEquals("Decimal(10, 2, 128)", mockSchema.findField("numeric_col").getType().toString());
assertEquals("Utf8", mockSchema.findField("boolean_col").getType().toString());
assertEquals("Utf8", mockSchema.findField("varchar_col").getType().toString());
assertEquals("Utf8", mockSchema.findField("timestamp_col").getType().toString());
assertEquals("Utf8", mockSchema.findField("timestamptz_col").getType().toString());
assertEquals("Date(DAY)", mockSchema.findField("datetime_col").getType().toString());
assertEquals("Utf8", mockSchema.findField("unknown_col").getType().toString());
}

Field testPriceField = mockSchema.findField("price");
Assert.assertEquals("FloatingPoint(SINGLE)", testPriceField.getType().toString());
@Test
public void buildTableSchemaSQLException() throws SQLException
{
when(databaseMetaData.getColumns(null, tableName.getSchemaName(), tableName.getTableName(), null))
.thenThrow(new SQLException("Database error"));

VerticaSchemaUtils verticaSchemaUtils = new VerticaSchemaUtils();
try {
verticaSchemaUtils.buildTableSchema(connection, tableName);
fail("Expected RuntimeException");
} catch (RuntimeException e) {
assertTrue(e.getMessage().contains("Error in building the table schema"));
assertTrue(e.getCause() instanceof SQLException);
}
}

@Test
public void formatCrtFileContents() throws Exception
{
when(mockCertificate.getEncoded()).thenReturn("test-cert".getBytes());

Method formatMethod = VerticaSchemaUtils.class.getDeclaredMethod("formatCrtFileContents", Certificate.class);
formatMethod.setAccessible(true);
String result = (String) formatMethod.invoke(null, mockCertificate);

String expectedEncoded = Base64.getMimeEncoder(64, System.lineSeparator().getBytes())
.encodeToString("test-cert".getBytes());
String expected = "-----BEGIN CERTIFICATE-----" + System.lineSeparator() +
expectedEncoded + System.lineSeparator() + "-----END CERTIFICATE-----";
assertEquals(expected, result);
}
}

}
Loading