Skip to content

Commit

Permalink
Merge pull request #13 from stanford-oval/litellm_migration
Browse files Browse the repository at this point in the history
added logging
  • Loading branch information
liamjxu authored Apr 11, 2024
2 parents 7780e34 + 062fa1c commit 0ec931f
Showing 1 changed file with 23 additions and 7 deletions.
30 changes: 23 additions & 7 deletions src/suql/sql_free_text_support/execute_free_text_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import string
import time
import traceback
import logging
from collections import defaultdict
from copy import deepcopy
from typing import List, Union
Expand Down Expand Up @@ -229,7 +230,7 @@ def visit_SelectStmt(self, ancestors, node: SelectStmt):
list(map(lambda x: f'"{x[0]}" {x[1]}', column_info))
)
create_stmt = f"CREATE TABLE {tmp_table_name} (\n{column_create_stmt}\n); GRANT SELECT ON {tmp_table_name} TO {self.select_username};"
print("created table {}".format(tmp_table_name))
logging.info("created table {}".format(tmp_table_name))
execute_sql(
create_stmt,
user=self.create_username,
Expand Down Expand Up @@ -484,9 +485,9 @@ def verify_single_value(single_value, single_column_name):
)
)
if all_found:
print("\n".join(found_stmt))
logging.info("\n".join(found_stmt))
elif found_stmt:
print("partially verified: " + "\n".join(found_stmt))
logging.info("partially verified: " + "\n".join(found_stmt))

return all_found

Expand Down Expand Up @@ -677,7 +678,7 @@ def _retrieve_and_verify(
id_res.append(each_res[0])

end_time = time.time()
print("retrieve + verification time {}s".format(end_time - start_time))
logging.info("retrieve + verification time {}s".format(end_time - start_time))

if single_table:
res = list(filter(lambda x: x[id_index] in id_res, existing_results))
Expand Down Expand Up @@ -906,13 +907,13 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr):
password=self.select_userpswd,
)
except psyconpg2Error:
print(
logging.info(
"above error happens during ENUM classification attempts. Marking this predicate as returning answer."
)
res = True

if not res:
print("determined the above predicate returns no result")
logging.info("determined the above predicate returns no result")
# try to classify into one of the known values
# first, we need to find out what is the value here - some heuristics here to find out
column_name, value_res = _get_a_expr_field_value(node)
Expand All @@ -926,7 +927,7 @@ def visit_A_Expr(self, ancestors: Ancestor, node: A_Expr):
else:
raise ValueError()

print(
logging.info(
"determined column name: {}; value: {}".format(
column_name, value_res_clear
)
Expand Down Expand Up @@ -1524,6 +1525,7 @@ def suql_execute(
llm_model_name="gpt-3.5-turbo-0125",
max_verify=20,
loggings="",
log_filename=None,
disable_try_catch=False,
embedding_server_address="http://127.0.0.1:8501",
select_username="select_user",
Expand Down Expand Up @@ -1553,6 +1555,8 @@ def suql_execute(
`loggings` (str, optional): Prefix for error case loggings. Errors are written to a "_suql_error_log.txt"
file by default.
`log_filename` (str, optional): Logging file name for the SUQL compiler. If not provided, logging is disabled.
`disable_try_catch` (bool, optional): whether to disable try-catch (errors would directly propagate to caller).
Expand Down Expand Up @@ -1589,6 +1593,18 @@ def suql_execute(
Ideally, this query should match against all `Mcdonald's`, as opposed to just 'mcdonalds'.
FTS helps with such cases.
"""
if log_filename:
logging.basicConfig(level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
handlers=[
logging.FileHandler(log_filename),
logging.StreamHandler()
])

else:
logging.basicConfig(level=logging.CRITICAL + 1)

results, column_names, cache = _suql_execute_single(
suql,
table_w_ids,
Expand Down

0 comments on commit 0ec931f

Please sign in to comment.