Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions c/driver/postgresql/statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,16 @@ struct BindStream {
type_id = PostgresTypeId::kBytea;
param_lengths[i] = 0;
break;
case ArrowType::NANOARROW_TYPE_TIMESTAMP:
if (strcmp("", bind_schema_fields[i].timezone)) {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was a bit surprised that nanoarrow assigns an empty string to the timezone member when constructing a schema with a timestamp; was expecting it to be nullptr

SetError(error, "[libpq] Field #%" PRIi64 "%s%s%s",
static_cast<int64_t>(i + 1), " (\"", bind_schema->children[i]->name,
"\") has unsupported type code timestamp with timezone");
return ADBC_STATUS_NOT_IMPLEMENTED;
}
type_id = PostgresTypeId::kTimestamp;
param_lengths[i] = 8;
break;
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", bind_schema->children[i]->name,
Expand Down Expand Up @@ -337,6 +347,53 @@ struct BindStream {
param_values[col] = const_cast<char*>(view.data.as_char);
break;
}
case ArrowType::NANOARROW_TYPE_TIMESTAMP: {
int64_t val = array_view->children[col]->buffer_views[1].data.as_int64[row];
if (strcmp("", bind_schema_fields[col].timezone)) {
SetError(error, "[libpq] Column #%" PRIi64 "%s%s%s", col + 1, " (\"",
PQfname(result, col),
"\") has unsupported type code timestamp with timezone");
return ADBC_STATUS_NOT_IMPLEMENTED;
}

// 2000-01-01 00:00:00.000000 in microseconds
constexpr int64_t kPostgresTimestampEpoch = 946684800000000;
constexpr int64_t kSecOverflowLimit = 9223372036854;
constexpr int64_t kmSecOverflowLimit = 9223372036854775;

auto unit = bind_schema_fields[col].time_unit;
switch (unit) {
case NANOARROW_TIME_UNIT_SECOND:
if (abs(val) > kSecOverflowLimit) {
SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s",
col + 1, "('", bind_schema->children[col]->name, "') Row #",
row + 1,
"has value which exceeds postgres timestamp limits");
return ADBC_STATUS_INVALID_ARGUMENT;
}
val *= 1000000;
break;
case NANOARROW_TIME_UNIT_MILLI:
if (abs(val) > kmSecOverflowLimit) {
SetError(error, "[libpq] Field #%" PRId64 "%s%s%s%" PRId64 "%s",
col + 1, "('", bind_schema->children[col]->name, "') Row #",
row + 1,
"has value which exceeds postgres timestamp limits");
return ADBC_STATUS_INVALID_ARGUMENT;
}
val *= 1000;
break;
case NANOARROW_TIME_UNIT_MICRO:
break;
case NANOARROW_TIME_UNIT_NANO:
val /= 1000;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should handle truncation/overflow here

break;
}

const uint64_t value = ToNetworkInt64(val - kPostgresTimestampEpoch);
std::memcpy(param_values[col], &value, sizeof(int64_t));
break;
}
default:
SetError(error, "%s%" PRId64 "%s%s%s%s", "[libpq] Field #", col + 1, " ('",
bind_schema->children[col]->name,
Expand Down Expand Up @@ -605,6 +662,15 @@ AdbcStatusCode PostgresStatement::CreateBulkTable(
case ArrowType::NANOARROW_TYPE_BINARY:
create += " BYTEA";
break;
case ArrowType::NANOARROW_TYPE_TIMESTAMP:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably differentiate between WITH/WITHOUT TIMEZONE?

note that in WITH TIMEZONE, Arrow always stores the underlying value in UTC so there's no need for us to do any time zone math (thankfully!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I mistakenly assumed with TIMEZONE was a different type. For this PR planning to just raise if a timezone is detected, as I'm not yet sure how to transmit that information via the binary protocol

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to check for timezone here, too?

if (strcmp("", source_schema_fields[i].timezone)) {
SetError(error, "[libpq] Field #%" PRIi64 "%s%s%s", static_cast<int64_t>(i + 1),
" (\"", source_schema.children[i]->name,
"\") has unsupported type for ingestion timestamp with timezone");
return ADBC_STATUS_NOT_IMPLEMENTED;
}
create += " TIMESTAMP";
break;
default:
SetError(error, "%s%" PRIu64 "%s%s%s%s", "[libpq] Field #",
static_cast<uint64_t>(i + 1), " ('", source_schema.children[i]->name,
Expand Down
3 changes: 3 additions & 0 deletions c/driver/sqlite/sqlite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ class SqliteStatementTest : public ::testing::Test,

void TestSqlIngestUInt64() { GTEST_SKIP() << "Cannot ingest UINT64 (out of range)"; }
void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not implemented)"; }
void TestSqlIngestTimestamp() {
GTEST_SKIP() << "Cannot ingest TIMESTAMP (not implemented)";
}

protected:
SqliteQuirks quirks_;
Expand Down
3 changes: 3 additions & 0 deletions c/driver_manager/adbc_driver_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ class SqliteStatementTest : public ::testing::Test,

void TestSqlIngestUInt64() { GTEST_SKIP() << "Cannot ingest UINT64 (out of range)"; }
void TestSqlIngestBinary() { GTEST_SKIP() << "Cannot ingest BINARY (not implemented)"; }
void TestSqlIngestTimestamp() {
GTEST_SKIP() << "Cannot ingest TIMESTAMP (not implemented)";
}

protected:
SqliteQuirks quirks_;
Expand Down
77 changes: 77 additions & 0 deletions c/validation/adbc_validation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,83 @@ void StatementTest::TestSqlIngestBinary() {
NANOARROW_TYPE_BINARY, {std::nullopt, "", "\x00\x01\x02\x04", "\xFE\xFF"}));
}

