Source code for esp.base

# -*- coding: utf-8 -*-
#
# Session management.
#
# ------------------------------------------------


# imports
# -------
import atexit
import base64
from contextvars import ContextVar
from collections import Counter
from dataclasses import dataclass, field
from functools import wraps
import json
import logging
import os
import re
import sys
from typing import Mapping
import traceback
import warnings
import time

try:
    import cookielib as cookiejar
except ImportError:
    import http.cookiejar as cookiejar

from gems import cached
import requests
from requests import ConnectionError
from requests.packages import urllib3
import six.moves.urllib as urllib

from .utils import debugtime, is_uuid

from espclient.config import DEFAULT_CLIENT_COOKIES

_ENABLE_TIMING = str(os.environ.get("L7_CLIENT_ENABLE_TIMING", "true")).lower() == "true"
# config
# ------
CONFIG = None
SESSION = None
_debug = debugtime("API")
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

logger = logging.getLogger(__name__)

_module_load_time = time.perf_counter_ns()

_NS_TO_UNIT_DIVISOR = {
    "ns": 1,
    "us": 1e3,
    "ms": 1e6,
    "s": 1e9,
}


@dataclass
class ApiCall:
    api: str
    method: str
    call_count: int = 0
    # 1 nanosecond is 10^-9 seconds, so set min_time to 10^18 so it becomes
    # 10^9 seconds... our actual min must be less than that since a billion
    # seconds is ~32 years...
    min_time: int = 1e18
    max_time: int = 0
    cumulative_time: int = 0
    unique_call_counts: Mapping[str, int] = field(default_factory=Counter)

    @classmethod
    def base_api(cls, method, endpoint) -> str:
        """Given an ApiCall, returns a string representing the "base" API. For instance:
        GET /api/samples/608f4f2e-843c-492c-bec6-a51f6f03e0ba -> GET /api/samples/:uuid
        GET /api/samples -> GET /api/samples
        PUT /api/samples/:uuid -> POST /api/samples/:uuid
        PUT /api/samples -> POST /api/samples

        The distinction between with/without UUID is to make it easier to distinguish between
        bulk fetch/update calls vs. individual update calls while still allowing for
        aggregate reporting of the individual calls.
        """
        parsed = urllib.parse.urlparse(endpoint)
        # Note: when we move to python3.8 as a minimum, we can
        # change to using walrus operators.
        path = parsed.path
        match = re.match(r"^(.*?)/[0-9a-f-]{36}(/.*)?$", path)
        if match:
            path = f'{match.group(1)}/:uuid{match.group(2) or ""}'

        return f"{method} {path}"

    def update_timing_data(self, endpoint, total_time):
        self.call_count += 1
        self.unique_call_counts[endpoint] += 1
        self.cumulative_time += total_time
        if total_time < self.min_time:
            self.min_time = total_time
        if total_time > self.max_time:
            self.max_time = total_time

    @property
    def mean_time(self):
        if self.call_count > 0:
            return self.cumulative_time / self.call_count
        return 0.0

    def report_string(self, units="ns"):
        return ",".join([str(x) for x in self.report_list(units)])

    def report_list(self, units="ns"):
        divisor = _NS_TO_UNIT_DIVISOR.get(units, 1)
        return [
            self.api,
            self.call_count,
            self.cumulative_time / divisor,
            self.min_time / divisor,
            self.max_time / divisor,
            self.mean_time / divisor,
            len(self.unique_call_counts),
        ]


# no typehints for better backwards compatibility.
_api_calls = {}
# with typehints for better comprehension - can use this form
# starting with python 3.9, so when we drop support for
# python 3.8, uncomment this line, remove _api_calls = {}
# above, and remove this comment.
# _api_calls: dict[str, ApiCall] = {}


def _report_apicalls(*args, **kwargs):
    program_exit_time = time.perf_counter_ns()
    calls: list[ApiCall] = sorted(
        [call for call in _api_calls.values()], key=lambda x: (x.cumulative_time, x.mean_time), reverse=True
    )
    if not calls:
        return
    units = "ms"
    divisor = _NS_TO_UNIT_DIVISOR.get(units, 1)
    print("{", file=sys.stderr)
    print(f' "units": {json.dumps(units)}', file=sys.stderr)
    print(
        f' "approximate_execution_time": {json.dumps((program_exit_time - _module_load_time)/divisor)}', file=sys.stderr
    )
    headers = "API #Calls Cumulative Min Max Mean #UniqueEndpoints".split(" ")
    print(f' "call_headers": {headers}', file=sys.stderr)
    print(f' "calls": [', file=sys.stderr)
    for call in calls:
        print(f"  {json.dumps(call.report_list(units))}", file=sys.stderr)
    print(" ]", file=sys.stderr)
    sys.stderr.flush()


