@@ -27,6 +27,7 @@ def __init__(
27
27
description : Optional [str ] = "" ,
28
28
tags : Optional [Dict [str , str ]] = None ,
29
29
owner : Optional [str ] = "" ,
30
+ database : Optional [str ] = "" ,
30
31
):
31
32
"""
32
33
Creates a RedshiftSource object.
@@ -47,11 +48,12 @@ def __init__(
47
48
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
48
49
owner (optional): The owner of the redshift source, typically the email of the primary
49
50
maintainer.
51
+ database (optional): The Redshift database name.
50
52
"""
51
53
# The default Redshift schema is named "public".
52
54
_schema = "public" if table and not schema else schema
53
55
self .redshift_options = RedshiftOptions (
54
- table = table , schema = _schema , query = query
56
+ table = table , schema = _schema , query = query , database = database
55
57
)
56
58
57
59
if table is None and query is None :
@@ -102,6 +104,7 @@ def from_proto(data_source: DataSourceProto):
102
104
description = data_source .description ,
103
105
tags = dict (data_source .tags ),
104
106
owner = data_source .owner ,
107
+ database = data_source .redshift_options .database ,
105
108
)
106
109
107
110
# Note: Python requires redefining hash in child classes that override __eq__
@@ -119,6 +122,7 @@ def __eq__(self, other):
119
122
and self .redshift_options .table == other .redshift_options .table
120
123
and self .redshift_options .schema == other .redshift_options .schema
121
124
and self .redshift_options .query == other .redshift_options .query
125
+ and self .redshift_options .database == other .redshift_options .database
122
126
and self .event_timestamp_column == other .event_timestamp_column
123
127
and self .created_timestamp_column == other .created_timestamp_column
124
128
and self .field_mapping == other .field_mapping
@@ -139,9 +143,14 @@ def schema(self):
139
143
140
144
@property
141
145
def query (self ):
142
- """Returns the Redshift options of this Redshift source."""
146
+ """Returns the Redshift query of this Redshift source."""
143
147
return self .redshift_options .query
144
148
149
+ @property
150
+ def database (self ):
151
+ """Returns the Redshift database of this Redshift source."""
152
+ return self .redshift_options .database
153
+
145
154
def to_proto (self ) -> DataSourceProto :
146
155
"""
147
156
Converts a RedshiftSource object to its protobuf representation.
@@ -197,12 +206,15 @@ def get_table_column_names_and_types(
197
206
assert isinstance (config .offline_store , RedshiftOfflineStoreConfig )
198
207
199
208
client = aws_utils .get_redshift_data_client (config .offline_store .region )
200
-
201
209
if self .table is not None :
202
210
try :
203
211
table = client .describe_table (
204
212
ClusterIdentifier = config .offline_store .cluster_id ,
205
- Database = config .offline_store .database ,
213
+ Database = (
214
+ self .database
215
+ if self .database
216
+ else config .offline_store .database
217
+ ),
206
218
DbUser = config .offline_store .user ,
207
219
Table = self .table ,
208
220
Schema = self .schema ,
@@ -221,7 +233,7 @@ def get_table_column_names_and_types(
221
233
statement_id = aws_utils .execute_redshift_statement (
222
234
client ,
223
235
config .offline_store .cluster_id ,
224
- config .offline_store .database ,
236
+ self . database if self . database else config .offline_store .database ,
225
237
config .offline_store .user ,
226
238
f"SELECT * FROM ({ self .query } ) LIMIT 1" ,
227
239
)
@@ -238,11 +250,16 @@ class RedshiftOptions:
238
250
"""
239
251
240
252
def __init__ (
241
- self , table : Optional [str ], schema : Optional [str ], query : Optional [str ]
253
+ self ,
254
+ table : Optional [str ],
255
+ schema : Optional [str ],
256
+ query : Optional [str ],
257
+ database : Optional [str ],
242
258
):
243
259
self ._table = table
244
260
self ._schema = schema
245
261
self ._query = query
262
+ self ._database = database
246
263
247
264
@property
248
265
def query (self ):
@@ -274,6 +291,16 @@ def schema(self, schema):
274
291
"""Sets the schema of this Redshift table."""
275
292
self ._schema = schema
276
293
294
+ @property
295
+ def database (self ):
296
+ """Returns the schema name of this Redshift table."""
297
+ return self ._database
298
+
299
+ @database .setter
300
+ def database (self , database ):
301
+ """Sets the database name of this Redshift table."""
302
+ self ._database = database
303
+
277
304
@classmethod
278
305
def from_proto (cls , redshift_options_proto : DataSourceProto .RedshiftOptions ):
279
306
"""
@@ -289,6 +316,7 @@ def from_proto(cls, redshift_options_proto: DataSourceProto.RedshiftOptions):
289
316
table = redshift_options_proto .table ,
290
317
schema = redshift_options_proto .schema ,
291
318
query = redshift_options_proto .query ,
319
+ database = redshift_options_proto .database ,
292
320
)
293
321
294
322
return redshift_options
@@ -301,7 +329,10 @@ def to_proto(self) -> DataSourceProto.RedshiftOptions:
301
329
A RedshiftOptionsProto protobuf.
302
330
"""
303
331
redshift_options_proto = DataSourceProto .RedshiftOptions (
304
- table = self .table , schema = self .schema , query = self .query ,
332
+ table = self .table ,
333
+ schema = self .schema ,
334
+ query = self .query ,
335
+ database = self .database ,
305
336
)
306
337
307
338
return redshift_options_proto
@@ -314,7 +345,7 @@ class SavedDatasetRedshiftStorage(SavedDatasetStorage):
314
345
315
346
def __init__ (self , table_ref : str ):
316
347
self .redshift_options = RedshiftOptions (
317
- table = table_ref , schema = None , query = None
348
+ table = table_ref , schema = None , query = None , database = None
318
349
)
319
350
320
351
@staticmethod
0 commit comments