template <enum ArrowTimeUnit TU>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The templating here might be over-engineering. I was thinking we could also change the signature to something like:

void StatementTest::TestSqlIngestTemporalType(
    std::vector<std::optional<int64_t>>& values, 
    enum ArrowTimeUnit unit, const char* timezone) {
...
}

While nothing else is using that pattern currently, it might be nice for gtest to do something like:

  ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType(
    {std::nullopt, 0, 42}, NANOARROW_TIME_UNIT_SECOND, nullptr
  ));
  EXPECT_FATAL_FAILURE(TestSqlIngestTemporalType(
    {std::nullopt, INT64_MIN, INT64_MAX}, NANOARROW_TIME_UNIT_SECOND, nullptr), 
    "overflow")
  );
  EXPECT_FATAL_FAILURE(TestSqlIngestTemporalType(
    {std::nullopt, 0, 42}, NANOARROW_TIME_UNIT_SECOND, "America/Los Angeles"), 
    "not implemented")
  );

(N.B. I have no experience with EXPECT_FATAL_FAILURE)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be easiest to have the signature be AdbcStatusCode TestSqlIngestTemporalType(..., struct AdbcError* error) and then you can use the usual ASSERT_THAT(..., IsOkStatus(&error))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to fix things here, but it should be quite straightforward to also write a matcher like IsError(&error, ::testing::HasSubstr("..."))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems reasonable as is though

void StatementTest::TestSqlIngestTemporalType() {
if (!quirks()->supports_bulk_ingest()) {
GTEST_SKIP();
}

ASSERT_THAT(quirks()->DropTable(&connection, "bulk_ingest", &error),
IsOkStatus(&error));

Handle<struct ArrowSchema> schema;
Handle<struct ArrowArray> array;
struct ArrowError na_error;
const std::vector<std::optional<int64_t>> values = {std::nullopt, 0, 42};

// much of this code is shared with TestSqlIngestType with minor
// changes to allow for various time units to be tested
ArrowSchemaInit(&schema.value);
ArrowSchemaSetTypeStruct(&schema.value, 1);
ArrowSchemaSetTypeDateTime(schema->children[0], NANOARROW_TYPE_TIMESTAMP, TU,
/*timezone=*/nullptr);
ArrowSchemaSetName(schema->children[0], "col");
ASSERT_THAT(MakeBatch<int64_t>(&schema.value, &array.value, &na_error, values),
IsOkErrno());

ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE,
"bulk_ingest", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error),
IsOkStatus(&error));

int64_t rows_affected = 0;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error),
IsOkStatus(&error));
ASSERT_THAT(rows_affected,
::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1)));

ASSERT_THAT(AdbcStatementSetSqlQuery(
&statement,
"SELECT * FROM bulk_ingest ORDER BY \"col\" ASC NULLS FIRST", &error),
IsOkStatus(&error));
{
StreamReader reader;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
IsOkStatus(&error));
ASSERT_THAT(reader.rows_affected,
::testing::AnyOf(::testing::Eq(values.size()), ::testing::Eq(-1)));

ASSERT_NO_FATAL_FAILURE(reader.GetSchema());
ASSERT_NO_FATAL_FAILURE(CompareSchema(&reader.schema.value,
{{"col", NANOARROW_TYPE_TIMESTAMP, NULLABLE}}));

ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_NE(nullptr, reader.array->release);
ASSERT_EQ(values.size(), reader.array->length);
ASSERT_EQ(1, reader.array->n_children);

if (TU == NANOARROW_TIME_UNIT_MICRO) {
// Similar to the TestSqlIngestType implementation we are only now
// testing values if the unit round trips
ASSERT_NO_FATAL_FAILURE(
CompareArray<int64_t>(reader.array_view->children[0], values));
}

ASSERT_NO_FATAL_FAILURE(reader.Next());
ASSERT_EQ(nullptr, reader.array->release);
}
}

void StatementTest::TestSqlIngestTimestamp() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_SECOND>());
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MICRO>());
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_MILLI>());
ASSERT_NO_FATAL_FAILURE(TestSqlIngestTemporalType<NANOARROW_TIME_UNIT_NANO>());
}

void StatementTest::TestSqlIngestAppend() {
if (!quirks()->supports_bulk_ingest()) {
GTEST_SKIP();
Expand Down
7 changes: 7 additions & 0 deletions c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ class StatementTest {
void TestSqlIngestString();
void TestSqlIngestBinary();

// Temporal
void TestSqlIngestTimestamp();

// ---- End Type-specific tests ----------------

void TestSqlIngestAppend();
Expand Down Expand Up @@ -269,6 +272,9 @@ class StatementTest {

template <typename CType>
void TestSqlIngestNumericType(ArrowType type);

template <enum ArrowTimeUnit TU>
void TestSqlIngestTemporalType();
};

#define ADBCV_TEST_STATEMENT(FIXTURE) \
Expand All @@ -288,6 +294,7 @@ class StatementTest {
TEST_F(FIXTURE, SqlIngestFloat64) { TestSqlIngestFloat64(); } \
TEST_F(FIXTURE, SqlIngestString) { TestSqlIngestString(); } \
TEST_F(FIXTURE, SqlIngestBinary) { TestSqlIngestBinary(); } \
TEST_F(FIXTURE, SqlIngestTimestamp) { TestSqlIngestTimestamp(); } \
TEST_F(FIXTURE, SqlIngestAppend) { TestSqlIngestAppend(); } \
TEST_F(FIXTURE, SqlIngestErrors) { TestSqlIngestErrors(); } \
TEST_F(FIXTURE, SqlIngestMultipleConnections) { TestSqlIngestMultipleConnections(); } \
Expand Down