if _ENABLE_TIMING:
    atexit.register(_report_apicalls)


_noretry_codes: ContextVar[set] = ContextVar("_noretry_codes")


class noretry_on:
    def __init__(self, *codes):
        self.skip_retry_codes = set(codes)
        self.token = None

    def __enter__(self):
        self.token = _noretry_codes.set(self.skip_retry_codes)

    def __exit__(self, exc_type, exc_value, exc_tb):
        _noretry_codes.reset(self.token)
        self.token = None


# decorators
def expect(*codes, retry_on=None, attempts=2):
    # See https://peps.python.org/pep-3102/ for information
    # on keyword-only arguments (ie why we're allowed to do
    # named arguments after *codes. Python3 is fun.
    if attempts < 1:
        raise ValueError("attempts must be at least 1")
    if retry_on is None:
        retry_on = []

    def decorator(func):
        @wraps(func)
        def _(self, endpoint, **kwargs):
            start_time = time.perf_counter_ns()
            local_retry = set(retry_on) - _noretry_codes.get(set([]))
            try:
                for x in range(attempts):
                    try:
                        res = func(self, endpoint, **kwargs)
                    except ConnectionError as e:
                        sleep = 2 ** (x + 1)
                        time.sleep(sleep)
                        res = func(self, endpoint, **kwargs)
                        continue
                    if res.status_code in local_retry:
                        # simple power-of-2 exponential
                        # we can get fancier later if we need to.
                        sleep = 2 ** (x + 1)
                        time.sleep(sleep)
                        continue
                    if res.status_code not in codes:
                        try:
                            content = json.loads(res.content.decode("utf8"))
                        except:
                            raise AssertionError(
                                "\n\n".join(
                                    ["Response from server malformed (could not be decoded):", str(res.content)]
                                )
                            )
                        raise AssertionError(
                            "\n\n".join(
                                [
                                    "Response from server ({}) different from "
                                    "expected {}".format(res.status_code, codes),
                                    "SERVER ERROR: {}".format(content.get("error")),
                                    "SERVER TRACEBACK: {}".format(content.get("traceback")),
                                ]
                            )
                        )
                    return res
            finally:
                if _ENABLE_TIMING:
                    total_time = time.perf_counter_ns() - start_time
                    method = func.__name__.upper()
                    api = ApiCall.base_api(method=method, endpoint=endpoint)
                    _api_calls.setdefault(api, ApiCall(api=api, method=method)).update_timing_data(endpoint, total_time)

        return _

    return decorator


def logged_in(func):
    @wraps(func)
    def _(self, endpoint, **kwargs):
        if not self._logged_in and "login" not in endpoint:
            try:
                self.login()
            except:
                raise ConnectionError(
                    "Could not connect to ESP. Either the server is down or login credentials are invalid"
                )
        return func(self, endpoint, **kwargs)

    return _


