Skip to content

Commit 4fa73a9

Browse files
authored
fix: Update redshift api (#2479)
* Fix Signed-off-by: Kevin Zhang <[email protected]> * Fix Signed-off-by: Kevin Zhang <[email protected]> * Remove warning Signed-off-by: Kevin Zhang <[email protected]>
1 parent 4864252 commit 4fa73a9

File tree

5 files changed

+49
-9
lines changed

5 files changed

+49
-9
lines changed

protos/feast/core/DataSource.proto

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ message DataSource {
145145

146146
// Redshift schema name
147147
string schema = 3;
148+
149+
// Redshift database name
150+
string database = 4;
148151
}
149152

150153
// Defines options for DataSource that sources features from a Snowflake Query

sdk/python/feast/infra/offline_stores/redshift_source.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
description: Optional[str] = "",
2828
tags: Optional[Dict[str, str]] = None,
2929
owner: Optional[str] = "",
30+
database: Optional[str] = "",
3031
):
3132
"""
3233
Creates a RedshiftSource object.
@@ -47,11 +48,12 @@ def __init__(
4748
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
4849
owner (optional): The owner of the redshift source, typically the email of the primary
4950
maintainer.
51+
database (optional): The Redshift database name.
5052
"""
5153
# The default Redshift schema is named "public".
5254
_schema = "public" if table and not schema else schema
5355
self.redshift_options = RedshiftOptions(
54-
table=table, schema=_schema, query=query
56+
table=table, schema=_schema, query=query, database=database
5557
)
5658

5759
if table is None and query is None:
@@ -102,6 +104,7 @@ def from_proto(data_source: DataSourceProto):
102104
description=data_source.description,
103105
tags=dict(data_source.tags),
104106
owner=data_source.owner,
107+
database=data_source.redshift_options.database,
105108
)
106109

107110
# Note: Python requires redefining hash in child classes that override __eq__
@@ -119,6 +122,7 @@ def __eq__(self, other):
119122
and self.redshift_options.table == other.redshift_options.table
120123
and self.redshift_options.schema == other.redshift_options.schema
121124
and self.redshift_options.query == other.redshift_options.query
125+
and self.redshift_options.database == other.redshift_options.database
122126
and self.event_timestamp_column == other.event_timestamp_column
123127
and self.created_timestamp_column == other.created_timestamp_column
124128
and self.field_mapping == other.field_mapping
@@ -139,9 +143,14 @@ def schema(self):
139143

140144
@property
141145
def query(self):
142-
"""Returns the Redshift options of this Redshift source."""
146+
"""Returns the Redshift query of this Redshift source."""
143147
return self.redshift_options.query
144148

149+
@property
150+
def database(self):
151+
"""Returns the Redshift database of this Redshift source."""
152+
return self.redshift_options.database
153+
145154
def to_proto(self) -> DataSourceProto:
146155
"""
147156
Converts a RedshiftSource object to its protobuf representation.
@@ -197,12 +206,15 @@ def get_table_column_names_and_types(
197206
assert isinstance(config.offline_store, RedshiftOfflineStoreConfig)
198207

199208
client = aws_utils.get_redshift_data_client(config.offline_store.region)
200-
201209
if self.table is not None:
202210
try:
203211
table = client.describe_table(
204212
ClusterIdentifier=config.offline_store.cluster_id,
205-
Database=config.offline_store.database,
213+
Database=(
214+
self.database
215+
if self.database
216+
else config.offline_store.database
217+
),
206218
DbUser=config.offline_store.user,
207219
Table=self.table,
208220
Schema=self.schema,
@@ -221,7 +233,7 @@ def get_table_column_names_and_types(
221233
statement_id = aws_utils.execute_redshift_statement(
222234
client,
223235
config.offline_store.cluster_id,
224-
config.offline_store.database,
236+
self.database if self.database else config.offline_store.database,
225237
config.offline_store.user,
226238
f"SELECT * FROM ({self.query}) LIMIT 1",
227239
)
@@ -238,11 +250,16 @@ class RedshiftOptions:
238250
"""
239251

