"""Database helpers."""
import getpass
import inspect
from contextlib import contextmanager
from functools import wraps
from typing import Generator
import numpy as np
import psycopg
from psycopg.adapt import Dumper
from psycopg.rows import DictRow, dict_row
from psycopg.sql import SQL, Identifier
from sqlalchemy import TextClause, create_engine, text
from sqlalchemy.engine.base import Connection
# NB: Careful of cyclic imports here...
from pyiem.exceptions import NewDatabaseConnectionFailure
# Map system users back to something supported by akrherz/iem-database repo
USERNAME_MAPPER = {
"apache": "nobody",
"www-data": "nobody",
"akrherz": "mesonet",
"runner": "mesonet", # github actions
"meteor_ldm": "ldm",
}
class _FloatDumper(Dumper):
"""Prevent NaN from reaching the database."""
@staticmethod
def dump(obj):
"""Opinionated dumper."""
if np.isnan(obj):
return None
return str(obj).encode()
# Adapters for Python to PostgreSQL
psycopg.adapters.register_dumper(float, _FloatDumper)
psycopg.adapters.register_dumper(np.float32, _FloatDumper)
psycopg.adapters.register_dumper(np.float64, _FloatDumper)
psycopg.adapters.register_dumper(np.int64, _FloatDumper)
[docs]
def get_dbconnstr(name, **kwargs) -> str:
"""Create a database connection string/URI.
Args:
name (str): the database name to connect to.
**kwargs: any additional arguments to pass to psycopg.connect
user (str): the database user to connect as
host (str): the database host to connect to
port (int): the database port to connect to
connect_timeout (int): Connection timeout in seconds, default 30.
rw (bool | None): Require that the connected cluster can accept
write requests. The default is `None`, which some opinionated
logic happens. If `nobody` is computed, then read-only is assumed.
Returns:
str
"""
user = kwargs.get("user")
# RW1. Default to requiring a read-write connection
# RW4. Covers the case of rw=True
target_session_attrs = "read-write"
if user is None:
user = USERNAME_MAPPER.get(getpass.getuser(), getpass.getuser())
# RW2. If we compute nobody and no rw specified, any will work
if user == "nobody" and kwargs.get("rw") is None:
target_session_attrs = "any"
# RW3. If rw is explicitly False, then any will work
if kwargs.get("rw") is False:
target_session_attrs = "any"
host = kwargs.get("host")
if host is None:
host = f"iemdb-{name}.local"
port = kwargs.get("port")
if port is None:
port = 5432
# 15 seconds found to be a bit tight for local ISU congestion
return (
f"postgresql://{user}@{host}:{port}/{name}?"
f"connect_timeout={kwargs.get('connect_timeout', 30)}&"
f"gssencmode={kwargs.get('gssencmode', 'disable')}&"
f"target_session_attrs={target_session_attrs}"
)
[docs]
def get_dbconn(database="mesosite", user=None, host=None, port=5432, **kwargs):
"""Helper function with business logic to get a database connection
Note that this helper could return a read-only database connection if the
connection to the primary server fails.
Args:
database (str,optional): the database name to connect to.
default: mesosite
user (str,optional): hard coded user to connect as, default: current user
host (str,optional): hard coded hostname to connect as,
default: iemdb.local
port (int,optional): the TCP port that PostgreSQL is listening
defaults to 5432
password (str,optional): the password to use.
rw (bool | None): Require that the connected cluster can accept
write requests. The default is `None`, which some opinionated
logic happens. If `nobody` is computed, then read-only is assumed.
Returns:
psycopg database connection
"""
dsn = get_dbconnstr(database, user=user, host=host, port=port, **kwargs)
attempt = 0
conn = None
while attempt < 3:
attempt += 1
try:
conn = psycopg.connect(dsn)
# FIXME make this opinionated to return a default row_factory
break
except Exception as exp:
if attempt == 3:
raise NewDatabaseConnectionFailure(str(exp)) from exp
return conn
[docs]
def get_dbconnc(
database: str = "mesosite",
user: str = None,
host: str = None,
cursor_name: str = None,
**kwargs,
) -> tuple[psycopg.Connection[DictRow], psycopg.ServerCursor[DictRow]]:
"""Helper function to get a database connection + dict_row cursor.
Note that this helper could return a read-only database connection if the
connection to the primary server fails.
Args:
database (str,optional): the database name to connect to.
default: mesosite
user (str,optional): hard coded user to connect as, default: current user
host (str,optional): hard coded hostname to connect as,
default: iemdb.local
cursor_name (str,optional): name of the cursor to create
port (int,optional): the TCP port that PostgreSQL is listening
defaults to 5432
password (str,optional): the password to use.
rw (bool | None): Require that the connected cluster can accept
write requests. The default is `None`, which some opinionated
logic happens. If `nobody` is computed, then read-only is assumed.
"""
conn = get_dbconn(database, user=user, host=host, **kwargs)
conn.row_factory = dict_row
return conn, conn.cursor(cursor_name)
[docs]
@contextmanager
def get_sqlalchemy_conn(
name: str, **kwargs
) -> Generator[Connection, None, None]:
"""An auto-disposing sqlalchemy context-manager helper.
This is used for when we really do not want to manage having pools of
database connections open. So this isn't something that is fast!
Args:
name (str): the database to connect to, passed to get_dbconnstr
rw (bool | None): Require that the connected cluster can accept
write requests. The default is `None`, which some opinionated
logic happens. If `nobody` is computed, then read-only is assumed.
**kwargs: any additional arguments to pass to get_dbconnstr
"""
# Le Sigh
connstr = get_dbconnstr(name, **kwargs).replace(
"postgresql",
"postgresql+psycopg",
)
engine = create_engine(connstr)
try:
# This seems to be a best practice as the finally will always clean
# up the connection
with engine.connect() as conn:
yield conn
finally:
engine.dispose()
[docs]
def with_sqlalchemy_conn(name: str, **kwargs):
"""Decorator variant of get_sqlalchemy_conn adding ``conn=`` to function.
Usage::
@with_sqlalchemy_conn("dbname")
def foo(args, conn=None, **kwargs):
...
Note:
Be sure to commit any transactions before returning from the
decorated function.
Args:
name (str): the database to connect to, passed to get_dbconnstr
rw (bool | None): Require that the connected cluster can accept
write requests. The default is `None`, which some opinionated
logic happens. If `nobody` is computed, then read-only is assumed.
**kwargs: any additional arguments to pass to get_dbconnstr
"""
def decorator(func):
if inspect.isgeneratorfunction(func):
@wraps(func)
def generator_wrapper(*args, **kwds):
with get_sqlalchemy_conn(name, **kwargs) as conn:
yield from func(*args, **kwds, conn=conn)
return generator_wrapper
@wraps(func)
def wrapper(*args, **kwds):
with get_sqlalchemy_conn(name, **kwargs) as conn:
return func(*args, **kwds, conn=conn)
return wrapper
return decorator
[docs]
def sql_helper(sql: str, **kwargs) -> TextClause:
"""Run string through psycopg.sql machinery destined for sqlalchemy.Allows
for removal of boilerplate and appease SQL injection detection.
Example:
```python
sql = "select bah from {table} where {limiter} foo = :bar"
stm = sql_helper(sql, table='foo', limiter='a = :a and ')
pd.read_sql(stm, conn, params={'bar': 'baz', 'a': 1})
```
Args:
sql (str): the SQL statement to process
**kwargs: arguments needed to build the string.
"""
args = {"table": Identifier(kwargs.pop("table", ""))}
for key, value in kwargs.items():
args[key] = SQL(value)
return text(SQL(sql).format(**args).as_string())