Skip to content

Utilities

check_temp_query(sql)

Checks if a query to a temporary table has had temp wrapped in quote marks.

Parameters:

Name Type Description Default
sql str

an SQL query

required
Source code in pydbtools/utils.py
def check_temp_query(sql: str):
    """
    Checks if a query to a temporary table
    has had __temp__ wrapped in quote marks.

    Args:
        sql (str): an SQL query

    Raises:
        ValueError
    """
    if re.findall(r'["|\']__temp__["|\']\.', sql.lower()):
        raise ValueError(
            "When querying a temporary database, "
            "__temp__ should not be wrapped in quotes"
        )

clean_query(sql, fmt_opts=None)

removes trailing whitespace, newlines and final semicolon from sql for use with sqlparse package Args: sql (str): The raw SQL query fmt_opts (dict): Dictionary of params to pass to sqlparse.format. If None then sqlparse.format is not called. Returns: str: The cleaned SQL query

Source code in pydbtools/utils.py
def clean_query(sql: str, fmt_opts: Optional[dict] = None) -> str:
    """
    removes trailing whitespace, newlines and final
    semicolon from sql for use with
    sqlparse package
    Args:
        sql (str): The raw SQL query
        fmt_opts (dict): Dictionary of params to pass to sqlparse.format.
        If None then sqlparse.format is not called.
    Returns:
        str: The cleaned SQL query
    """
    if fmt_opts is None:
        fmt_opts = {}
    fmt_opts["strip_comments"] = True
    sql = sqlparse.format(sql, **fmt_opts)
    sql = " ".join(sql.splitlines()).strip().rstrip(";")
    return sql

replace_temp_database_name_reference(sql, database_name)

Replaces references to the user's temp database temp with the database_name string provided.

Parameters:

Name Type Description Default
sql str

The raw SQL query as a string

required
database_name str

The database name to replace temp

required

Returns:

Name Type Description
str str

The new SQL query which is sent to Athena

Source code in pydbtools/utils.py
def replace_temp_database_name_reference(sql: str, database_name: str) -> str:
    """
    Replaces references to the user's temp database __temp__
    with the database_name string provided.

    Args:
        sql (str): The raw SQL query as a string
        database_name (str): The database name to replace __temp__

    Returns:
        str: The new SQL query which is sent to Athena
    """

    parsed = sqlparse.parse(sql)
    new_query = []
    for query in parsed:
        check_temp_query(str(query))
        # Get all the separated tokens from subtrees
        fq = list(query.flatten())
        # Join them back together replacing __temp__
        # where necessary
        new_query.append(
            "".join(
                re.sub("^__temp__", database_name, str(word), flags=re.IGNORECASE)
                for word in fq
            )
        )
    # Strip output for consistency, different versions of sqlparse
    # treat a trailing newline differently
    return "".join(new_query).strip()

get_database_name_from_sql(sql)

Obtains database name from SQL query for use by awswrangler.

Parameters:

Name Type Description Default
sql str

The raw SQL query as a string

required

Returns:

Name Type Description
str str

The database table name

Source code in pydbtools/utils.py
def get_database_name_from_sql(sql: str) -> str:
    """
    Obtains database name from SQL query for use
    by awswrangler.

    Args:
        sql (str): The raw SQL query as a string

    Returns:
        str: The database table name
    """

    for table in sql_metadata.Parser(sql).tables:
        # Return the first database seen in the
        # form "database.table"
        xs = table.split(".")
        if len(xs) > 1:
            return xs[0]

    # Return default in case of failure to parse
    return None