3737import org .slf4j .Logger ;
3838import org .slf4j .LoggerFactory ;
3939
40+ import java .math .BigDecimal ;
4041import java .sql .Connection ;
4142import java .sql .SQLException ;
43+ import java .sql .Timestamp ;
44+ import java .text .SimpleDateFormat ;
45+ import java .time .LocalDateTime ;
46+ import java .time .format .DateTimeFormatter ;
4247import java .util .ArrayList ;
4348import java .util .Collections ;
49+ import java .util .Date ;
4450import java .util .List ;
51+ import java .util .concurrent .TimeUnit ;
4552import java .util .stream .Collectors ;
4653
54+ import static org .apache .arrow .vector .types .pojo .ArrowType .ArrowTypeID .Utf8 ;
55+
4756/**
4857 * Extends {@link JdbcSplitQueryBuilder} and implements MySql specific SQL clauses for split.
4958 *
5261public class SnowflakeQueryStringBuilder
5362 extends JdbcSplitQueryBuilder
5463{
55- private static final String EMPTY_STRING = "" ;
5664 private static final Logger LOGGER = LoggerFactory .getLogger (SnowflakeQueryStringBuilder .class );
57- private final String quoteCharacters = "\" " ;
65+ private static final String quoteCharacters = "\" " ;
66+ private static final String singleQuoteCharacters = "\' " ;
5867
5968 public SnowflakeQueryStringBuilder (final String quoteCharacters , final FederationExpressionParser federationExpressionParser )
6069 {
@@ -136,6 +145,12 @@ protected String quote(String name)
136145 return quoteCharacters + name + quoteCharacters ;
137146 }
138147
148+ protected String singleQuote (String name )
149+ {
150+ name = name .replace (singleQuoteCharacters , singleQuoteCharacters + singleQuoteCharacters );
151+ return singleQuoteCharacters + name + singleQuoteCharacters ;
152+ }
153+
139154 private List <String > toConjuncts (List <Field > columns , Constraints constraints , List <TypeAndValue > accumulator )
140155 {
141156 List <String > conjuncts = new ArrayList <>();
@@ -181,10 +196,10 @@ private String toPredicate(String columnName, ValueSet valueSet, ArrowType type,
181196 if (!range .getLow ().isLowerUnbounded ()) {
182197 switch (range .getLow ().getBound ()) {
183198 case ABOVE :
184- rangeConjuncts .add (toPredicate (columnName , ">" , range .getLow ().getValue (), type , accumulator ));
199+ rangeConjuncts .add (toPredicate (columnName , ">" , range .getLow ().getValue (), type ));
185200 break ;
186201 case EXACTLY :
187- rangeConjuncts .add (toPredicate (columnName , ">=" , range .getLow ().getValue (), type , accumulator ));
202+ rangeConjuncts .add (toPredicate (columnName , ">=" , range .getLow ().getValue (), type ));
188203 break ;
189204 case BELOW :
190205 throw new IllegalArgumentException ("Low marker should never use BELOW bound" );
@@ -197,10 +212,10 @@ private String toPredicate(String columnName, ValueSet valueSet, ArrowType type,
197212 case ABOVE :
198213 throw new IllegalArgumentException ("High marker should never use ABOVE bound" );
199214 case EXACTLY :
200- rangeConjuncts .add (toPredicate (columnName , "<=" , range .getHigh ().getValue (), type , accumulator ));
215+ rangeConjuncts .add (toPredicate (columnName , "<=" , range .getHigh ().getValue (), type ));
201216 break ;
202217 case BELOW :
203- rangeConjuncts .add (toPredicate (columnName , "<" , range .getHigh ().getValue (), type , accumulator ));
218+ rangeConjuncts .add (toPredicate (columnName , "<" , range .getHigh ().getValue (), type ));
204219 break ;
205220 default :
206221 throw new AssertionError ("Unhandled bound: " + range .getHigh ().getBound ());
@@ -214,17 +229,86 @@ private String toPredicate(String columnName, ValueSet valueSet, ArrowType type,
214229
215230 // Add back all of the possible single values either as an equality or an IN predicate
216231 if (singleValues .size () == 1 ) {
217- disjuncts .add (toPredicate (columnName , "=" , Iterables .getOnlyElement (singleValues ), type , accumulator ));
232+ disjuncts .add (toPredicate (columnName , "=" , Iterables .getOnlyElement (singleValues ), type ));
218233 }
219234 else if (singleValues .size () > 1 ) {
235+ List <Object > val = new ArrayList <>();
220236 for (Object value : singleValues ) {
221- accumulator .add (new TypeAndValue ( type , value ));
237+ val .add ((( type . getTypeID (). equals ( Utf8 ) || type . getTypeID (). equals ( ArrowType . ArrowTypeID . Date )) ? singleQuote ( getObjectForWhereClause ( columnName , value , type ). toString ()) : getObjectForWhereClause ( columnName , value , type ) ));
222238 }
223- String values = Joiner .on ("," ).join (Collections . nCopies ( singleValues . size (), "?" ) );
224- disjuncts .add (quote ( columnName ) + " IN (" + values + ")" );
239+ String values = Joiner .on ("," ).join (val );
240+ disjuncts .add (columnName + " IN (" + values + ")" );
225241 }
226242 }
227-
228243 return "(" + Joiner .on (" OR " ).join (disjuncts ) + ")" ;
229244 }
245+
246+ protected String toPredicate (String columnName , String operator , Object value , ArrowType type )
247+ {
248+ return columnName + " " + operator + " " + ((type .getTypeID ().equals (Utf8 ) || type .getTypeID ().equals (ArrowType .ArrowTypeID .Date )) ? singleQuote (getObjectForWhereClause (columnName , value , type ).toString ()) : getObjectForWhereClause (columnName , value , type ));
249+ }
250+
251+ private static Object getObjectForWhereClause (String columnName , Object value , ArrowType arrowType )
252+ {
253+ String val ;
254+ StringBuilder tempVal ;
255+
256+ switch (arrowType .getTypeID ()) {
257+ case Int :
258+ return ((Number ) value ).longValue ();
259+ case Decimal :
260+ if (value instanceof BigDecimal ) {
261+ return (BigDecimal ) value ;
262+ }
263+ else if (value instanceof Number ) {
264+ return BigDecimal .valueOf (((Number ) value ).doubleValue ());
265+ }
266+ else {
267+ throw new IllegalArgumentException ("Unexpected type for decimal conversion: " + value .getClass ().getName ());
268+ }
269+ case FloatingPoint :
270+ return (double ) value ;
271+ case Bool :
272+ return (Boolean ) value ;
273+ case Utf8 :
274+ return value .toString ();
275+ case Date :
276+ val = value .toString ();
277+ if (val .contains ("-" ) && val .length () == 16 ) {
278+ LocalDateTime dateTime = LocalDateTime .parse (val );
279+ DateTimeFormatter formatter = DateTimeFormatter .ofPattern ("yyyy-MM-dd HH:mm:ss" );
280+ return dateTime .format (formatter );
281+ }
282+ else if (val .contains ("-" )) {
283+ tempVal = new StringBuilder (val );
284+ tempVal = tempVal .length () == 19 ? tempVal .append (".0" ) : tempVal ;
285+ val = String .format ("%-26s" , tempVal ).replace (' ' , '0' ).replace ("T" , " " );
286+ return val ; // Returning as string formatted datetime
287+ }
288+ else {
289+ long days = Long .parseLong (val );
290+ long milliseconds = TimeUnit .DAYS .toMillis (days );
291+ return new SimpleDateFormat ("yyyy-MM-dd" ).format (new Date (milliseconds ));
292+ }
293+ case Timestamp :
294+ long millis = ((Number ) value ).longValue ();
295+ Timestamp timestamp = new Timestamp (millis );
296+ SimpleDateFormat sdf = new SimpleDateFormat ("yyyy-MM-dd HH:mm:ss" );
297+ return sdf .format (timestamp );
298+ case Time :
299+ case Interval :
300+ case Binary :
301+ case FixedSizeBinary :
302+ case Null :
303+ case Struct :
304+ case List :
305+ case FixedSizeList :
306+ case Union :
307+ case NONE :
308+ throw new UnsupportedOperationException ("The Arrow type: " + arrowType .getTypeID ().name () + " is currently not supported" );
309+ default :
310+ throw new IllegalArgumentException ("Unknown type encountered during processing: " + columnName +
311+ " Field Type: " + arrowType .getTypeID ().name ());
312+ }
313+ }
230314}
0 commit comments