240252
def __init__(
241-
self, table: Optional[str], schema: Optional[str], query: Optional[str]
253+
self,
254+
table: Optional[str],
255+
schema: Optional[str],
256+
query: Optional[str],
257+
database: Optional[str],
242258
):
243259
self._table = table
244260
self._schema = schema
245261
self._query = query
262+
self._database = database
246263

247264
@property
248265
def query(self):
@@ -274,6 +291,16 @@ def schema(self, schema):
274291
"""Sets the schema of this Redshift table."""
275292
self._schema = schema
276293

294+
@property
295+
def database(self):
296+
"""Returns the schema name of this Redshift table."""
297+
return self._database
298+
299+
@database.setter
300+
def database(self, database):
301+
"""Sets the database name of this Redshift table."""
302+
self._database = database
303+
277304
@classmethod
278305
def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions):
279306
"""
@@ -289,6 +316,7 @@ def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions):
289316
table=redshift_options_proto.table,
290317
schema=redshift_options_proto.schema,
291318
query=redshift_options_proto.query,
319+
database=redshift_options_proto.database,
292320
)
293321

294322
return redshift_options
@@ -301,7 +329,10 @@ def to_proto(self) -> DataSourceProto.RedshiftOptions:
301329
A RedshiftOptionsProto protobuf.
302330
"""
303331
redshift_options_proto = DataSourceProto.RedshiftOptions(
304-
table=self.table, schema=self.schema, query=self.query,
332+
table=self.table,
333+
schema=self.schema,
334+
query=self.query,
335+
database=self.database,
305336
)
306337

307338
return redshift_options_proto
@@ -314,7 +345,7 @@ class SavedDatasetRedshiftStorage(SavedDatasetStorage):
314345

315346
def __init__(self, table_ref: str):
316347
self.redshift_options = RedshiftOptions(
317-
table=table_ref, schema=None, query=None
348+
table=table_ref, schema=None, query=None, database=None
318349
)
319350

320351
@staticmethod

sdk/python/feast/templates/aws/bootstrap.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,17 @@ def bootstrap():
5353

5454
repo_path = pathlib.Path(__file__).parent.absolute()
5555
config_file = repo_path / "feature_store.yaml"
56+
driver_file = repo_path / "driver_repo.py"
5657

5758
replace_str_in_file(config_file, "%AWS_REGION%", aws_region)
5859
replace_str_in_file(config_file, "%REDSHIFT_CLUSTER_ID%", cluster_id)
5960
replace_str_in_file(config_file, "%REDSHIFT_DATABASE%", database)
61+
replace_str_in_file(driver_file, "%REDSHIFT_DATABASE%", database)
6062
replace_str_in_file(config_file, "%REDSHIFT_USER%", user)
6163
replace_str_in_file(
62-
config_file, "%REDSHIFT_S3_STAGING_LOCATION%", s3_staging_location
64+
driver_file, config_file, "%REDSHIFT_S3_STAGING_LOCATION%", s3_staging_location
6365
)
66+
replace_str_in_file(config_file,)
6467
replace_str_in_file(config_file, "%REDSHIFT_IAM_ROLE%", iam_role)
6568

6669

sdk/python/feast/templates/aws/driver_repo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
# The (optional) created timestamp is used to ensure there are no duplicate
2828
# feature rows in the offline store or when building training datasets
2929
created_timestamp_column="created",
30+
# Database to redshift source.
31+
database="%REDSHIFT_DATABASE%",
3032
)
3133

3234
# Feature views are a grouping based on how features are stored in either the

sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def create_data_source(
6565
created_timestamp_column=created_timestamp_column,
6666
date_partition_column="",
6767
field_mapping=field_mapping or {"ts_1": "ts"},
68+
database=self.offline_store_config.database,
6869
)
6970

7071
def create_saved_dataset_destination(self) -> SavedDatasetRedshiftStorage:

0 commit comments

Comments
 (0)