from __future__ import annotations
import string
from collections.abc import Mapping
from dataclasses import asdict, dataclass
from io import BytesIO
from typing import IO, Any, Literal, NamedTuple, cast
from urllib.parse import urlencode
import sentry_sdk
import urllib3
import zstandard
from urllib3.connectionpool import HTTPConnectionPool
from objectstore_client.metadata import (
HEADER_EXPIRATION,
HEADER_META_PREFIX,
Compression,
ExpirationPolicy,
Metadata,
format_expiration,
)
from objectstore_client.metrics import (
MetricsBackend,
NoOpMetricsBackend,
measure_storage_operation,
)
Permission = Literal["read", "write"]
[docs]
class GetResult(NamedTuple):
metadata: Metadata
payload: IO[bytes]
[docs]
class RequestError(Exception):
"""Exception raised if an API call to Objectstore fails."""
def __init__(self, message: str, status: int, response: str):
super().__init__(message)
self.status = status
self.response = response
[docs]
class Usecase:
"""
An identifier for a workload in Objectstore, along with defaults to use for all
operations within that Usecase.
Usecases need to be statically defined in Objectstore's configuration server-side.
Objectstore can make decisions based on the Usecase. For example, choosing the most
suitable storage backend.
"""
name: str
_compression: Compression
_expiration_policy: ExpirationPolicy | None
def __init__(
self,
name: str,
compression: Compression = "zstd",
expiration_policy: ExpirationPolicy | None = None,
):
self.name = name
self._compression = compression
self._expiration_policy = expiration_policy
# Characters allowed in a Scope's key and value.
# These are the URL safe characters, except for `.` which we use as separator between
# key and value of Scope components.
SCOPE_VALUE_ALLOWED_CHARS = set(string.ascii_letters + string.digits + "-_()$!+'")
@dataclass
class _ConnectionDefaults:
retries: urllib3.Retry = urllib3.Retry(connect=3, read=0)
"""We only retry connection problems, as we cannot rewind our compression stream."""
timeout: urllib3.Timeout = urllib3.Timeout(connect=0.5, read=0.5)
"""
The read timeout is defined to be "between consecutive read operations",
which should mean one chunk of the response, with a large response being
split into multiple chunks.
We define both as 500ms which is still very conservative,
given that we are in the same network, and expect our backends to respond in <100ms.
"""
[docs]
class Client:
"""A client for Objectstore. Constructing it initializes a connection pool."""
def __init__(
self,
base_url: str,
metrics_backend: MetricsBackend | None = None,
propagate_traces: bool = False,
retries: int | None = None,
timeout_ms: float | None = None,
connection_kwargs: Mapping[str, Any] | None = None,
):
connection_kwargs_to_use = asdict(_ConnectionDefaults())
if retries:
connection_kwargs_to_use["retries"] = urllib3.Retry(
connect=retries,
# we only retry connection problems, as we cannot rewind our
# compression stream
read=0,
)
if timeout_ms:
connection_kwargs_to_use["timeout"] = urllib3.Timeout(
connect=timeout_ms * 100, read=timeout_ms * 100
)
if connection_kwargs:
connection_kwargs_to_use = {**connection_kwargs_to_use, **connection_kwargs}
self._pool = urllib3.connectionpool.connection_from_url(
base_url, **connection_kwargs_to_use
)
self._metrics_backend = metrics_backend or NoOpMetricsBackend()
self._propagate_traces = propagate_traces
[docs]
def session(self, usecase: Usecase, **scopes: str | int | bool) -> Session:
"""
Create a [Session] with the Objectstore server, tied to a specific [Usecase] and
Scope.
A Scope is a (possibly nested) namespace within a Usecase, given as a sequence
of key-value pairs passed as kwargs.
IMPORTANT: the order of the kwargs matters!
The admitted characters for keys and values are: `A-Za-z0-9_-()$!+*'`.
Users are free to choose the scope structure that best suits their Usecase.
The combination of Usecase and Scope will determine the physical key/path of the
blob in the underlying storage backend.
For most usecases, it's recommended to use the organization and project ID as
the first components of the scope, as follows:
```
client.session(usecase, org=organization_id, project=project_id, ...)
```
"""
parts = []
for key, value in scopes.items():
if not key:
raise ValueError("Scope key cannot be empty")
if not value:
raise ValueError("Scope value cannot be empty")
if any(c not in SCOPE_VALUE_ALLOWED_CHARS for c in key):
raise ValueError(
f"Invalid scope key {key}. The valid character set is: "
f"{''.join(SCOPE_VALUE_ALLOWED_CHARS)}"
)
value = str(value)
if any(c not in SCOPE_VALUE_ALLOWED_CHARS for c in value):
raise ValueError(
f"Invalid scope value {value}. The valid character set is: "
f"{''.join(SCOPE_VALUE_ALLOWED_CHARS)}"
)
formatted = f"{key}.{value}"
parts.append(formatted)
scope_str = "/".join(parts)
return Session(
self._pool,
self._metrics_backend,
self._propagate_traces,
usecase,
scope_str,
)
[docs]
class Session:
"""
A session with the Objectstore server, scoped to a specific [Usecase] and Scope.
This should never be constructed directly, use [Client.session].
"""
def __init__(
self,
pool: HTTPConnectionPool,
metrics_backend: MetricsBackend,
propagate_traces: bool,
usecase: Usecase,
scope: str,
):
self._pool = pool
self._metrics_backend = metrics_backend
self._propagate_traces = propagate_traces
self._usecase = usecase
self._scope = scope
def _make_headers(self) -> dict[str, str]:
if self._propagate_traces:
return dict(sentry_sdk.get_current_scope().iter_trace_propagation_headers())
return {}
def _make_url(self, id: str | None, full: bool = False) -> str:
base_path = f"/v1/{id}" if id else "/v1/"
qs = urlencode({"usecase": self._usecase.name, "scope": self._scope})
if full:
return f"http://{self._pool.host}:{self._pool.port}{base_path}?{qs}"
else:
return f"{base_path}?{qs}"
[docs]
def put(
self,
contents: bytes | IO[bytes],
id: str | None = None,
compression: Compression | Literal["none"] | None = None,
content_type: str | None = None,
metadata: dict[str, str] | None = None,
expiration_policy: ExpirationPolicy | None = None,
) -> str:
"""
Uploads the given `contents` to blob storage.
If no `id` is provided, one will be automatically generated and returned
from this function.
The client will select the configured `default_compression` if none is given
explicitly.
This can be overridden by explicitly giving a `compression` argument.
Providing `"none"` as the argument will instruct the client to not apply
any compression to this upload, which is useful for uncompressible formats.
"""
headers = self._make_headers()
body = BytesIO(contents) if isinstance(contents, bytes) else contents
original_body: IO[bytes] = body
compression = compression or self._usecase._compression
if compression == "zstd":
cctx = zstandard.ZstdCompressor()
body = cctx.stream_reader(original_body)
headers["Content-Encoding"] = "zstd"
if content_type:
headers["Content-Type"] = content_type
expiration_policy = expiration_policy or self._usecase._expiration_policy
if expiration_policy:
headers[HEADER_EXPIRATION] = format_expiration(expiration_policy)
if metadata:
for k, v in metadata.items():
headers[f"{HEADER_META_PREFIX}{k}"] = v
with measure_storage_operation(
self._metrics_backend, "put", self._usecase.name
) as metric_emitter:
response = self._pool.request(
"PUT",
self._make_url(id),
body=body,
headers=headers,
preload_content=True,
decode_content=True,
)
raise_for_status(response)
res = response.json()
# Must do this after streaming `body` as that's what is responsible
# for advancing the seek position in both streams
metric_emitter.record_uncompressed_size(original_body.tell())
if compression and compression != "none":
metric_emitter.record_compressed_size(body.tell(), compression)
return res["key"]
[docs]
def get(self, id: str, decompress: bool = True) -> GetResult:
"""
This fetches the blob with the given `id`, returning an `IO` stream that
can be read.
By default, content that was uploaded compressed will be automatically
decompressed, unless `decompress=True` is passed.
"""
headers = self._make_headers()
with measure_storage_operation(
self._metrics_backend, "get", self._usecase.name
):
response = self._pool.request(
"GET",
self._make_url(id),
preload_content=False,
decode_content=False,
headers=headers,
)
raise_for_status(response)
# OR: should I use `response.stream()`?
stream = cast(IO[bytes], response)
metadata = Metadata.from_headers(response.headers)
if metadata.compression and decompress:
if metadata.compression != "zstd":
raise NotImplementedError(
"Transparent decoding of anything but `zstd` is not implemented yet"
)
metadata.compression = None
dctx = zstandard.ZstdDecompressor()
stream = dctx.stream_reader(stream, read_across_frames=True)
return GetResult(metadata, stream)
[docs]
def object_url(self, id: str) -> str:
"""
Generates a GET url to the object with the given `id`.
This can then be used by downstream services to fetch the given object.
NOTE however that the service does not strictly follow HTTP semantics,
in particular in relation to `Accept-Encoding`.
"""
return self._make_url(id, full=True)
[docs]
def delete(self, id: str) -> None:
"""
Deletes the blob with the given `id`.
"""
headers = self._make_headers()
with measure_storage_operation(
self._metrics_backend, "delete", self._usecase.name
):
response = self._pool.request(
"DELETE",
self._make_url(id),
headers=headers,
)
raise_for_status(response)
[docs]
def raise_for_status(response: urllib3.BaseHTTPResponse) -> None:
if response.status >= 400:
res = str(response.data or response.read())
raise RequestError(
f"Objectstore request failed with status {response.status}",
response.status,
res,
)