# -*- coding: utf-8 -*-
#
# Session management.
#
# ------------------------------------------------
# imports
# -------
import atexit
import base64
from contextvars import ContextVar
from collections import Counter
from dataclasses import dataclass, field
from functools import wraps
import json
import logging
import os
import re
import sys
from typing import Mapping
import traceback
import warnings
import time
try:
import cookielib as cookiejar
except ImportError:
import http.cookiejar as cookiejar
from gems import cached
import requests
from requests import ConnectionError
from requests.packages import urllib3
import six.moves.urllib as urllib
from .utils import debugtime, is_uuid
from espclient.config import DEFAULT_CLIENT_COOKIES
_ENABLE_TIMING = str(os.environ.get("L7_CLIENT_ENABLE_TIMING", "true")).lower() == "true"
# config
# ------
CONFIG = None
SESSION = None
_debug = debugtime("API")
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
logger = logging.getLogger(__name__)
_module_load_time = time.perf_counter_ns()
_NS_TO_UNIT_DIVISOR = {
"ns": 1,
"us": 1e3,
"ms": 1e6,
"s": 1e9,
}
@dataclass
class ApiCall:
api: str
method: str
call_count: int = 0
# 1 nanosecond is 10^-9 seconds, so set min_time to 10^18 so it becomes
# 10^9 seconds... our actual min must be less than that since a billion
# seconds is ~32 years...
min_time: int = 1e18
max_time: int = 0
cumulative_time: int = 0
unique_call_counts: Mapping[str, int] = field(default_factory=Counter)
@classmethod
def base_api(cls, method, endpoint) -> str:
"""Given an ApiCall, returns a string representing the "base" API. For instance:
GET /api/samples/608f4f2e-843c-492c-bec6-a51f6f03e0ba -> GET /api/samples/:uuid
GET /api/samples -> GET /api/samples
PUT /api/samples/:uuid -> POST /api/samples/:uuid
PUT /api/samples -> POST /api/samples
The distinction between with/without UUID is to make it easier to distinguish between
bulk fetch/update calls vs. individual update calls while still allowing for
aggregate reporting of the individual calls.
"""
parsed = urllib.parse.urlparse(endpoint)
# Note: when we move to python3.8 as a minimum, we can
# change to using walrus operators.
path = parsed.path
match = re.match(r"^(.*?)/[0-9a-f-]{36}(/.*)?$", path)
if match:
path = f'{match.group(1)}/:uuid{match.group(2) or ""}'
return f"{method} {path}"
def update_timing_data(self, endpoint, total_time):
self.call_count += 1
self.unique_call_counts[endpoint] += 1
self.cumulative_time += total_time
if total_time < self.min_time:
self.min_time = total_time
if total_time > self.max_time:
self.max_time = total_time
@property
def mean_time(self):
if self.call_count > 0:
return self.cumulative_time / self.call_count
return 0.0
def report_string(self, units="ns"):
return ",".join([str(x) for x in self.report_list(units)])
def report_list(self, units="ns"):
divisor = _NS_TO_UNIT_DIVISOR.get(units, 1)
return [
self.api,
self.call_count,
self.cumulative_time / divisor,
self.min_time / divisor,
self.max_time / divisor,
self.mean_time / divisor,
len(self.unique_call_counts),
]
# no typehints for better backwards compatibility.
_api_calls = {}
# with typehints for better comprehension - can use this form
# starting with python 3.9, so when we drop support for
# python 3.8, uncomment this line, remove _api_calls = {}
# above, and remove this comment.
# _api_calls: dict[str, ApiCall] = {}
def _report_apicalls(*args, **kwargs):
program_exit_time = time.perf_counter_ns()
calls: list[ApiCall] = sorted(
[call for call in _api_calls.values()], key=lambda x: (x.cumulative_time, x.mean_time), reverse=True
)
if not calls:
return
units = "ms"
divisor = _NS_TO_UNIT_DIVISOR.get(units, 1)
print("{", file=sys.stderr)
print(f' "units": {json.dumps(units)}', file=sys.stderr)
print(
f' "approximate_execution_time": {json.dumps((program_exit_time - _module_load_time)/divisor)}', file=sys.stderr
)
headers = "API #Calls Cumulative Min Max Mean #UniqueEndpoints".split(" ")
print(f' "call_headers": {headers}', file=sys.stderr)
print(f' "calls": [', file=sys.stderr)
for call in calls:
print(f" {json.dumps(call.report_list(units))}", file=sys.stderr)
print(" ]", file=sys.stderr)
sys.stderr.flush()
if _ENABLE_TIMING:
atexit.register(_report_apicalls)
_noretry_codes: ContextVar[set] = ContextVar("_noretry_codes")
class noretry_on:
def __init__(self, *codes):
self.skip_retry_codes = set(codes)
self.token = None
def __enter__(self):
self.token = _noretry_codes.set(self.skip_retry_codes)
def __exit__(self, exc_type, exc_value, exc_tb):
_noretry_codes.reset(self.token)
self.token = None
# decorators
def expect(*codes, retry_on=None, attempts=2):
# See https://peps.python.org/pep-3102/ for information
# on keyword-only arguments (ie why we're allowed to do
# named arguments after *codes. Python3 is fun.
if attempts < 1:
raise ValueError("attempts must be at least 1")
if retry_on is None:
retry_on = []
def decorator(func):
@wraps(func)
def _(self, endpoint, **kwargs):
start_time = time.perf_counter_ns()
local_retry = set(retry_on) - _noretry_codes.get(set([]))
try:
for x in range(attempts):
try:
res = func(self, endpoint, **kwargs)
except ConnectionError as e:
sleep = 2 ** (x + 1)
time.sleep(sleep)
res = func(self, endpoint, **kwargs)
continue
if res.status_code in local_retry:
# simple power-of-2 exponential
# we can get fancier later if we need to.
sleep = 2 ** (x + 1)
time.sleep(sleep)
continue
if res.status_code not in codes:
try:
content = json.loads(res.content.decode("utf8"))
except:
raise AssertionError(
"\n\n".join(
["Response from server malformed (could not be decoded):", str(res.content)]
)
)
raise AssertionError(
"\n\n".join(
[
"Response from server ({}) different from "
"expected {}".format(res.status_code, codes),
"SERVER ERROR: {}".format(content.get("error")),
"SERVER TRACEBACK: {}".format(content.get("traceback")),
]
)
)
return res
finally:
if _ENABLE_TIMING:
total_time = time.perf_counter_ns() - start_time
method = func.__name__.upper()
api = ApiCall.base_api(method=method, endpoint=endpoint)
_api_calls.setdefault(api, ApiCall(api=api, method=method)).update_timing_data(endpoint, total_time)
return _
return decorator
def logged_in(func):
@wraps(func)
def _(self, endpoint, **kwargs):
if not self._logged_in and "login" not in endpoint:
try:
self.login()
except:
raise ConnectionError(
"Could not connect to ESP. Either the server is down or login credentials are invalid"
)
return func(self, endpoint, **kwargs)
return _
# objects
# -------
[docs]class Session(object):
"""
Object for managing connections to esp.
Args:
host (str): Host with database to connect to.
port (int): Port to connect to database with.
username (str): Username to access application with.
password (str): Password to be used with username.
cookies (str): Optional path to cookies file to store session data.
token(str): Optional API token to use in place of username and password
for authentication. Can be a string or path to a file containing the
token.
cache (bool): Whether or not to cache objects for faster querying.
email (str): Username to access application with. Deprecated but
preserved through at least 2.3 for backwards compatibility.
If both username and email are provided, username wins.
Attributes:
session (requests.Session): Internal ``requests.Session`` object for managing
connection.
"""
def __init__(
self,
host="127.0.0.1",
port=8002,
cookies=None,
headers=None,
username="admin@localhost",
password="password",
cache=False,
ssl=False,
token=None,
email=None,
tenant=None,
):
if email is not None and username is None:
warnings.warn("email property is deprecated. Use username instead.", DeprecationWarning, stacklevel=2)
username = email
# save basics
self.host = host
self.port = port
self.ssl = ssl
self.tenant = tenant
# set up url
protocol = "https" if self.ssl or port == 443 else "http"
self.url = "{}://{}".format(protocol, self.host)
if self.port is not None and self.port != 80 and self.port != 443:
self.url += ":{}".format(self.port)
if self.url.endswith("/"):
self.url = self.url[:-1]
# properties
self.cache = cache
self.username = username
self.password = password
# generate user-agent info
from .__init__ import __version__
agent = "User-Agent: L7 Informatics ESP/{base}.0 ({platform}; esp-client {version} +https://l7informatics.com/esp)".format(
base=__version__.split(".")[0], platform=sys.platform.capitalize(), version=__version__
)
# set up cookies and session
self.session = requests.Session()
self.session.headers.update(
{
"Connection": "keep-alive",
"Accept": "application/json",
"Content-Type": "application/json",
"X-Frame-Options": "sameorigin",
"X-Content-Type-Options": "nosniff",
"User-Agent": agent,
"X-XSS-Protection": '"1; mode=block"',
"uEhz08sUFV8gfZGpVjsE": "ignore", # Temporary fix for cookie domain. see APPS-3235
}
)
if isinstance(headers, dict):
self.session.headers.update(headers)
self.cookies = cookies
self.token = self._handle_token(token)
if self.token is not None:
self.session.headers["Authorization"] = "Bearer {}".format(self.token)
self._logged_in = False
self._user = None
return
def _handle_token(self, token):
if not token:
return None
if os.path.exists(token):
with open(token) as file_:
token = "".join(x.strip() for x in file_.readlines())
if is_uuid(token):
token = base64.b64encode(token.encode()).decode()
return token
@cached
def jar(self):
"""
Property containing cookie jar for use in authentication.
"""
if self.cookies is not None and os.path.exists(self.cookies):
jar = cookiejar.MozillaCookieJar()
jar.load(self.cookies)
return jar
else:
return None
@property
def user(self):
"""
Property for accessing User object containing
current user for session.
"""
return self._user
[docs] def login(self):
"""
Login to ESP system and configure authentication. This will be
run implicitly if the user issues any request to any non '/login'
endpoint.
"""
if self.cookies or self.token:
ret = self.session.get(self.url + "/api/profile", verify=False)
user = ret.json()
else:
if self.tenant is not None:
payload = {"tenant": self.tenant}
else:
payload = None
ret = self.session.post(
self.url + "/login",
auth=requests.auth.HTTPBasicAuth(self.username, self.password),
json=payload,
verify=False,
)
ret = ret.json()
if "error" in ret:
raise AssertionError("Error occurred during authentication: {}".format(ret["error"]))
user = ret["user"]
self._logged_in = True
from .models import User
self._user = User(user["name"])
return ret
[docs] def logout(self):
"""
Logout of ESP system from the current session.
"""
if self._logged_in:
auth = None
if not self.token and not self.cookies:
auth = requests.auth.HTTPBasicAuth(self.username, self.password)
self.session.post(
self.url + "/logout",
auth=auth,
)
self._logged_in = False
[docs] @logged_in
@expect(200, retry_on=[404, 504])
def get(self, endpoint, **kwargs):
"""
Issue GET request to esp endpoint, and raise exception
if return status code is not 200.
Args:
endpoint (str): String with relative api endpoint.
**kwargs: Additional arguments to pass into ``request.get``.
"""
logger.debug("GET %s%s (kwargs: %s)", self.url, endpoint, kwargs)
if logger.level == logging.DEBUG:
traceback.print_stack()
return self.session.get(self.url + endpoint, cookies=self.jar, verify=False, **kwargs)
[docs] @logged_in
@expect(200, 201, 202)
def put(self, endpoint, **kwargs):
"""
Issue PUT request to esp endpoint, and raise exception
if return status code is not 200, 201, 202.
Args:
endpoint (str): String with relative api endpoint.
**kwargs: Additional arguments to pass into ``request.put``.
This will usually include the payload for the request.
"""
logger.debug("PUT %s%s (kwargs: %s)", self.url, endpoint, kwargs)
if logger.level == logging.DEBUG:
traceback.print_stack()
return self.session.put(self.url + endpoint, cookies=self.jar, verify=False, **kwargs)
[docs] @logged_in
@expect(200, 201, 202)
def post(self, endpoint, **kwargs):
"""
Issue POST request to esp endpoint, and raise exception
if return status code is not 200, 201, 202.
Args:
endpoint (str): String with relative api endpoint.
**kwargs: Additional arguments to pass into ``request.post``.
This will usually include the payload for the request.
"""
logger.debug("POST %s%s (kwargs: %s)", self.url, endpoint, kwargs)
if logger.level == logging.DEBUG:
traceback.print_stack()
return self.session.post(self.url + endpoint, cookies=self.jar, verify=False, **kwargs)
[docs] @logged_in
@expect(200, 204)
def delete(self, endpoint, **kwargs):
"""
Issue DELETE request to esp endpoint, and raise exception
if return status code is not 200.
Args:
endpoint (str): String with relative api endpoint.
**kwargs: Additional arguments to pass into ``request.delete``.
This will usually include the payload for the request.
"""
logger.debug("DELETE %s%s (kwargs: %s)", self.url, endpoint, kwargs)
if logger.level == logging.DEBUG:
traceback.print_stack()
return self.session.delete(self.url + endpoint, cookies=self.jar, verify=False, **kwargs)
@property
def options(self):
"""
Return JSON with options on the session.
"""
ret = {
"url": self.url,
"host": self.host,
"port": self.port,
"cache": self.cache,
"username": self.username,
"password": self.password,
"ssl": self.ssl,
"token": self.token,
"config": None,
"tenant": self.tenant,
}
if self.cookies:
ret["cookies"] = self.cookies
return ret
# methods
# -------
[docs]def options(**kwargs):
"""
Set global options for session management. You can also configure the system
to use a default options file located in ``~/.lab7/client.yml``. Here's an example
options file:
.. code-block:: console
$ cat ~/.lab7/client.yml
username: me@localhost
password: password
cache: false
Args:
host (str): Host with database to connect to.
port (int): Port to connect to database with.
username (str): Username to access application with.
password (str): Password to be used with username.
cookies (str): Optional path to cookies file to store session data.
token(str): Optional API token to use in place of username and password
for authentication. Can be a string or path to a file containing the
token.
cache (bool): Whether or not to cache objects for faster querying.
email (str): Alias for username, retained for backwards compatibility
with existing scripts, but due for removal in the future.
tenant (str): Tenancy to use. Only for multi-tenant ESP instances.
Note: this argument is only necessary for tenancies where the
same username + password combination exists in multiple tenants.
Examples:
>>> # set default options for connections
>>> import esp
>>> esp.options(
>>> username='user@localhost', password='pass'
>>> cookies='/path/to/cookies-file.txt'
>>> )
>>>
>>> # interact with esp client normally
>>> from esp.models import Protocol
>>> obj = Protocol('My Protocol')
"""
import os
from gems import composite
global SESSION, CONFIG
if "email" in kwargs and kwargs.get("username") is None:
warnings.warn("email property is deprecated. Use username instead.", DeprecationWarning, stacklevel=2)
kwargs["username"] = kwargs.pop("email")
# load defaults
CONFIG = composite(
dict(
url=os.getenv("LAB7_API_SERVER", "http://127.0.0.1:8002"),
cookies=os.getenv("LAB7_COOKIE_FILE", None),
token=os.getenv("LAB7_API_KEY_FILE", None),
username="admin@localhost",
password="password",
headers=None,
cache=True,
)
)
# load user config
USER = os.path.join(os.getenv("HOME", os.path.expanduser("~")), ".lab7", "client.yml")
if os.path.exists(USER):
CONFIG += composite.from_yaml(open(USER))
# load specified config
if "config" in kwargs and kwargs["config"] and os.path.exists(kwargs["config"]):
CONFIG += composite.from_yaml(open(kwargs["config"]))
del kwargs["config"]
# overwrite with any additional options
CONFIG += composite({k: v for k, v in kwargs.items()})
# configure url (if specified)
if "url" in kwargs:
CONFIG.url = kwargs["url"]
# parse url into ssl, host, and port
parsed = urllib.parse.urlparse(CONFIG.url)
netloc = parsed.netloc.split(":")
if "ssl" not in kwargs:
CONFIG.ssl = parsed.scheme == "https"
if "host" not in kwargs:
CONFIG.host = netloc[0]
if "port" not in kwargs:
CONFIG.port = None if len(netloc) == 1 else int(netloc[1])
if "tenant" not in kwargs:
CONFIG.tenant = None
# the observed (UI) tenant name and the backend tenant id required for authentication
# differ by db_, at least in 3.1.
if CONFIG.tenant and not CONFIG.tenant.startswith("db_"):
CONFIG.tenant = "db_" + CONFIG.tenant
# establish session
SESSION = Session(
host=CONFIG.host,
port=CONFIG.port,
cookies=CONFIG.cookies,
username=CONFIG.username,
password=CONFIG.password,
ssl=CONFIG.ssl,
cache=CONFIG.cache,
token=CONFIG.token,
headers=CONFIG.headers,
tenant=CONFIG.tenant,
)
return CONFIG
options()