"""Utility functions for iemwebfarm applications."""
import html
import inspect
import io
import os
import random
import re
import string
import sys
import traceback
import warnings
from collections import namedtuple
from collections.abc import Callable
from datetime import datetime, timezone
from http import HTTPStatus
from typing import Any, Iterator
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from docutils.core import publish_string
from paste.request import parse_formvars
from pydantic import (
AfterValidator,
BaseModel,
BeforeValidator,
ConfigDict,
ValidationError,
WithJsonSchema,
field_validator,
)
from pymemcache.client import Client
from typing_extensions import Annotated
from pyiem.database import get_sqlalchemy_conn, sql_helper
from pyiem.exceptions import (
BadWebRequest,
IncompleteWebRequest,
NewDatabaseConnectionFailure,
NoDataFound,
)
from pyiem.templates.iem import TEMPLATE
from pyiem.util import LOG
# Forgive some typos
TZ_TYPOS = {
"central": "America/Chicago",
"CST": "America/Chicago",
"CDT": "America/Chicago",
"MST": "America/Denver",
"MDT": "America/Denver",
"PST": "America/Los_Angeles",
"PDT": "America/Los_Angeles",
"AKST": "America/Anchorage",
"AKDT": "America/Anchorage",
"HST": "Pacific/Honolulu",
"HDT": "Pacific/Honolulu",
"GMT": "UTC",
"UT": "UTC",
"etc/utc": "UTC",
"utc": "UTC",
}
# Match something that looks like a four digit year
YEAR_RE = re.compile(r"^\d{4}")
TELEMETRY = namedtuple(
"TELEMETRY",
["timing", "status_code", "client_addr", "app", "request_uri", "vhost"],
)
XSS_SENTINEL = "XSS"
MEMCACHED_HIT = "_mhit"
def _conv2list(mixed) -> list:
"""Convert to a list."""
if isinstance(mixed, list):
return mixed
return mixed.split(",")
def _ensure_all_strings(mixed) -> list:
"""Ensure we have all strings."""
return [x for x in mixed if isinstance(x, str)]
ListOrCSVType = Annotated[
list,
BeforeValidator(_conv2list),
AfterValidator(_ensure_all_strings),
WithJsonSchema({"type": "string"}, mode="serialization"),
]
def _is_xss_payload(val: str) -> bool:
"""Return True if the provided string appears to contain XSS payloads.
We first normalize entities with `html.unescape` and then inspect it. If
the normalized text contains HTML tag-like patterns or javascript: URIs,
treat it as naughty.
"""
normalized = html.unescape(val)
# Detect tag-like content (e.g. <script>, <img src=...>, etc.)
if re.search(r"<\s*/?\s*[a-zA-Z]", normalized):
return True
# Detect javascript: URIs which may be used as payloads
if re.search(r"javascript\s*:", normalized, re.I):
return True
return False
# https://github.com/tiangolo/fastapi/discussions/8143#discussioncomment-5147698
[docs]
class CGIModel(BaseModel):
"""A Pydantic model that parses CGI arguments."""
model_config = ConfigDict(arbitrary_types_allowed=True)
def __init__(self, **kwargs):
try:
super().__init__(**kwargs)
except ValidationError as e:
errors = e.errors()
for error in errors:
error["loc"] = ("query",) + error["loc"]
if str(error.get("ctx", {}).get("error", "")) == XSS_SENTINEL:
raise BadWebRequest(
f"XSS Key: {error['loc']} Value: {error.get('input')}"
) from e
raise IncompleteWebRequest(errors) from e
[docs]
@field_validator("*", mode="before")
@classmethod
def xss_protect(cls, v):
"""Protect against XSS attacks."""
# We want pydantic to *fail* validation when naughty HTML/JS is
# provided, but avoid false-positives for benign strings that only
# differ because of entity escaping (e.g. '&').
if isinstance(v, str):
if _is_xss_payload(v):
raise ValueError(XSS_SENTINEL)
return v
if isinstance(v, list):
for x in v:
if isinstance(x, str) and _is_xss_payload(x):
raise ValueError(XSS_SENTINEL)
return v
# Another field_validator may have run at this point and we have
# a non-str datatype
return v
[docs]
def model_to_rst(model: BaseModel) -> str:
"""Convert a Pydantic model to a reStructuredText table.
Args:
model: The Pydantic model to convert
Returns: A reStructuredText table
"""
rst = [
"CGI Arguments",
"-------------",
"",
"""
The following table lists the CGI arguments that are accepted by this service.
A HTTP ``GET`` request is required. Fields of type
**Multi-Params or CSV value** can accept either a comma separated list or
multiple parameter and value combinations. For example, ``?foo=1&foo=2`` is
equivalent to ``?foo=1,2``.
""",
"",
".. list-table::",
" :header-rows: 1",
" :widths: 15 15 70",
"",
" * - Field",
" - Type",
" - Description",
]
schema = model.model_json_schema()
for key, prop in schema["properties"].items():
required = " (required)" if key in schema.get("required", []) else ""
if "anyOf" in prop:
typetext = " or ".join([x["type"] for x in prop["anyOf"]])
else:
typetext = prop["type"]
if typetext == "array":
typetext = "Multi-Params or CSV value"
desc = str(prop.get("description", "")).strip()
if desc:
desc = "\n ".join(
[line.strip() for line in desc.splitlines() if line.strip()]
)
rst.append(f" * - {key}\n - {typetext}{required}\n - {desc}")
return "\n".join(rst)
[docs]
def write_telemetry(data: TELEMETRY) -> bool:
"""Write telemetry to the database."""
# Yes, this blocks, but if this database is not working, we are in trouble
try:
with get_sqlalchemy_conn("mesosite", rw=True) as conn:
conn.execute(
sql_helper(
"""
insert into website_telemetry(timing, status_code,
client_addr, app, request_uri, vhost)
values (:timing, :status_code, :client_addr,
:app, :request_uri, :vhost)
"""
),
data._asdict(),
)
conn.commit()
return True
except NewDatabaseConnectionFailure:
# swallow this
return False
except Exception as exp:
LOG.exception(exp)
return False
[docs]
def ensure_list(environ, key, parse_commas=True) -> list:
"""Ensure that we get something that is at least an empty list.
Args:
environ: the WSGI environ
key: the key to look for
parse_commas: split each found value based on commas, default True.
"""
if key not in environ:
return []
if isinstance(environ[key], list):
res = environ[key]
else:
res = [environ[key]]
if parse_commas:
res = [x.strip() for y in res for x in y.split(",")]
return res
[docs]
def log_request(environ: dict, multiplier: int = 1):
"""Log the request to database for future processing."""
with get_sqlalchemy_conn("mesosite") as conn:
for _ in range(multiplier):
conn.execute(
sql_helper(
"""
INSERT into weblog
(client_addr, uri, referer, http_status, x_forwarded_for, domain)
VALUES (:client_addr, :uri, :referer, :http_status, :x_forwarded_for,
:domain)
"""
),
{
"client_addr": environ.get("REMOTE_ADDR"),
"uri": environ.get("REQUEST_URI"),
"referer": environ.get("HTTP_REFERER"),
"http_status": 404,
"x_forwarded_for": environ.get("HTTP_X_FORWARDED_FOR"),
"domain": environ.get("HTTP_HOST"),
},
)
conn.commit()
[docs]
def compute_ts_from_string(form, key):
"""Convert a string to a timestamp."""
# Support various ISO8601 formats
tstr = form[key].replace("T", " ")
tz = ZoneInfo(form.get("tz", "America/Chicago"))
if tstr.endswith("Z"):
tz = ZoneInfo("UTC")
tstr = tstr[:-1]
fmt = "%Y-%m-%d %H:%M:%S"
if "." in tstr:
fmt += ".%f"
if len(tstr.split(":")) == 2:
fmt = "%Y-%m-%d %H:%M"
return datetime.strptime(tstr, fmt).replace(tzinfo=tz)
[docs]
def compute_ts(form, suffix):
"""Figure out the timestamp."""
# NB: form["tz"] should always be set by this point, but alas
month = int(form.get(f"month{suffix}", form.get("month")))
day = min(int(form.get(f"day{suffix}", form.get("day"))), 31)
# Forgive bad day of the month combinations
if month in [4, 6, 9, 11] and day == 31:
day = 30
# Forgive specification of two years, with pydantic, this could be NOne
yearval = form.get(f"year{suffix}") or form.get("year")
if isinstance(yearval, list) and len(set(yearval)) == 1:
yearval = yearval[0]
# Forgive February 29ths on non-leap years
if month == 2 and day > 28:
# Check for leap year, close enough
if int(yearval) % 4 == 0 and yearval not in [1800, 1900]:
day = min(day, 29)
else:
day = 28
return datetime(
int(yearval),
month,
day,
int(form.get(f"hour{suffix}", 0)),
int(form.get(f"minute{suffix}", 0)),
tzinfo=ZoneInfo(form.get("tz", "America/Chicago")),
)
[docs]
def add_to_environ(environ: dict, form: dict, **kwargs):
"""Build out some things auto-parsed from the request."""
# Typically apps without schema definition, so this removes some
# downstream boilerplate to prevent unwanted lists
allowed = kwargs.get("allowed_as_list", [])
if allowed:
for key in form:
if not isinstance(form[key], list):
continue
if key not in allowed:
raise BadWebRequest(f"Key {key} is a list, but not allowed")
if "schema" in kwargs:
# Remove any form keys that are not contained in the
# schema, but first we debracket keys with []
form = {
k: v
for k, v in _debracket(form).items()
if k in kwargs["schema"].model_fields and v != ""
}
# Retain a reference to the Schema instance as it may have
# private / computed attributes that are needed
environ["_cgimodel_schema"] = kwargs["schema"](**form)
form = environ["_cgimodel_schema"].model_dump()
if "tz" not in form:
form["tz"] = kwargs.get("default_tz", "America/Chicago")
# Important this is set before calling add_to_environ
form["tz"] = TZ_TYPOS.get(form["tz"], form["tz"])
for key, val in form.items():
if key not in environ:
# check for XSS and other naughty things
# We should only have either lists or strings
if isinstance(val, list):
for va in val:
if isinstance(va, str) and _is_xss_payload(va):
raise BadWebRequest(f"XSS Key: {key} Value: {va}")
elif isinstance(val, str):
if _is_xss_payload(val):
raise BadWebRequest(f"XSS Key: {key} Value: {val}")
environ[key] = form[key]
else:
warnings.warn(
f"Refusing to over-write environ key {key}",
UserWarning,
stacklevel=1,
)
if kwargs.get("parse_times", True):
try:
sts = None
ets = None
# Le Sigh, darly used sts for stations in the past, so ensure
# that sts starts with something that looks like a year
if isinstance(form.get("sts"), str) and YEAR_RE.match(form["sts"]):
sts = compute_ts_from_string(form, "sts")
if isinstance(form.get("ets"), str) and YEAR_RE.match(form["ets"]):
ets = compute_ts_from_string(form, "ets")
# NB: The usage of a schema may have already parsed a sts or ets,
# but it will be None if it was not provided
if form.get("day1") is not None and form.get("sts") is None:
sts = compute_ts(form, "1")
if form.get("day2") is not None and form.get("ets") is None:
ets = compute_ts(form, "2")
if sts:
environ["sts"] = sts
if (
"_cgimodel_schema" in environ
and "sts" in environ["_cgimodel_schema"].model_fields
):
environ["_cgimodel_schema"].sts = sts
if ets:
environ["ets"] = ets
if (
"_cgimodel_schema" in environ
and "ets" in environ["_cgimodel_schema"].model_fields
):
environ["_cgimodel_schema"].ets = ets
except (TypeError, ValueError) as exp:
raise IncompleteWebRequest("Invalid timestamp specified") from exp
except (IsADirectoryError, ZoneInfoNotFoundError) as exp:
raise IncompleteWebRequest("Invalid timezone specified") from exp
def _handle_help(start_response, **kwargs):
"""Handle the help request.
Args:
start_response: the WSGI start_response function
kwargs: the keyword arguments passed to the decorator
Returns The HTML response
"""
start_response("200 OK", [("Content-type", "text/html")])
# return the module docstring for the func
sdoc = kwargs.get("help", "Help not available")
if "schema" in kwargs:
sdoc = f"{sdoc}\n\n{model_to_rst(kwargs['schema'])}"
warning_stream = io.StringIO()
rendered = publish_string(
source=sdoc,
writer_name="html",
settings_overrides={
# Keep docutils CSS external so we can style ourselves.
"embed_stylesheet": False,
# Never raise, but capture/report warnings for visibility.
"warning_stream": warning_stream,
},
).decode("utf-8")
warning_text = warning_stream.getvalue().strip()
if warning_text:
LOG.warning("Docutils warnings while rendering help: %s", warning_text)
# Load external CSS file for styling docutils-generated content
css_path = os.path.join(
os.path.dirname(__file__), "data", "docutils-help.css"
)
with open(css_path, "r", encoding="utf-8") as fh:
css_content = fh.read()
# Get just the rendered HTML body and wrap with responsive container.
body_content = rendered.split("<body>")[1].split("</body>")[0]
styled_content = (
f"<style>\n{css_content}\n</style>"
f'<div class="container-fluid">{body_content}</div>'
)
res = {"content": styled_content}
return TEMPLATE.render(res).encode("utf-8")
def _debracket(form):
"""Remove brackets from form keys."""
res = {}
for key in form:
if key.endswith("[]"):
res[key[:-2]] = form[key]
else:
res[key] = form[key]
return res
def _mcall(
func: Callable,
environ: dict,
start_response: Callable,
memcachekey: str | Callable | None,
expire: int | Callable,
content_type: str | Callable,
):
"""Call the function with memcachekey handling."""
if memcachekey is None:
return func(environ, start_response)
key = memcachekey if isinstance(memcachekey, str) else memcachekey(environ)
# Ensure that the key is not too long for memcache
if key is None or len(key) > 249 or inspect.isgeneratorfunction(func):
# An appside short circuit when we programatically do not want cache
# or we are dealing with a generator
return func(environ, start_response)
mc = Client("iem-memcached:11211")
res = mc.get(key)
if not res:
res = func(environ, start_response)
# IEM memcache instances run with a 10MB limit `-I 10m`, so check first
if len(res) < 10e6:
mc.set(
key,
res,
expire if isinstance(expire, int) else expire(environ),
)
else:
# since our function never got called, we need to start_response
ct = (
content_type
if isinstance(content_type, str)
else content_type(environ)
)
start_response("200 OK", [("Content-type", ct)])
environ[MEMCACHED_HIT] = True
cb = environ.get("callback")
if cb is not None:
if isinstance(res, str):
res = f"{cb}({res})"
elif isinstance(res, bytes):
res = f"{cb}({res.decode('utf-8')})"
mc.close()
return res
[docs]
def ip_is_throttled(environ: dict, throttle_secs: float | Callable) -> bool:
"""Return True if the REMOTE_ADDR is throttled."""
client_ip = environ.get("REMOTE_ADDR")
if not client_ip or client_ip.startswith(("127.", "129.186.", "10.")):
return False
if isinstance(throttle_secs, Callable):
throttle_secs = throttle_secs(environ)
if throttle_secs > 0:
try:
mc = Client("iem-memcached:11211")
key = f"throttle:{client_ip}"
res = mc.get(key)
if res:
return True
mc.set(key, "1", expire=int(throttle_secs) + 1)
except Exception:
pass
return False
def _iemapp_error_response(
environ: dict,
start_response: Callable,
errormsg: str,
routine: bool = False,
code: int = 500,
) -> bytes:
"""Build an iemapp text/plain error response payload."""
# generate a random string so we can track this request
uid = "".join(
random.choice(string.ascii_uppercase + string.digits)
for _ in range(12)
)
msg = (
"Oopsy, something failed on our end, but fear not.\n"
"Please contact akrherz@iastate.edu and reference "
f"this unique identifier: {uid}\n"
"Or wait a day for daryl to review the web logs and fix "
"the bugs he wrote. What a life."
)
if not routine:
# Nicely log things about this actual request
sys.stderr.write(f"={uid} URL: {environ.get('REQUEST_URI')}\n")
sys.stderr.write(errormsg)
else:
msg = errormsg
start_response(
f"{code} {HTTPStatus(code).phrase}",
[("Content-type", "text/plain")],
)
return msg.encode("ascii", errors="replace")
def _iemapp_preflight(
environ: dict,
start_response: Callable,
kwargs: dict[str, Any],
ip_throttle_secs: float | Callable,
) -> tuple[bool, bytes | None]:
"""Run request preflight checks and return early payload when needed."""
# mixed converts this to a regular dict
form = parse_formvars(environ).mixed()
form = clean_form(form)
if "help" in form:
return True, _handle_help(start_response, **kwargs)
add_to_environ(environ, form, **kwargs)
if ip_is_throttled(environ, ip_throttle_secs):
start_response(
"429 Too Many Requests",
[("Content-type", "text/plain")],
)
return True, b"Too many requests from your IP address, slow down."
return False, None
def _normalize_iemapp_response(res: Any) -> Iterator[bytes]:
"""Yield response chunks in a uniform iterable form."""
# Need to be careful here and ensure we are returning a list of bytes.
if isinstance(res, str):
yield res.encode("utf-8")
return
if isinstance(res, bytes):
yield res
return
if isinstance(res, (tuple, list)):
for chunk in res:
yield chunk
return
yield from res
def _iemapp_emit_telemetry(
environ: dict,
start_time: datetime,
status_code: int,
) -> None:
"""Emit telemetry for an iemapp request."""
end_time = datetime.now(timezone.utc)
write_telemetry(
TELEMETRY(
(end_time - start_time).total_seconds(),
status_code,
environ.get("REMOTE_ADDR"),
environ.get("SCRIPT_NAME"),
environ.get("REQUEST_URI"),
environ.get("HTTP_HOST"),
)
)
def _parse_status_code(status: str) -> int | None:
"""Parse integer HTTP status code from a WSGI status line."""
# This could raise, but this should not be accounted for.
return int(status.split()[0])
def _capture_start_response(start_response: Callable) -> tuple[Callable, dict]:
"""Wrap start_response and capture if it was called and status code."""
state = {
"started": False,
"status_code": None,
}
def _wrapped_start_response(status, headers, exc_info=None):
state["started"] = True
state["status_code"] = _parse_status_code(status)
if exc_info is None:
return start_response(status, headers)
return start_response(status, headers, exc_info)
return _wrapped_start_response, state
def _iemapp_handle_exception(
environ: dict,
start_response: Callable,
exp: Exception,
) -> tuple[int, bytes]:
"""Map exceptions to status code and user-facing payload."""
if isinstance(exp, (IncompleteWebRequest, NoDataFound)):
status_code = 422
return status_code, _iemapp_error_response(
environ,
start_response,
str(exp),
routine=True,
code=status_code,
)
if isinstance(exp, BadWebRequest):
status_code = 422
log_request(environ, multiplier=2)
return status_code, _iemapp_error_response(
environ,
start_response,
str(exp),
code=status_code,
)
if isinstance(exp, NewDatabaseConnectionFailure):
status_code = 503
return status_code, _iemapp_error_response(
environ,
start_response,
f"get_dbconn() failed with `{exp}`",
code=status_code,
)
return 500, _iemapp_error_response(
environ,
start_response,
traceback.format_exc(),
)
[docs]
def iemapp(**kwargs):
"""Attempt to do all kinds of nice things for the user and the developer.
kwargs:
- default_tz: The default timezone to use for timestamps, the default
is ``America/Chicago``.
- enable_telemetry: Enable telemetry logging, default ``True``.
- help: Default help text, default ``Help not available``.
- parse_times: Parse the form for timestamps, default ``True``.
- schema (BaseModel): A Pydantic model to parse the form with.
- memcachekey (str or callable): A memcache key to use for caching
the response. If the callable returns `None`, no caching is done.
- memcacheexpire (int or callable): The number of seconds to cache
the response, defaults to 3600.
- content_type (str or callable): The content type to use for the
response.
- allowed_as_list (list): CGI parameters that are permitted to be
lists.
- ip_throttle_secs (float or callable): Number of seconds between
requests from the same REMOTE_ADDR, 0 to disable,
which is the default.
What all this does:
1) Attempts to catch database connection errors and handle nicely
2) Updates `environ` with some auto-parsed values + form content.
3) If the wrapped function returns a str or bytes, it will be encoded
and made into a list for the WSGI response.
Exception Raising
-----------------
The following Exception types raised within the mod_wsgi wrapped code will
trigger the following HTTP status codes sent to the client.
- `NoDataFound` or `IncompleteWebRequest` -> 422 Unprocessable Entity
- `BadWebRequest` -> 422 Unprocessable Entity (also db logged...)
- `NewDatabaseConnectionFailure` -> 503 Service Unavailable
- Any other Exception -> 500 Internal Server Error
"""
enable_telemetry = kwargs.get("enable_telemetry", True)
ip_throttle_secs = kwargs.get("ip_throttle_secs", 0)
memcachekey = kwargs.get("memcachekey")
memcacheexpire = kwargs.get("memcacheexpire", 3600)
content_type = kwargs.get("content_type", "application/json")
def _decorator(func):
"""Decorate a function to catch exceptions and do nice things."""
def _wrapped(environ, start_response):
"""Decorate function."""
start_time = datetime.now(timezone.utc)
status_code = 500
wrapped_start_response, response_state = _capture_start_response(
start_response
)
try:
short_circuit, payload = _iemapp_preflight(
environ,
wrapped_start_response,
kwargs,
ip_throttle_secs,
)
if short_circuit:
yield payload
return
res = _mcall(
func,
environ,
wrapped_start_response,
memcachekey,
memcacheexpire,
content_type,
)
# you know what assumptions do
status_code = 200
# Keep generator iteration in the try block so downstream
# iteration exceptions are mapped by our exception handler.
if inspect.isgenerator(res):
yield from _normalize_iemapp_response(res)
if enable_telemetry and not environ.get(
MEMCACHED_HIT, False
):
_iemapp_emit_telemetry(
environ, start_time, status_code
)
return
except Exception as exp:
if response_state["started"]:
# Once streaming has started, we cannot safely restart
# the response with a new status/body.
LOG.exception(
"iemapp: exception raised after start_response."
)
res = []
status_code = response_state["status_code"] or status_code
else:
status_code, res = _iemapp_handle_exception(
environ,
wrapped_start_response,
exp,
)
if enable_telemetry and not environ.get(MEMCACHED_HIT, False):
_iemapp_emit_telemetry(
environ,
start_time,
response_state["status_code"] or status_code,
)
yield from _normalize_iemapp_response(res)
return _wrapped
return _decorator