# objects
# -------
[docs]class Session(object): """ Object for managing connections to esp. Args: host (str): Host with database to connect to. port (int): Port to connect to database with. username (str): Username to access application with. password (str): Password to be used with username. cookies (str): Optional path to cookies file to store session data. token(str): Optional API token to use in place of username and password for authentication. Can be a string or path to a file containing the token. cache (bool): Whether or not to cache objects for faster querying. email (str): Username to access application with. Deprecated but preserved through at least 2.3 for backwards compatibility. If both username and email are provided, username wins. Attributes: session (requests.Session): Internal ``requests.Session`` object for managing connection. """ def __init__( self, host="127.0.0.1", port=8002, cookies=None, headers=None, username="admin@localhost", password="password", cache=False, ssl=False, token=None, email=None, tenant=None, ): if email is not None and username is None: warnings.warn("email property is deprecated. Use username instead.", DeprecationWarning, stacklevel=2) username = email # save basics self.host = host self.port = port self.ssl = ssl self.tenant = tenant # set up url protocol = "https" if self.ssl or port == 443 else "http" self.url = "{}://{}".format(protocol, self.host) if self.port is not None and self.port != 80 and self.port != 443: self.url += ":{}".format(self.port) if self.url.endswith("/"): self.url = self.url[:-1] # properties self.cache = cache self.username = username self.password = password # generate user-agent info from .__init__ import __version__ agent = "User-Agent: L7 Informatics ESP/{base}.0 ({platform}; esp-client {version} +https://l7informatics.com/esp)".format( base=__version__.split(".")[0], platform=sys.platform.capitalize(), version=__version__ ) # set up cookies and session self.session = requests.Session() self.session.headers.update( { "Connection": "keep-alive", "Accept": "application/json", "Content-Type": "application/json", "X-Frame-Options": "sameorigin", "X-Content-Type-Options": "nosniff", "User-Agent": agent, "X-XSS-Protection": '"1; mode=block"', "uEhz08sUFV8gfZGpVjsE": "ignore", # Temporary fix for cookie domain. see APPS-3235 } ) if isinstance(headers, dict): self.session.headers.update(headers) self.cookies = cookies self.token = self._handle_token(token) if self.token is not None: self.session.headers["Authorization"] = "Bearer {}".format(self.token) self._logged_in = False self._user = None return def _handle_token(self, token): if not token: return None if os.path.exists(token): with open(token) as file_: token = "".join(x.strip() for x in file_.readlines()) if is_uuid(token): token = base64.b64encode(token.encode()).decode() return token @cached def jar(self): """ Property containing cookie jar for use in authentication. """ if self.cookies is not None and os.path.exists(self.cookies): jar = cookiejar.MozillaCookieJar() jar.load(self.cookies) return jar else: return None @property def user(self): """ Property for accessing User object containing current user for session. """ return self._user
[docs] def login(self): """ Login to ESP system and configure authentication. This will be run implicitly if the user issues any request to any non '/login' endpoint. """ if self.cookies or self.token: ret = self.session.get(self.url + "/api/profile", verify=False) user = ret.json() else: if self.tenant is not None: payload = {"tenant": self.tenant} else: payload = None ret = self.session.post( self.url + "/login", auth=requests.auth.HTTPBasicAuth(self.username, self.password), json=payload, verify=False, ) ret = ret.json() if "error" in ret: raise AssertionError("Error occurred during authentication: {}".format(ret["error"])) user = ret["user"] self._logged_in = True from .models import User self._user = User(user["name"]) return ret
[docs] def logout(self): """ Logout of ESP system from the current session. """ if self._logged_in: auth = None if not self.token and not self.cookies: auth = requests.auth.HTTPBasicAuth(self.username, self.password) self.session.post( self.url + "/logout", auth=auth, ) self._logged_in = False
[docs] @logged_in @expect(200, retry_on=[404, 504]) def get(self, endpoint, **kwargs): """ Issue GET request to esp endpoint, and raise exception if return status code is not 200. Args: endpoint (str): String with relative api endpoint. **kwargs: Additional arguments to pass into ``request.get``. """ logger.debug("GET %s%s (kwargs: %s)", self.url, endpoint, kwargs) if logger.level == logging.DEBUG: traceback.print_stack() return self.session.get(self.url + endpoint, cookies=self.jar, verify=False, **kwargs)
[docs] @logged_in @expect(200, 201, 202) def put(self, endpoint, **kwargs): """ Issue PUT request to esp endpoint, and raise exception if return status code is not 200, 201, 202. Args: endpoint (str): String with relative api endpoint. **kwargs: Additional arguments to pass into ``request.put``. This will usually include the payload for the request. """ logger.debug("PUT %s%s (kwargs: %s)", self.url, endpoint, kwargs) if logger.level == logging.DEBUG: traceback.print_stack() return self.session.put(self.url + endpoint, cookies=self.jar, verify=False, **kwargs)
[docs] @logged_in @expect(200, 201, 202) def post(self, endpoint, **kwargs): """ Issue POST request to esp endpoint, and raise exception if return status code is not 200, 201, 202. Args: endpoint (str): String with relative api endpoint. **kwargs: Additional arguments to pass into ``request.post``. This will usually include the payload for the request. """ logger.debug("POST %s%s (kwargs: %s)", self.url, endpoint, kwargs) if logger.level == logging.DEBUG: traceback.print_stack() return self.session.post(self.url + endpoint, cookies=self.jar, verify=False, **kwargs)
[docs] @logged_in @expect(200, 204) def delete(self, endpoint, **kwargs): """ Issue DELETE request to esp endpoint, and raise exception if return status code is not 200. Args: endpoint (str): String with relative api endpoint. **kwargs: Additional arguments to pass into ``request.delete``. This will usually include the payload for the request. """ logger.debug("DELETE %s%s (kwargs: %s)", self.url, endpoint, kwargs) if logger.level == logging.DEBUG: traceback.print_stack() return self.session.delete(self.url + endpoint, cookies=self.jar, verify=False, **kwargs)
@property def options(self): """ Return JSON with options on the session. """ ret = { "url": self.url, "host": self.host, "port": self.port, "cache": self.cache, "username": self.username, "password": self.password, "ssl": self.ssl, "token": self.token, "config": None, "tenant": self.tenant, } if self.cookies: ret["cookies"] = self.cookies return ret
# methods # -------
[docs]def options(**kwargs): """ Set global options for session management. You can also configure the system to use a default options file located in ``~/.lab7/client.yml``. Here's an example options file: .. code-block:: console $ cat ~/.lab7/client.yml username: me@localhost password: password cache: false Args: host (str): Host with database to connect to. port (int): Port to connect to database with. username (str): Username to access application with. password (str): Password to be used with username. cookies (str): Optional path to cookies file to store session data. token(str): Optional API token to use in place of username and password for authentication. Can be a string or path to a file containing the token. cache (bool): Whether or not to cache objects for faster querying. email (str): Alias for username, retained for backwards compatibility with existing scripts, but due for removal in the future. tenant (str): Tenancy to use. Only for multi-tenant ESP instances. Note: this argument is only necessary for tenancies where the same username + password combination exists in multiple tenants. Examples: >>> # set default options for connections >>> import esp >>> esp.options( >>> username='user@localhost', password='pass' >>> cookies='/path/to/cookies-file.txt' >>> ) >>> >>> # interact with esp client normally >>> from esp.models import Protocol >>> obj = Protocol('My Protocol') """ import os from gems import composite global SESSION, CONFIG if "email" in kwargs and kwargs.get("username") is None: warnings.warn("email property is deprecated. Use username instead.", DeprecationWarning, stacklevel=2) kwargs["username"] = kwargs.pop("email") # load defaults CONFIG = composite( dict( url=os.getenv("LAB7_API_SERVER", "http://127.0.0.1:8002"), cookies=os.getenv("LAB7_COOKIE_FILE", None), token=os.getenv("LAB7_API_KEY_FILE", None), username="admin@localhost", password="password", headers=None, cache=True, ) ) # load user config USER = os.path.join(os.getenv("HOME", os.path.expanduser("~")), ".lab7", "client.yml") if os.path.exists(USER): CONFIG += composite.from_yaml(open(USER)) # load specified config if "config" in kwargs and kwargs["config"] and os.path.exists(kwargs["config"]): CONFIG += composite.from_yaml(open(kwargs["config"])) del kwargs["config"] # overwrite with any additional options CONFIG += composite({k: v for k, v in kwargs.items()}) # configure url (if specified) if "url" in kwargs: CONFIG.url = kwargs["url"] # parse url into ssl, host, and port parsed = urllib.parse.urlparse(CONFIG.url) netloc = parsed.netloc.split(":") if "ssl" not in kwargs: CONFIG.ssl = parsed.scheme == "https" if "host" not in kwargs: CONFIG.host = netloc[0] if "port" not in kwargs: CONFIG.port = None if len(netloc) == 1 else int(netloc[1]) if "tenant" not in kwargs: CONFIG.tenant = None # the observed (UI) tenant name and the backend tenant id required for authentication # differ by db_, at least in 3.1. if CONFIG.tenant and not CONFIG.tenant.startswith("db_"): CONFIG.tenant = "db_" + CONFIG.tenant # establish session SESSION = Session( host=CONFIG.host, port=CONFIG.port, cookies=CONFIG.cookies, username=CONFIG.username, password=CONFIG.password, ssl=CONFIG.ssl, cache=CONFIG.cache, token=CONFIG.token, headers=CONFIG.headers, tenant=CONFIG.tenant, ) return CONFIG
options()