1919from google .cloud .environment_vars import BIGTABLE_EMULATOR
2020from google .cloud .bigtable .data import BigtableDataClientAsync
2121from google .cloud .bigtable .data ._cross_sync import CrossSync
22+ from helpers import sql_encoding_helpers
2223
2324if not CrossSync .is_async :
2425 from client_handler_data_async import error_safe
@@ -32,6 +33,7 @@ def error_safe(func):
3233 Catch and pass errors back to the grpc_server_process
3334 Also check if client is closed before processing requests
3435 """
36+
3537 async def wrapper (self , * args , ** kwargs ):
3638 try :
3739 if self .closed :
@@ -50,6 +52,7 @@ def encode_exception(exc):
5052 Encode an exception or chain of exceptions to pass back to grpc_handler
5153 """
5254 from google .api_core .exceptions import GoogleAPICallError
55+
5356 error_msg = f"{ type (exc ).__name__ } : { exc } "
5457 result = {"error" : error_msg }
5558 if exc .__cause__ :
@@ -113,7 +116,9 @@ async def ReadRows(self, request, **kwargs):
113116 table_id = request .pop ("table_name" ).split ("/" )[- 1 ]
114117 app_profile_id = self .app_profile_id or request .get ("app_profile_id" , None )
115118 table = self .client .get_table (self .instance_id , table_id , app_profile_id )
116- kwargs ["operation_timeout" ] = kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
119+ kwargs ["operation_timeout" ] = (
120+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
121+ )
117122 result_list = CrossSync .rm_aio (await table .read_rows (request , ** kwargs ))
118123 # pack results back into protobuf-parsable format
119124 serialized_response = [row ._to_dict () for row in result_list ]
@@ -124,7 +129,9 @@ async def ReadRow(self, row_key, **kwargs):
124129 table_id = kwargs .pop ("table_name" ).split ("/" )[- 1 ]
125130 app_profile_id = self .app_profile_id or kwargs .get ("app_profile_id" , None )
126131 table = self .client .get_table (self .instance_id , table_id , app_profile_id )
127- kwargs ["operation_timeout" ] = kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
132+ kwargs ["operation_timeout" ] = (
133+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
134+ )
128135 result_row = CrossSync .rm_aio (await table .read_row (row_key , ** kwargs ))
129136 # pack results back into protobuf-parsable format
130137 if result_row :
@@ -135,10 +142,13 @@ async def ReadRow(self, row_key, **kwargs):
135142 @error_safe
136143 async def MutateRow (self , request , ** kwargs ):
137144 from google .cloud .bigtable .data .mutations import Mutation
145+
138146 table_id = request ["table_name" ].split ("/" )[- 1 ]
139147 app_profile_id = self .app_profile_id or request .get ("app_profile_id" , None )
140148 table = self .client .get_table (self .instance_id , table_id , app_profile_id )
141- kwargs ["operation_timeout" ] = kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
149+ kwargs ["operation_timeout" ] = (
150+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
151+ )
142152 row_key = request ["row_key" ]
143153 mutations = [Mutation ._from_dict (d ) for d in request ["mutations" ]]
144154 CrossSync .rm_aio (await table .mutate_row (row_key , mutations , ** kwargs ))
@@ -147,21 +157,29 @@ async def MutateRow(self, request, **kwargs):
147157 @error_safe
148158 async def BulkMutateRows (self , request , ** kwargs ):
149159 from google .cloud .bigtable .data .mutations import RowMutationEntry
160+
150161 table_id = request ["table_name" ].split ("/" )[- 1 ]
151162 app_profile_id = self .app_profile_id or request .get ("app_profile_id" , None )
152163 table = self .client .get_table (self .instance_id , table_id , app_profile_id )
153- kwargs ["operation_timeout" ] = kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
154- entry_list = [RowMutationEntry ._from_dict (entry ) for entry in request ["entries" ]]
164+ kwargs ["operation_timeout" ] = (
165+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
166+ )
167+ entry_list = [
168+ RowMutationEntry ._from_dict (entry ) for entry in request ["entries" ]
169+ ]
155170 CrossSync .rm_aio (await table .bulk_mutate_rows (entry_list , ** kwargs ))
156171 return "OK"
157172
158173 @error_safe
159174 async def CheckAndMutateRow (self , request , ** kwargs ):
160175 from google .cloud .bigtable .data .mutations import Mutation , SetCell
176+
161177 table_id = request ["table_name" ].split ("/" )[- 1 ]
162178 app_profile_id = self .app_profile_id or request .get ("app_profile_id" , None )
163179 table = self .client .get_table (self .instance_id , table_id , app_profile_id )
164- kwargs ["operation_timeout" ] = kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
180+ kwargs ["operation_timeout" ] = (
181+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
182+ )
165183 row_key = request ["row_key" ]
166184 # add default values for incomplete dicts, so they can still be parsed to objects
167185 true_mutations = []
@@ -180,33 +198,44 @@ async def CheckAndMutateRow(self, request, **kwargs):
180198 # invalid mutation type. Conformance test may be sending generic empty request
181199 false_mutations .append (SetCell ("" , "" , "" , 0 ))
182200 predicate_filter = request .get ("predicate_filter" , None )
183- result = CrossSync .rm_aio (await table .check_and_mutate_row (
184- row_key ,
185- predicate_filter ,
186- true_case_mutations = true_mutations ,
187- false_case_mutations = false_mutations ,
188- ** kwargs ,
189- ))
201+ result = CrossSync .rm_aio (
202+ await table .check_and_mutate_row (
203+ row_key ,
204+ predicate_filter ,
205+ true_case_mutations = true_mutations ,
206+ false_case_mutations = false_mutations ,
207+ ** kwargs ,
208+ )
209+ )
190210 return result
191211
192212 @error_safe
193213 async def ReadModifyWriteRow (self , request , ** kwargs ):
194214 from google .cloud .bigtable .data .read_modify_write_rules import IncrementRule
195215 from google .cloud .bigtable .data .read_modify_write_rules import AppendValueRule
216+
196217 table_id = request ["table_name" ].split ("/" )[- 1 ]
197218 app_profile_id = self .app_profile_id or request .get ("app_profile_id" , None )
198219 table = self .client .get_table (self .instance_id , table_id , app_profile_id )
199- kwargs ["operation_timeout" ] = kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
220+ kwargs ["operation_timeout" ] = (
221+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
222+ )
200223 row_key = request ["row_key" ]
201224 rules = []
202225 for rule_dict in request .get ("rules" , []):
203226 qualifier = rule_dict ["column_qualifier" ]
204227 if "append_value" in rule_dict :
205- new_rule = AppendValueRule (rule_dict ["family_name" ], qualifier , rule_dict ["append_value" ])
228+ new_rule = AppendValueRule (
229+ rule_dict ["family_name" ], qualifier , rule_dict ["append_value" ]
230+ )
206231 else :
207- new_rule = IncrementRule (rule_dict ["family_name" ], qualifier , rule_dict ["increment_amount" ])
232+ new_rule = IncrementRule (
233+ rule_dict ["family_name" ], qualifier , rule_dict ["increment_amount" ]
234+ )
208235 rules .append (new_rule )
209- result = CrossSync .rm_aio (await table .read_modify_write_row (row_key , rules , ** kwargs ))
236+ result = CrossSync .rm_aio (
237+ await table .read_modify_write_row (row_key , rules , ** kwargs )
238+ )
210239 # pack results back into protobuf-parsable format
211240 if result :
212241 return result ._to_dict ()
@@ -218,6 +247,55 @@ async def SampleRowKeys(self, request, **kwargs):
218247 table_id = request ["table_name" ].split ("/" )[- 1 ]
219248 app_profile_id = self .app_profile_id or request .get ("app_profile_id" , None )
220249 table = self .client .get_table (self .instance_id , table_id , app_profile_id )
221- kwargs ["operation_timeout" ] = kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
250+ kwargs ["operation_timeout" ] = (
251+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
252+ )
222253 result = CrossSync .rm_aio (await table .sample_row_keys (** kwargs ))
223254 return result
255+
256+ @error_safe
257+ async def ExecuteQuery (self , request , ** kwargs ):
258+ app_profile_id = self .app_profile_id or request .get ("app_profile_id" , None )
259+ query = request .get ("query" )
260+ params = request .get ("params" ) or {}
261+ # Note that the request has been coverted to json, and the code for this converts
262+ # query param names to snake case. convert_params reverses this conversion. For this
263+ # reason, snake case params will have issues if they're used in the conformance tests.
264+ formatted_params , parameter_types = sql_encoding_helpers .convert_params (params )
265+ operation_timeout = (
266+ kwargs .get ("operation_timeout" , self .per_operation_timeout ) or 20
267+ )
268+ result = CrossSync .rm_aio (
269+ await self .client .execute_query (
270+ query ,
271+ self .instance_id ,
272+ parameters = formatted_params ,
273+ parameter_types = parameter_types ,
274+ app_profile_id = app_profile_id ,
275+ operation_timeout = operation_timeout ,
276+ prepare_operation_timeout = operation_timeout ,
277+ )
278+ )
279+ rows = [r async for r in result ]
280+ md = result .metadata
281+ proto_rows = []
282+ for r in rows :
283+ vals = []
284+ for c in md .columns :
285+ vals .append (sql_encoding_helpers .convert_value (c .column_type , r [c .column_name ]))
286+
287+ proto_rows .append ({"values" : vals })
288+
289+ proto_columns = []
290+ for c in md .columns :
291+ proto_columns .append (
292+ {
293+ "name" : c .column_name ,
294+ "type" : sql_encoding_helpers .convert_type (c .column_type ),
295+ }
296+ )
297+
298+ return {
299+ "metadata" : {"columns" : proto_columns },
300+ "rows" : proto_rows ,
301+ }
0 commit comments