From 1ecd3bd0ac279353e80db6fd485d4c62ba9c8003 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Thu, 5 Dec 2024 16:24:18 +0100 Subject: [PATCH 01/75] feat: submit aggregate tunnel metrics This commit contains a work-in-progress implementation of the aggregate tunnel metrics submission. For now, we have sketched out the basic architecture to show the functionality. There are still several unaddressed talking points to explore. --- dashboard/parse_all_logs.py | 1 - oonisubmitter/DESIGN.md | 28 ++++ oonisubmitter/aggregator.py | 176 ++++++++++++++++++++++ oonisubmitter/model.py | 284 ++++++++++++++++++++++++++++++++++++ oonisubmitter/serializer.py | 207 ++++++++++++++++++++++++++ oonisubmitter/submitter.py | 128 ++++++++++++++++ 6 files changed, 823 insertions(+), 1 deletion(-) create mode 100644 oonisubmitter/DESIGN.md create mode 100644 oonisubmitter/aggregator.py create mode 100644 oonisubmitter/model.py create mode 100644 oonisubmitter/serializer.py create mode 100644 oonisubmitter/submitter.py diff --git a/dashboard/parse_all_logs.py b/dashboard/parse_all_logs.py index 008b821..d078ed0 100644 --- a/dashboard/parse_all_logs.py +++ b/dashboard/parse_all_logs.py @@ -8,7 +8,6 @@ The resulting csv file is called 'metrics_all_logs.csv' and stored in the result import csv import os -import re import sys from parsing_functions import * from pathlib import Path diff --git a/oonisubmitter/DESIGN.md b/oonisubmitter/DESIGN.md new file mode 100644 index 0000000..87ed4ee --- /dev/null +++ b/oonisubmitter/DESIGN.md @@ -0,0 +1,28 @@ +# Submit aggregate tunnel metrics to OONI + +This pipeline component is meant to run periodically. The entry +point is `submitter.py` and the other scripts are lower-level +components of the overall implementation: + +- `aggregator.py` contains the code to aggregate endpoint metrics; + +- `model.py` contains the data model; + +- `serializer.py` creates aggregate OONI measurements. + +The component reads on-disk state specifying the last measurement +that was submitted and the path to the CSV file or files. + +It loads the CSV entries and goes through them building aggregated +endpoint results, and produces OONI measurements. + +Finally, the measurements are submitted to the OONI collector. + +This component is designed to interface with the current ETL pipline, +which uses a single flat columnar CSV-based datastore. + +We probably want to start using file locking since we have more +then one component reading or writing the CSV file(s) at the same time. + +TODO(bassosimone): deduplication happens using pandas and we use the +CSVs, so we should probably deduplicate ourselves here. diff --git a/oonisubmitter/aggregator.py b/oonisubmitter/aggregator.py new file mode 100644 index 0000000..059b524 --- /dev/null +++ b/oonisubmitter/aggregator.py @@ -0,0 +1,176 @@ +""" +Logic for aggregating field testing measurements into the aggregate +tunnel metrics OONI-compatible data format. + +See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md +""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, List +from . import model + + +@dataclass +class AggregatorConfig: + """ + Configuration for the measurement aggregator. + """ + + provider: str + upstream_collector: str + probe_asn: str + probe_cc: str + scope: str = "endpoint" # for now we only care about endpoint scope + + +@dataclass +class AggregateEndpointState: + """ + Flat representation of an endpoint's aggregated state. + + All fields needed for measurement generation should be here. + """ + + # Core identification + hostname: str + address: str + port: int + protocol: str + + # Classification + asn: str + cc: str + + # Context from config + provider: str + probe_asn: str + probe_cc: str + scope: str + + # Time window + window_start: datetime + window_end: datetime + + # Error tracking + # + # Statistics about successes and failures using + # empty string to represent success + errors: Dict[str, int] = field(default_factory=dict) + + # Ping statistics (collect raw values) + ping_min_values: List[float] = field(default_factory=list) + ping_avg_values: List[float] = field(default_factory=list) + ping_max_values: List[float] = field(default_factory=list) + ping_packets_loss_values: List[float] = field(default_factory=list) + + # NDT statistics (collect raw values) + download_throughput_values: List[float] = field(default_factory=list) + download_latency_values: List[float] = field(default_factory=list) + download_retransmission_values: List[float] = field(default_factory=list) + + upload_throughput_values: List[float] = field(default_factory=list) + upload_latency_values: List[float] = field(default_factory=list) + upload_retransmission_values: List[float] = field(default_factory=list) + + def update_error_counts(self, entry: model.FieldTestingCSVEntry) -> None: + """Update error counts based on a new entry""" + error_type = ( + "bootstrap.generic_error" if entry.is_tunnel_error_measurement() else "" + ) + self.errors[error_type] = self.errors.get(error_type, 0) + 1 + + def update_performance_metrics(self, entry: model.FieldTestingCSVEntry) -> None: + """Update performance metrics based on a new entry""" + if not entry.is_tunnel_error_measurement(): # only successful measurements + # Ping metrics + self.ping_min_values.append(entry.ping_roundtrip_min) + self.ping_avg_values.append(entry.ping_roundtrip_avg) + self.ping_max_values.append(entry.ping_roundtrip_max) + self.ping_packets_loss_values.append(entry.ping_packets_loss) + + # Download metrics + self.download_throughput_values.append(entry.throughput_download) + self.download_latency_values.append(entry.latency_download) + self.download_retransmission_values.append(entry.retransmission_download) + + # Upload metrics + self.upload_throughput_values.append(entry.throughput_upload) + self.upload_latency_values.append(entry.latency_upload) + self.upload_retransmission_values.append(entry.retransmission_upload) + + @classmethod + def from_csv_entry( + cls, + entry: model.FieldTestingCSVEntry, + config: AggregatorConfig, + window_start: datetime, + window_end: datetime, + ) -> "AggregateEndpointState": + return cls( + hostname=entry.server_fqdn, + address=entry.server_ip, + port=443, # TODO(bassosimone): check whether this is in the CSV + protocol=entry.protocol, + asn=entry.asn, + cc=entry.region, + provider=config.provider, + probe_asn=config.probe_asn, + probe_cc=config.probe_cc, + scope=config.scope, + window_start=window_start, + window_end=window_end, + ) + + +class EndpointAggregator: + """ + Maintains state for multiple endpoints. + """ + + def __init__( + self, config: AggregatorConfig, window_start: datetime, window_end: datetime + ): + self.config = config + self.window_start = window_start + self.window_end = window_end + self.endpoints: Dict[str, AggregateEndpointState] = {} + + def _make_key(self, hostname: str, address: str) -> str: + """Create unique key for an endpoint""" + # TODO(bassosimone): I wonder whether we need a more precise + # key here rather than the endpoint alone. This is probably not + # sufficient to address tcp/udp differences, for example. + return f"{hostname}:{address}" + + def _is_in_window(self, entry: model.FieldTestingCSVEntry) -> bool: + """Check if entry falls within our time window""" + return self.window_start <= entry.date < self.window_end + + def _is_tunnel_entry(self, entry: model.FieldTestingCSVEntry) -> bool: + """Check if entry is a tunnel measurement""" + return entry.is_tunnel_measurement() + + def update(self, entry: model.FieldTestingCSVEntry) -> None: + """ + Update aggregator state with a new measurement. + """ + # Exclude outside-window events as well as the + # events related to the baseline measurement for + # which we don't have spec support for now + if not self._is_in_window(entry): + return + if not self._is_tunnel_entry(entry): + return + + # Make sure we have an updated endpoint key here + key = self._make_key(entry.server_fqdn, entry.server_ip) + if key not in self.endpoints: + self.endpoints[key] = AggregateEndpointState.from_csv_entry( + entry, self.config, self.window_start, self.window_end + ) + + # Update the endpoint statistics + epnt = self.endpoints[key] + epnt.update_performance_metrics(entry) + epnt.update_error_counts(entry) diff --git a/oonisubmitter/model.py b/oonisubmitter/model.py new file mode 100644 index 0000000..42bc5f1 --- /dev/null +++ b/oonisubmitter/model.py @@ -0,0 +1,284 @@ +""" +Data structures for the OONI aggregate tunnel metrics pipeline. + +This module defines the core data structures used throughout the +pipeline, including: + +1. the input CSV format; + +2. the output OONI measurement format. + +The structures closely follow: + +1. The CSV format used for storing network measurements; + +2. The OONI aggregate tunnel metrics specification. + +See: + +- https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md + +- https://0xacab.org/leap/solitech-compose-client/-/blob/main/images/obfsvpn-openvpn-client/start.sh +""" + +# TODO(bassosimone): When deploying in Docker environments, consider potential +# namespace conflicts with existing parsing helpers and other +# modules. This may require reorganization or renaming depending +# on the specific deployment context. + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Dict, List, Optional +import csv +import logging + + +@dataclass +class FieldTestingCSVEntry: + """ + Models a single field-testing entry read from the CSV datastore. + + The order of the fields in this dataclass it the same + of the fields within the CSV file. + """ + + date: datetime + asn: str + isp: str + est_city: str + user: str + region: str + server_fqdn: str + server_ip: str + mobile: bool + tunnel: str # 'baseline', 'tunnel', 'ERROR/baseline', 'ERROR/tunnel' + throughput_download: float + throughput_upload: float + latency_download: float + latency_upload: float + retransmission_download: float + retransmission_upload: float + ping_packets_loss: float + ping_roundtrip_min: float + ping_roundtrip_avg: float + ping_roundtrip_max: float + err_message: str + protocol: str + + def is_tunnel_measurement(self) -> bool: + """ + Return whether this is a tunnel measurement, which includes both + successful and failed tunnel measurements. + """ + return self.tunnel in ("tunnel", "ERROR/tunnel") + + def is_tunnel_error_measurement(self) -> bool: + """Return whether this is a failed tunnel measurement""" + return self.tunnel == "ERROR/tunnel" + + +class FieldTestingCSVFile: + """ + Models the content of a given CSV datastore and mediates access + to its entries through the `entries` property. + """ + + def __init__(self, filename: str): + """ + Initialize with CSV filename to load. + + Args: + filename: Path to the CSV file to read + """ + self.filename = filename + self._entries: Optional[List[FieldTestingCSVEntry]] = None + + def _parse_datetime(self, date_str: str) -> datetime: + """ + Parse datetime string from CSV into datetime object. + + Args: + date_str: Date string in format "%Y-%m-%d %H:%M:%S%z" + + Returns: + Parsed datetime object + """ + # TODO(bassosimone): I think this may be incorrect and we need + # to double check with the existing CSV files + return datetime.strptime(date_str, "%Y-%m-%d %H:%M:%S%z") + + def _parse_bool(self, value: str) -> bool: + """ + Parse boolean string from CSV into bool. + + Args: + value: String representation of boolean ("true" or "false") + + Returns: + True if value.lower() == "true", False otherwise + """ + return value.lower() == "true" + + def load(self) -> List[FieldTestingCSVEntry]: + """ + Loads and returns entries from CSV file. + + Also caches the entries and makes them available + through the `entries` property. + """ + entries = [] + + with open(self.filename, "r") as f: + reader = csv.DictReader(f) + for row in reader: + try: + + measurement = FieldTestingCSVEntry( + date=self._parse_datetime(row["date"]), + asn=str(row["asn"]), + isp=str(row["isp"]), + est_city=str(row["est_city"]), + user=str(row["user"]), + region=str(row["region"]), + server_fqdn=str(row["server_fqdn"]), + server_ip=str(row["server_ip"]), + mobile=self._parse_bool(row["mobile"]), + tunnel=str(row["tunnel"]), + throughput_download=float(row["throughput_download"]), + throughput_upload=float(row["throughput_upload"]), + latency_download=float(row["latency_download"]), + latency_upload=float(row["latency_upload"]), + retransmission_download=float(row["retransmission_download"]), + retransmission_upload=float(row["retransmission_upload"]), + ping_packets_loss=float(row["ping_packets_loss"]), + ping_roundtrip_min=float(row["ping_roundtrip_min"]), + ping_roundtrip_avg=float(row["ping_roundtrip_avg"]), + ping_roundtrip_max=float(row["ping_roundtrip_max"]), + err_message=str(row["err_message"]), + protocol=str(row["PT"]), + ) + entries.append(measurement) + + except (ValueError, KeyError) as exc: + logging.warning(f"cannot import row: {exc}") + continue + + self._entries = entries + return entries + + @property + def entries(self) -> List[FieldTestingCSVEntry]: + """Get cached entries or load if not yet loaded""" + if self._entries is None: + return self.load() + return self._entries + + +def datetime_to_compact_utc(dt: datetime) -> str: + """Convert datetime to compact UTC format (YYYYMMDDThhmmssZ)""" + return dt.astimezone(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + + +def datetime_to_ooni_format(dt: datetime) -> str: + """Convert datetime to OONI's format (YYYY-MM-DD hh:mm:ss)""" + return dt.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + + +@dataclass +class AggregationTimeWindow: + """Time window for aggregating measurements""" + + from_time: datetime + to_time: datetime + + def as_dict(self) -> Dict: + """Convert to JSON-serializable dict""" + return { + "from": datetime_to_compact_utc(self.from_time), + "to": datetime_to_compact_utc(self.to_time), + } + + +@dataclass +class AggregateTunnelMetricsTestKeys: + """ + Models the test_keys portion of an OONI measurement as defined + in the aggregate tunnel metrics specification. + """ + + provider: str + scope: str # "endpoint", "endpoint_pool", or "global" + protocol: str + time_window: AggregationTimeWindow + + # Optional fields depending on scope + endpoint_hostname: Optional[str] + endpoint_address: Optional[str] + endpoint_port: Optional[int] + asn: Optional[str] # Format: ^AS[0-9]+$ + cc: Optional[str] # Format: ^[A-Z]{2}$ + bodies: List[Dict] # TODO (bassosimone): we can make this more specific later + + def as_dict(self) -> Dict: + """Convert to JSON-serializable dict""" + # Start with required fields + d = { + "provider": self.provider, + "scope": self.scope, + "protocol": self.protocol, + "time_window": self.time_window.as_dict(), + "bodies": self.bodies, + } + + # Add optional fields if they exist + for field in [ + "endpoint_hostname", + "endpoint_address", + "endpoint_port", + "asn", + "cc", + ]: + value = getattr(self, field) + if value is not None: + d[field] = value + + return d + + +@dataclass +class OONIMeasurement: + """ + Models the OONI measurement envelope. + """ + + annotations: Dict[str, str] + data_format_version: str + input: str # {protocol}://{provider}/?{query_string} + measurement_start_time: datetime + probe_asn: str # Format: ^AS[0-9]+$ + probe_cc: str # Format: ^[A-Z]{2}$ + test_keys: AggregateTunnelMetricsTestKeys + test_name: str + test_runtime: float + test_start_time: datetime + test_version: str + + def as_dict(self) -> Dict: + """Convert to JSON-serializable dict""" + # TODO(bassosimone): ensure we include the correct + # annotation about the collector we're using + return { + "annotations": self.annotations, + "data_format_version": self.data_format_version, + "input": self.input, + "measurement_start_time": datetime_to_ooni_format( + self.measurement_start_time + ), + "probe_asn": self.probe_asn, + "probe_cc": self.probe_cc, + "test_keys": self.test_keys.as_dict(), + "test_name": self.test_name, + "test_runtime": self.test_runtime, + "test_start_time": datetime_to_ooni_format(self.test_start_time), + "test_version": self.test_version, + } diff --git a/oonisubmitter/serializer.py b/oonisubmitter/serializer.py new file mode 100644 index 0000000..f29a5d6 --- /dev/null +++ b/oonisubmitter/serializer.py @@ -0,0 +1,207 @@ +""" +Serializes aggregated endpoint state into OONI measurements format. + +See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md +""" + +from datetime import datetime +from typing import Dict, List +from statistics import quantiles +from . import model +from .aggregator import AggregateEndpointState, AggregatorConfig + + +class SerializationConfigError(Exception): + """Raised when serialization configuration does not allow us to proceed.""" + + +class OONISerializer: + """Converts aggregate endpoint state into OONI measurements""" + + def __init__(self, config: AggregatorConfig): + self.config = config + + @staticmethod + def _compute_percentiles(values: List[float]) -> Dict[str, float]: + """Compute the required percentiles for OONI format""" + + # TODO(bassosimone): we should not emit data if we have + # less than a configurable amount of measurements! + + if not values: + return {} + p25, p50, p75, p99 = quantiles(values, n=100, method="exclusive") + return { + "25p": round(p25, 1), + "50p": round(p50, 1), + "75p": round(p75, 1), + "99p": round(p99, 1), + } + + def _create_input_url(self, state: AggregateEndpointState) -> str: + """Create the measurement input URL""" + base = f"{state.protocol}://{state.provider}/" + # TODO(bassosimone): use an enum here? + # TODO(bassosimone): we currently only use endpoint scope + # so wondering whether to keep this if here + if state.scope == "endpoint": + params = { + "address": state.address, + "asn": state.asn, + "hostname": state.hostname, + "port": str(state.port), + } + # TODO(bassosimone): serialise with proper urlencoding + return base + "?" + "&".join(f"{k}={v}" for k, v in params.items() if v) + return base + + def _create_bodies(self, state: AggregateEndpointState) -> List[Dict]: + """Create the bodies section of test_keys""" + bodies = [] + + # TODO(bassosimone): we should round the sample sizes! + + # Creation phase errors (including successes) + total = sum(state.errors.values()) + if total > 0: + for error_type, count in state.errors.items(): + if not error_type: + # TODO(bassosimone): do we want to submit successes? + continue + bodies.append( + { + "phase": "creation", + "sample_size": total, + "type": "network-error", + "failure_ratio": round(count / total, 2), + "error": error_type, + } + ) + + # TODO(bassosimone): can we combine pings in a single message? + + # Ping measurements - separate bodies for min/avg/max + if state.ping_min_values: + bodies.append( + { + "phase": "tunnel_ping", + "sample_size": len(state.ping_min_values), + "type": "ping_min", + "latency_ms": self._compute_percentiles(state.ping_min_values), + } + ) + + if state.ping_avg_values: + bodies.append( + { + "phase": "tunnel_ping", + "sample_size": len(state.ping_avg_values), + "type": "ping_avg", + "latency_ms": self._compute_percentiles(state.ping_avg_values), + } + ) + + if state.ping_max_values: + bodies.append( + { + "phase": "tunnel_ping", + "sample_size": len(state.ping_max_values), + "type": "ping_max", + "latency_ms": self._compute_percentiles(state.ping_max_values), + } + ) + + if state.ping_packets_loss_values: + bodies.append( + { + "phase": "tunnel_ping", + "sample_size": len(state.ping_packets_loss_values), + "type": "ping_loss", + "loss_percent": self._compute_percentiles( + state.ping_packets_loss_values + ), + } + ) + + # NDT measurements (independent of ping) + if state.download_throughput_values: + bodies.append( + { + "phase": "tunnel_ndt_download", + "sample_size": len(state.download_throughput_values), + "type": "ndt_download", + "latency_ms": self._compute_percentiles( + state.download_latency_values + ), + "speed_mbits": self._compute_percentiles( + state.download_throughput_values + ), + "retransmission_percent": self._compute_percentiles( + state.download_retransmission_values + ), + } + ) + + if state.upload_throughput_values: + bodies.append( + { + "phase": "tunnel_ndt_upload", + "sample_size": len(state.upload_throughput_values), + "type": "ndt_upload", + "latency_ms": self._compute_percentiles( + state.upload_latency_values + ), + "speed_mbits": self._compute_percentiles( + state.upload_throughput_values + ), + "retransmission_percent": self._compute_percentiles( + state.upload_retransmission_values + ), + } + ) + + return bodies + + def serialize(self, state: AggregateEndpointState) -> model.OONIMeasurement: + """ + Convert endpoint state to OONI measurement format. + + Raises: + SerializationError: if the scope is not "endpoint" + """ + if state.scope != "endpoint": + raise SerializationConfigError( + f"cannot serialize measurement with scope '{state.scope}': " + "only 'endpoint' scope is currently supported" + ) + + measurement_time = datetime.utcnow() + + test_keys = model.AggregateTunnelMetricsTestKeys( + provider=state.provider, + scope=state.scope, + protocol=state.protocol, + time_window=model.AggregationTimeWindow( + from_time=state.window_start, to_time=state.window_end + ), + endpoint_hostname=state.hostname if state.scope == "endpoint" else None, + endpoint_address=state.address if state.scope == "endpoint" else None, + endpoint_port=state.port if state.scope == "endpoint" else None, + asn=state.asn, + cc=state.cc, + bodies=self._create_bodies(state), + ) + + return model.OONIMeasurement( + annotations={"upstream_collector": self.config.upstream_collector}, + data_format_version="0.2.0", + input=self._create_input_url(state), + measurement_start_time=measurement_time, + probe_asn=self.config.probe_asn, + probe_cc=self.config.probe_cc, + test_keys=test_keys, + test_name="aggregate_tunnel_metrics", + test_runtime=0.0, + test_start_time=measurement_time, + test_version="0.1.0", + ) diff --git a/oonisubmitter/submitter.py b/oonisubmitter/submitter.py new file mode 100644 index 0000000..776b9c0 --- /dev/null +++ b/oonisubmitter/submitter.py @@ -0,0 +1,128 @@ +""" +Entry point for the aggregate-endpoint-metrics OONI measurement submitter. + +Depends on: + +1. AggregationConfig specifying what we are aggregating; + +2. file path of the state file, tracking the last field testing +measurement we considered for aggregation; + +3. FieldTestingCSVFile (or files) representing the field +testing measurements to aggregate; + +4. the window of time that we're aggregating. + +See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md +""" + +from datetime import datetime +from typing import Dict, List +import json +import logging + +from .aggregator import AggregatorConfig, EndpointAggregator +from .model import FieldTestingCSVEntry, FieldTestingCSVFile, OONIMeasurement +from .serializer import OONISerializer, SerializationConfigError + + +class OONISubmitter: + """ + Manages the submission of OONI measurements, including state tracking + and serialization. + """ + + # TODO(bassosimone): implement state-file locking + + def __init__(self, config: AggregatorConfig, state_file: str): + """ + Initializes the submitter with the given configuration and state file. + """ + self.config = config + self.state_file = state_file + self.state = self._load_state() + self.serializer = OONISerializer(config) + + def _load_state(self) -> Dict: + """ + Load submission state from disk. + """ + try: + with open(self.state_file) as f: + return json.load(f) + except FileNotFoundError: + return {"last_submitted": None} + + def _save_state(self): + """Save current state to disk""" + with open(self.state_file, "w") as f: + json.dump(self.state, f) + + def _should_process(self, entry: FieldTestingCSVEntry) -> bool: + """ + Determine whether we should process this entry based on + last submission state. + + Args: + entry: The CSV entry to check + + Returns: + True if entry should be processed, False otherwise + """ + # If we've never submitted anything, just process everything + if self.state["last_submitted"] is None: + return True + + # Parse last submission time from state + last_submitted = datetime.fromisoformat(self.state["last_submitted"]) + + # Only process entries newer than our last submission + return entry.date > last_submitted + + # TODO(bassosimone): the current method is not super + # ergonomic for the following reasons: + # + # 1. we should probably load the CSV file(s) ourselves + # + # 2. we should auto-compute the time window + # + # For now, I am leaving it all as-is since we need + # this function for integration testing and we will + # need to iterate on the top-level API anyway: we + # still need to specify how to properly hook in the + # rest of the ETL pipeline. + + def process_csv_file( + self, + csv_file: FieldTestingCSVFile, + window_start: datetime, + window_end: datetime, + ) -> List[OONIMeasurement]: + """ + Process CSV file and return JSON strings for new measurements. + + Return the raw measurements to submit. + + Note: DOES NOT update the state, which should be updated once the + measurements have been successfully submitted. + """ + aggregator = EndpointAggregator(self.config, window_start, window_end) + + # Only process the new field-testing entries + for entry in csv_file.entries: + if self._should_process(entry): + aggregator.update(entry) + + # Generate the measurements + measurements: List[OONIMeasurement] = [] + for endpoint in aggregator.endpoints.values(): + try: + measurement = self.serializer.serialize(endpoint) + measurements.append(measurement) + except SerializationConfigError: + logging.warning( + f"skipping endpoint {endpoint.hostname}: serialization config error" + ) + continue + + return measurements -- GitLab From 792b7c33c1ed9f4be1cf1385b8d67f8214c37c39 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Fri, 6 Dec 2024 15:50:49 +0100 Subject: [PATCH 02/75] fix(oonisubmitter): use correct endpoint addr/port This diff fixes the oonisubmitter to extract and use the correct endpoint address and port. We were previously using incorrectly the NDT server FQN and address, leaving the port as hardcoded to 443. I noticed this issue while working on writing test fixtures. While there, notice that the ping and ndt measurements we're generating are not related to the specific target, so also fix this issue. While there, set the stage for rounding the sample size, while not doing this yet. We will do this in a subsequent commit. While there, document in DESIGN.md the need to update the CSV generation. I don't want to do this right now, since I'm focusing on the oonisubmitter unit, but we ought to do this. Lacking the existence of a proper issue for now, I choose to put this information into a PR-visible place. (The DESIGN.md file is a placeholder anyway and I don't think we should actually merge it - it just feels a place where to document my though process but the actual content should be used to update the original spec or creating follow-up issues when merging this pull request.) --- oonisubmitter/DESIGN.md | 47 +++++++++ oonisubmitter/aggregator.py | 113 +++++++++++++-------- oonisubmitter/model.py | 19 ++++ oonisubmitter/serializer.py | 196 ++++++++++++++++++++---------------- 4 files changed, 246 insertions(+), 129 deletions(-) diff --git a/oonisubmitter/DESIGN.md b/oonisubmitter/DESIGN.md index 87ed4ee..abf7c19 100644 --- a/oonisubmitter/DESIGN.md +++ b/oonisubmitter/DESIGN.md @@ -26,3 +26,50 @@ then one component reading or writing the CSV file(s) at the same time. TODO(bassosimone): deduplication happens using pandas and we use the CSVs, so we should probably deduplicate ourselves here. + +## 2024-06-02: Additional issues + +While working to write unit and integration tests (and specifically while +writing fixtures), I realised that the current CSV format is not sufficient +to produce aggregate tunnel metrics. My proposal is to fix this by adding +additional fields to the CSV data format. (I don't think we should shuffle +the fields to retain backwards compatibility.) + +We may also want to rename some fields for clarity, however this is a +secondary order concern, for in the CSV domain what actually matters is +the index of the fields. Additionally, I don't know if the name of the +fielfs is used elsewhere in the pipeline (need to double check). + +### CSV Field Structure + +#### Current Fields + +| Field | Purpose | Source | +|---------------------|-----------------------------|--------------------| +| date | Measurement timestamp | Log timestamp | +| asn | Client ASN | whois lookup | +| isp | Client ISP | wtfismyip.com API | +| est_city | Client estimated city | wtfismyip.com API | +| user | Client user identifier | CLIENT_USER env | +| region | Client region | REGION env | +| server_fqdn | NDT test server FQDN | NDT test results | +| server_ip | NDT test server IP | NDT test results | +| mobile | Client network type | MOBILE_NETWORK env | +| tunnel | Measurement classification | Parser logic | +| throughput_* | NDT performance metrics | NDT test results | +| latency_* | NDT latency metrics | NDT test results | +| retransmission_* | NDT retransmission stats | NDT test results | +| ping_* | Ping test metrics | ping test results | +| err_message | Error details | Parser detection | +| PT | Protocol type | obfsvpn logs | + +#### Required Additional Fields + +| Field | Purpose | Source | +|----------------------|----------------------------|---------------------| +| endpoint_hostname | Tunnel endpoint FQDN | OBFS4_SERVER_HOST* | +| endpoint_address | Tunnel endpoint IP | (needs resolution) | +| endpoint_port | Tunnel endpoint port | OBFS4_PORT | +| endpoint_asn | Tunnel endpoint ASN | (needs resolution) | +| endpoint_cc | Tunnel endpoint CC | (needs resolution) | +| ping_target_address | Ping test target IP | ping command target | diff --git a/oonisubmitter/aggregator.py b/oonisubmitter/aggregator.py index 059b524..8c7a5d5 100644 --- a/oonisubmitter/aggregator.py +++ b/oonisubmitter/aggregator.py @@ -58,20 +58,31 @@ class AggregateEndpointState: # empty string to represent success errors: Dict[str, int] = field(default_factory=dict) - # Ping statistics (collect raw values) - ping_min_values: List[float] = field(default_factory=list) - ping_avg_values: List[float] = field(default_factory=list) - ping_max_values: List[float] = field(default_factory=list) - ping_packets_loss_values: List[float] = field(default_factory=list) - - # NDT statistics (collect raw values) - download_throughput_values: List[float] = field(default_factory=list) - download_latency_values: List[float] = field(default_factory=list) - download_retransmission_values: List[float] = field(default_factory=list) - - upload_throughput_values: List[float] = field(default_factory=list) - upload_latency_values: List[float] = field(default_factory=list) - upload_retransmission_values: List[float] = field(default_factory=list) + # Ping statistics organised by target + # + # { + # "8.8.8.8": { + # "min": [...], + # "avg": [...], + # "max": [...], + # "loss": [...] + # } + # } + ping_measurements: Dict[str, Dict[str, List[float]]] = field(default_factory=dict) + + # NDT statistics organised by target + # + # { + # "server.fqdn:ip": { + # "download_throughput": [...], + # "download_latency": [...], + # "download_retransmission": [...], + # "upload_throughput": [...], + # "upload_latency": [...], + # "upload_retransmission": [...], + # } + # } + ndt_measurements: Dict[str, Dict[str, List[float]]] = field(default_factory=dict) def update_error_counts(self, entry: model.FieldTestingCSVEntry) -> None: """Update error counts based on a new entry""" @@ -83,21 +94,44 @@ class AggregateEndpointState: def update_performance_metrics(self, entry: model.FieldTestingCSVEntry) -> None: """Update performance metrics based on a new entry""" if not entry.is_tunnel_error_measurement(): # only successful measurements - # Ping metrics - self.ping_min_values.append(entry.ping_roundtrip_min) - self.ping_avg_values.append(entry.ping_roundtrip_avg) - self.ping_max_values.append(entry.ping_roundtrip_max) - self.ping_packets_loss_values.append(entry.ping_packets_loss) - - # Download metrics - self.download_throughput_values.append(entry.throughput_download) - self.download_latency_values.append(entry.latency_download) - self.download_retransmission_values.append(entry.retransmission_download) - - # Upload metrics - self.upload_throughput_values.append(entry.throughput_upload) - self.upload_latency_values.append(entry.latency_upload) - self.upload_retransmission_values.append(entry.retransmission_upload) + self._update_ping(entry) + self._update_ndt(entry) + + def _update_ping(self, entry: model.FieldTestingCSVEntry) -> None: + """Unconditionally update the ping metrics.""" + ping_target = entry.ping_target_address + if ping_target not in self.ping_measurements: + self.ping_measurements[ping_target] = { + "min": [], + "avg": [], + "max": [], + "loss": [], + } + metrics = self.ping_measurements[ping_target] + metrics["min"].append(entry.ping_roundtrip_min) + metrics["avg"].append(entry.ping_roundtrip_avg) + metrics["max"].append(entry.ping_roundtrip_max) + metrics["loss"].append(entry.ping_packets_loss) + + def _update_ndt(self, entry: model.FieldTestingCSVEntry) -> None: + """Unconditionally update the NDT metrics.""" + ndt_target = f"{entry.server_fqdn}:{entry.server_ip}" + if ndt_target not in self.ndt_measurements: + self.ndt_measurements[ndt_target] = { + "download_throughput": [], + "download_latency": [], + "download_retransmission": [], + "upload_throughput": [], + "upload_latency": [], + "upload_retransmission": [], + } + metrics = self.ndt_measurements[ndt_target] + metrics["download_throughput"].append(entry.throughput_download) + metrics["download_latency"].append(entry.latency_download) + metrics["download_retransmission"].append(entry.retransmission_download) + metrics["upload_throughput"].append(entry.throughput_upload) + metrics["upload_latency"].append(entry.latency_upload) + metrics["upload_retransmission"].append(entry.retransmission_upload) @classmethod def from_csv_entry( @@ -108,12 +142,12 @@ class AggregateEndpointState: window_end: datetime, ) -> "AggregateEndpointState": return cls( - hostname=entry.server_fqdn, - address=entry.server_ip, - port=443, # TODO(bassosimone): check whether this is in the CSV + hostname=entry.endpoint_hostname, + address=entry.endpoint_address, + port=entry.endpoint_port, protocol=entry.protocol, - asn=entry.asn, - cc=entry.region, + asn=entry.endpoint_asn, + cc=entry.endpoint_cc, provider=config.provider, probe_asn=config.probe_asn, probe_cc=config.probe_cc, @@ -136,12 +170,9 @@ class EndpointAggregator: self.window_end = window_end self.endpoints: Dict[str, AggregateEndpointState] = {} - def _make_key(self, hostname: str, address: str) -> str: + def _make_key(self, entry: model.FieldTestingCSVEntry) -> str: """Create unique key for an endpoint""" - # TODO(bassosimone): I wonder whether we need a more precise - # key here rather than the endpoint alone. This is probably not - # sufficient to address tcp/udp differences, for example. - return f"{hostname}:{address}" + return f"{entry.endpoint_hostname}|{entry.endpoint_address}|{entry.endpoint_port}|{entry.protocol}" def _is_in_window(self, entry: model.FieldTestingCSVEntry) -> bool: """Check if entry falls within our time window""" @@ -163,8 +194,8 @@ class EndpointAggregator: if not self._is_tunnel_entry(entry): return - # Make sure we have an updated endpoint key here - key = self._make_key(entry.server_fqdn, entry.server_ip) + # Make sure we are tracking this endpoint + key = self._make_key(entry) if key not in self.endpoints: self.endpoints[key] = AggregateEndpointState.from_csv_entry( entry, self.config, self.window_start, self.window_end diff --git a/oonisubmitter/model.py b/oonisubmitter/model.py index 42bc5f1..253f8c0 100644 --- a/oonisubmitter/model.py +++ b/oonisubmitter/model.py @@ -42,6 +42,8 @@ class FieldTestingCSVEntry: of the fields within the CSV file. """ + # Fields originally present in the CSV file + # format as of 2024-12-06 date: datetime asn: str isp: str @@ -65,6 +67,17 @@ class FieldTestingCSVEntry: err_message: str protocol: str + # Fields added on 2024-12-06 to allow for exporting + # endpoint-level aggregate tunnel metrics. + # + # TODO(XXX): update the CSV file spec and generation. + endpoint_hostname: str + endpoint_address: str + endpoint_port: int + endpoint_asn: str + endpoint_cc: str + ping_target_address: str + def is_tunnel_measurement(self) -> bool: """ Return whether this is a tunnel measurement, which includes both @@ -156,6 +169,12 @@ class FieldTestingCSVFile: ping_roundtrip_max=float(row["ping_roundtrip_max"]), err_message=str(row["err_message"]), protocol=str(row["PT"]), + endpoint_hostname=str(row["endpoint_hostname"]), + endpoint_address=str(row["endpoint_address"]), + endpoint_port=int(row["endpoint_port"]), + endpoint_asn=str(row["endpoint_asn"]), + endpoint_cc=str(row["endpoint_cc"]), + ping_target_address=str(row["ping_target_address"]), ) entries.append(measurement) diff --git a/oonisubmitter/serializer.py b/oonisubmitter/serializer.py index f29a5d6..d75b3c2 100644 --- a/oonisubmitter/serializer.py +++ b/oonisubmitter/serializer.py @@ -55,111 +55,131 @@ class OONISerializer: return base + "?" + "&".join(f"{k}={v}" for k, v in params.items() if v) return base - def _create_bodies(self, state: AggregateEndpointState) -> List[Dict]: - """Create the bodies section of test_keys""" + def _round_sample_size(self, sample_size: int) -> int: + """Round the sample size according to the aggregate tunnel metrics spec""" + # TODO(bassosimone): implement rounding of the sample size + # according to what has been written inside the spec + return sample_size + + def _create_error_bodies(self, state: AggregateEndpointState) -> List[Dict]: + """Create error bodies if there are any errors""" bodies = [] - - # TODO(bassosimone): we should round the sample sizes! - - # Creation phase errors (including successes) total = sum(state.errors.values()) if total > 0: for error_type, count in state.errors.items(): - if not error_type: - # TODO(bassosimone): do we want to submit successes? + if not error_type: # Skip success counts continue bodies.append( { "phase": "creation", - "sample_size": total, + "sample_size": self._round_sample_size(total), "type": "network-error", "failure_ratio": round(count / total, 2), "error": error_type, } ) + return bodies - # TODO(bassosimone): can we combine pings in a single message? - - # Ping measurements - separate bodies for min/avg/max - if state.ping_min_values: - bodies.append( - { - "phase": "tunnel_ping", - "sample_size": len(state.ping_min_values), - "type": "ping_min", - "latency_ms": self._compute_percentiles(state.ping_min_values), - } - ) - - if state.ping_avg_values: - bodies.append( - { - "phase": "tunnel_ping", - "sample_size": len(state.ping_avg_values), - "type": "ping_avg", - "latency_ms": self._compute_percentiles(state.ping_avg_values), - } - ) - - if state.ping_max_values: - bodies.append( - { - "phase": "tunnel_ping", - "sample_size": len(state.ping_max_values), - "type": "ping_max", - "latency_ms": self._compute_percentiles(state.ping_max_values), - } - ) - - if state.ping_packets_loss_values: - bodies.append( - { - "phase": "tunnel_ping", - "sample_size": len(state.ping_packets_loss_values), - "type": "ping_loss", - "loss_percent": self._compute_percentiles( - state.ping_packets_loss_values - ), - } - ) + def _create_ping_bodies(self, state: AggregateEndpointState) -> List[Dict]: + """Create bodies for ping measurements""" + bodies = [] + for target_address, measurements in state.ping_measurements.items(): + # Min/Avg/Max latency bodies + for metric_type in ["min", "avg", "max"]: + if measurements[metric_type]: # Only if we have measurements + bodies.append( + { + "phase": "tunnel_ping", + "sample_size": self._round_sample_size( + len(measurements[metric_type]) + ), + "type": f"ping_{metric_type}", + "target_address": target_address, + "latency_ms": self._compute_percentiles( + measurements[metric_type] + ), + } + ) + + # Packet loss body + if measurements["loss"]: + bodies.append( + { + "phase": "tunnel_ping", + "sample_size": self._round_sample_size( + len(measurements["loss"]) + ), + "type": "ping_loss", + "target_address": target_address, + "loss_percent": self._compute_percentiles(measurements["loss"]), + } + ) + return bodies - # NDT measurements (independent of ping) - if state.download_throughput_values: - bodies.append( - { - "phase": "tunnel_ndt_download", - "sample_size": len(state.download_throughput_values), - "type": "ndt_download", - "latency_ms": self._compute_percentiles( - state.download_latency_values - ), - "speed_mbits": self._compute_percentiles( - state.download_throughput_values - ), - "retransmission_percent": self._compute_percentiles( - state.download_retransmission_values - ), - } - ) + def _create_ndt_bodies(self, state: AggregateEndpointState) -> List[Dict]: + """Create bodies for NDT measurements""" + bodies = [] + for target_id, measurements in state.ndt_measurements.items(): + # TODO(bassosimone): I am not convinced we should use this algorithm + # here to represent an NDT server and maybe we could go for sth + # that is more direct and the `:` is annoying because of IPv6 anyway + hostname, address = target_id.split(":", 1) + + # Download measurements + if measurements["download_throughput"]: + bodies.append( + { + "phase": "tunnel_ndt_download", + "sample_size": self._round_sample_size( + len(measurements["download_throughput"]) + ), + "type": "ndt_download", + "target_hostname": hostname, + "target_address": address, + "target_port": 443, # TODO: Get actual port + "latency_ms": self._compute_percentiles( + measurements["download_latency"] + ), + "speed_mbits": self._compute_percentiles( + measurements["download_throughput"] + ), + "retransmission_percent": self._compute_percentiles( + measurements["download_retransmission"] + ), + } + ) - if state.upload_throughput_values: - bodies.append( - { - "phase": "tunnel_ndt_upload", - "sample_size": len(state.upload_throughput_values), - "type": "ndt_upload", - "latency_ms": self._compute_percentiles( - state.upload_latency_values - ), - "speed_mbits": self._compute_percentiles( - state.upload_throughput_values - ), - "retransmission_percent": self._compute_percentiles( - state.upload_retransmission_values - ), - } - ) + # Upload measurements + if measurements["upload_throughput"]: + bodies.append( + { + "phase": "tunnel_ndt_upload", + "sample_size": self._round_sample_size( + len(measurements["upload_throughput"]) + ), + "type": "ndt_upload", + "target_hostname": hostname, + "target_address": address, + "target_port": 443, # TODO(bassosimone): get actual port + "latency_ms": self._compute_percentiles( + measurements["upload_latency"] + ), + "speed_mbits": self._compute_percentiles( + measurements["upload_throughput"] + ), + "retransmission_percent": self._compute_percentiles( + measurements["upload_retransmission"] + ), + } + ) + return bodies + def _create_bodies(self, state: AggregateEndpointState) -> List[Dict]: + """Create the bodies section of test_keys""" + bodies = [] + bodies.extend(self._create_error_bodies(state)) + bodies.extend(self._create_ping_bodies(state)) + bodies.extend(self._create_ndt_bodies(state)) return bodies def serialize(self, state: AggregateEndpointState) -> model.OONIMeasurement: -- GitLab From 906bd56966d6e07693f48f19e98610f3081a0c59 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Fri, 6 Dec 2024 16:42:53 +0100 Subject: [PATCH 03/75] feat(oonisubmitter): test the aggregator This diff adds (for now, smoke) tests for the aggregator. The general idea is to start testing we're aggregating correctly before moving to test we're producing the correct measurement. --- .gitignore | 3 +- oonisubmitter/aggregator.py | 50 ++++++++- oonisubmitter/aggregator_test.py | 56 ++++++++++ oonisubmitter/model.py | 29 +++-- oonisubmitter/serializer.py | 7 +- oonisubmitter/submitter.py | 8 +- oonisubmitter/testdata/expected_state.json | 124 +++++++++++++++++++++ oonisubmitter/testdata/sample.csv | 4 + 8 files changed, 265 insertions(+), 16 deletions(-) create mode 100755 oonisubmitter/aggregator_test.py create mode 100644 oonisubmitter/testdata/expected_state.json create mode 100644 oonisubmitter/testdata/sample.csv diff --git a/.gitignore b/.gitignore index 1828c3b..be38085 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ +__pycache__ logs/* -results/* \ No newline at end of file +results/* diff --git a/oonisubmitter/aggregator.py b/oonisubmitter/aggregator.py index 8c7a5d5..baaee6c 100644 --- a/oonisubmitter/aggregator.py +++ b/oonisubmitter/aggregator.py @@ -8,7 +8,10 @@ See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-ag from dataclasses import dataclass, field from datetime import datetime from typing import Dict, List -from . import model + +# TODO(bassosimone): I think we should consider creating +# proper python modules eventually :thinking: +import model @dataclass @@ -156,6 +159,31 @@ class AggregateEndpointState: window_end=window_end, ) + def to_dict(self) -> Dict: + """Convert state to a JSON-serializable dictionary""" + return { + # Core identification + "hostname": self.hostname, + "address": self.address, + "port": self.port, + "protocol": self.protocol, + # Classification + "asn": self.asn, + "cc": self.cc, + # Context + "provider": self.provider, + "probe_asn": self.probe_asn, + "probe_cc": self.probe_cc, + "scope": self.scope, + # Time window + "window_start": model.datetime_to_compact_utc(self.window_start), + "window_end": model.datetime_to_compact_utc(self.window_end), + # Measurements + "errors": self.errors, + "ping_measurements": self.ping_measurements, + "ndt_measurements": self.ndt_measurements, + } + class EndpointAggregator: """ @@ -205,3 +233,23 @@ class EndpointAggregator: epnt = self.endpoints[key] epnt.update_performance_metrics(entry) epnt.update_error_counts(entry) + + def to_dict(self) -> Dict: + """Convert aggregator state to a JSON-serializable dictionary""" + return { + # Config + "config": { + "provider": self.config.provider, + "upstream_collector": self.config.upstream_collector, + "probe_asn": self.config.probe_asn, + "probe_cc": self.config.probe_cc, + "scope": self.config.scope, + }, + # Time window + "window_start": model.datetime_to_compact_utc(self.window_start), + "window_end": model.datetime_to_compact_utc(self.window_end), + # Endpoints state + "endpoints": { + key: state.to_dict() for key, state in self.endpoints.items() + }, + } diff --git a/oonisubmitter/aggregator_test.py b/oonisubmitter/aggregator_test.py new file mode 100755 index 0000000..d4a3498 --- /dev/null +++ b/oonisubmitter/aggregator_test.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 + +"""Basic tests for the endpoint aggregator""" + +import json +import unittest +from datetime import datetime, timezone +from difflib import unified_diff + +import aggregator +import model + + +class TestAggregator(unittest.TestCase): + """Test the endpoint aggregator""" + + def test_aggregator(self): + """Verify that the aggregator correctly processes sample data""" + # Load test data + csv_file = model.FieldTestingCSVFile("testdata/sample.csv") + + # Configure aggregator + config = aggregator.AggregatorConfig( + provider="test.provider", + upstream_collector="test-collector", + probe_asn="AS12345", + probe_cc="XX", + ) + window_start = datetime(2024, 1, 1, 12, tzinfo=timezone.utc) + window_end = datetime(2024, 1, 1, 14, tzinfo=timezone.utc) + + # Process entries + aggr = aggregator.EndpointAggregator(config, window_start, window_end) + for entry in csv_file.entries: + aggr.update(entry) + + # Compare with expected state + actual_lines = json.dumps(aggr.to_dict(), indent=4).splitlines() + with open("testdata/expected_state.json") as f: + expected_lines = f.read().splitlines() + + if actual_lines != expected_lines: + diff = list( + unified_diff( + expected_lines, + actual_lines, + fromfile="expected", + tofile="actual", + lineterm="", + ) + ) + self.fail("\n".join(diff)) + + +if __name__ == "__main__": + unittest.main() diff --git a/oonisubmitter/model.py b/oonisubmitter/model.py index 253f8c0..382dde3 100644 --- a/oonisubmitter/model.py +++ b/oonisubmitter/model.py @@ -108,17 +108,28 @@ class FieldTestingCSVFile: def _parse_datetime(self, date_str: str) -> datetime: """ - Parse datetime string from CSV into datetime object. + Parse ctime formatted date from CSV into datetime object. - Args: - date_str: Date string in format "%Y-%m-%d %H:%M:%S%z" - - Returns: - Parsed datetime object + Example format: "Fri Dec 6 15:27:16 UTC 2024" """ - # TODO(bassosimone): I think this may be incorrect and we need - # to double check with the existing CSV files - return datetime.strptime(date_str, "%Y-%m-%d %H:%M:%S%z") + # strptime directives: + # %a - Weekday name (Mon) + # %b - Month name (Nov) + # %d - Day of month (18) + # %H:%M:%S - Time (17:18:39) + # %Z - Timezone name (UTC) + # %Y - Year (2024) + dt = datetime.strptime(date_str, "%a %b %d %H:%M:%S %Z %Y") + + # For now, since we expect UTC, let's be strict + # + # TODO(bassosimone): do we need to care about non-UTC? + if "UTC" not in date_str: + raise ValueError( + f"expected UTC timezone in date string, got: {date_str}" + ) + + return dt.replace(tzinfo=timezone.utc) def _parse_bool(self, value: str) -> bool: """ diff --git a/oonisubmitter/serializer.py b/oonisubmitter/serializer.py index d75b3c2..10ec694 100644 --- a/oonisubmitter/serializer.py +++ b/oonisubmitter/serializer.py @@ -7,8 +7,11 @@ See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-ag from datetime import datetime from typing import Dict, List from statistics import quantiles -from . import model -from .aggregator import AggregateEndpointState, AggregatorConfig + +# TODO(bassosimone): I think we should consider creating +# proper python modules eventually :thinking: +import model +from aggregator import AggregateEndpointState, AggregatorConfig class SerializationConfigError(Exception): diff --git a/oonisubmitter/submitter.py b/oonisubmitter/submitter.py index 776b9c0..51338a3 100644 --- a/oonisubmitter/submitter.py +++ b/oonisubmitter/submitter.py @@ -21,9 +21,11 @@ from typing import Dict, List import json import logging -from .aggregator import AggregatorConfig, EndpointAggregator -from .model import FieldTestingCSVEntry, FieldTestingCSVFile, OONIMeasurement -from .serializer import OONISerializer, SerializationConfigError +# TODO(bassosimone): I think we should consider creating +# proper python modules eventually :thinking: +from aggregator import AggregatorConfig, EndpointAggregator +from model import FieldTestingCSVEntry, FieldTestingCSVFile, OONIMeasurement +from serializer import OONISerializer, SerializationConfigError class OONISubmitter: diff --git a/oonisubmitter/testdata/expected_state.json b/oonisubmitter/testdata/expected_state.json new file mode 100644 index 0000000..30bebf7 --- /dev/null +++ b/oonisubmitter/testdata/expected_state.json @@ -0,0 +1,124 @@ +{ + "config": { + "provider": "test.provider", + "upstream_collector": "test-collector", + "probe_asn": "AS12345", + "probe_cc": "XX", + "scope": "endpoint" + }, + "window_start": "20240101T120000Z", + "window_end": "20240101T140000Z", + "endpoints": { + "bridge1.test.org|192.168.0.1|443|obfs4": { + "hostname": "bridge1.test.org", + "address": "192.168.0.1", + "port": 443, + "protocol": "obfs4", + "asn": "AS1234", + "cc": "IT", + "provider": "test.provider", + "probe_asn": "AS12345", + "probe_cc": "XX", + "scope": "endpoint", + "window_start": "20240101T120000Z", + "window_end": "20240101T140000Z", + "errors": { + "": 1, + "bootstrap.generic_error": 1 + }, + "ping_measurements": { + "8.8.8.8": { + "min": [ + 10.0 + ], + "avg": [ + 12.0 + ], + "max": [ + 15.0 + ], + "loss": [ + 0.0 + ] + } + }, + "ndt_measurements": { + "ndt1.test.org:10.0.0.1": { + "download_throughput": [ + 100.0 + ], + "download_latency": [ + 20.0 + ], + "download_retransmission": [ + 0.01 + ], + "upload_throughput": [ + 50.0 + ], + "upload_latency": [ + 25.0 + ], + "upload_retransmission": [ + 0.02 + ] + } + } + }, + "bridge2.test.org|192.168.0.2|443|obfs4+kcp": { + "hostname": "bridge2.test.org", + "address": "192.168.0.2", + "port": 443, + "protocol": "obfs4+kcp", + "asn": "AS1234", + "cc": "IT", + "provider": "test.provider", + "probe_asn": "AS12345", + "probe_cc": "XX", + "scope": "endpoint", + "window_start": "20240101T120000Z", + "window_end": "20240101T140000Z", + "errors": { + "": 1 + }, + "ping_measurements": { + "8.8.4.4": { + "min": [ + 11.0 + ], + "avg": [ + 13.0 + ], + "max": [ + 16.0 + ], + "loss": [ + 0.0 + ] + } + }, + "ndt_measurements": { + "ndt1.test.org:10.0.0.1": { + "download_throughput": [ + 120.0 + ], + "download_latency": [ + 22.0 + ], + "download_retransmission": [ + 0.015 + ], + "upload_throughput": [ + 60.0 + ], + "upload_latency": [ + 27.0 + ], + "upload_retransmission": [ + 0.025 + ] + } + } + } + } +} diff --git a/oonisubmitter/testdata/sample.csv b/oonisubmitter/testdata/sample.csv new file mode 100644 index 0000000..9ca756e --- /dev/null +++ b/oonisubmitter/testdata/sample.csv @@ -0,0 +1,4 @@ +date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT,endpoint_hostname,endpoint_address,endpoint_port,endpoint_asn,endpoint_cc,ping_target_address +Mon Jan 01 12:00:00 UTC 2024,AS12345,Test ISP,Test City,user1,TestRegion,ndt1.test.org,10.0.0.1,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4,bridge1.test.org,192.168.0.1,443,AS1234,IT,8.8.8.8 +Mon Jan 01 12:30:00 UTC 2024,AS12345,Test ISP,Test City,user1,TestRegion,ndt2.test.org,10.0.0.2,false,ERROR/tunnel,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,connection_failed,obfs4,bridge1.test.org,192.168.0.1,443,AS1234,IT,8.8.8.8 +Mon Jan 01 13:00:00 UTC 2024,AS12345,Test ISP,Test City,user1,TestRegion,ndt1.test.org,10.0.0.1,false,tunnel,120.0,60.0,22.0,27.0,0.015,0.025,0.0,11.0,13.0,16.0,,obfs4+kcp,bridge2.test.org,192.168.0.2,443,AS1234,IT,8.8.4.4 -- GitLab From 832efa5b264ec07be1d586c0927393d05f6051e7 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Fri, 6 Dec 2024 16:55:43 +0100 Subject: [PATCH 04/75] fix(oonisubmitter): align with actual CSV format This diff aligns the model and the fixtures with the actual CSV data format after comparison with the actual data. --- oonisubmitter/DESIGN.md | 8 +++++++- oonisubmitter/model.py | 10 +++++----- oonisubmitter/testdata/sample.csv | 8 ++++---- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/oonisubmitter/DESIGN.md b/oonisubmitter/DESIGN.md index abf7c19..268a201 100644 --- a/oonisubmitter/DESIGN.md +++ b/oonisubmitter/DESIGN.md @@ -68,8 +68,14 @@ fielfs is used elsewhere in the pipeline (need to double check). | Field | Purpose | Source | |----------------------|----------------------------|---------------------| | endpoint_hostname | Tunnel endpoint FQDN | OBFS4_SERVER_HOST* | -| endpoint_address | Tunnel endpoint IP | (needs resolution) | +| endpoint_address | Tunnel endpoint IP | (needs resolution?) | | endpoint_port | Tunnel endpoint port | OBFS4_PORT | | endpoint_asn | Tunnel endpoint ASN | (needs resolution) | | endpoint_cc | Tunnel endpoint CC | (needs resolution) | | ping_target_address | Ping test target IP | ping command target | + +#### CSV Field Mapping Notes + +- The `protocol` field is mapped from the CSV column named `PT` +- Empty error messages are represented as "" (empty string) +- The `filename` field is parsed but not used for metrics aggregation diff --git a/oonisubmitter/model.py b/oonisubmitter/model.py index 382dde3..a1cb6a1 100644 --- a/oonisubmitter/model.py +++ b/oonisubmitter/model.py @@ -44,6 +44,7 @@ class FieldTestingCSVEntry: # Fields originally present in the CSV file # format as of 2024-12-06 + filename: str date: datetime asn: str isp: str @@ -125,9 +126,7 @@ class FieldTestingCSVFile: # # TODO(bassosimone): do we need to care about non-UTC? if "UTC" not in date_str: - raise ValueError( - f"expected UTC timezone in date string, got: {date_str}" - ) + raise ValueError(f"expected UTC timezone in date string, got: {date_str}") return dt.replace(tzinfo=timezone.utc) @@ -158,6 +157,7 @@ class FieldTestingCSVFile: try: measurement = FieldTestingCSVEntry( + filename=str(row["filename"]), date=self._parse_datetime(row["date"]), asn=str(row["asn"]), isp=str(row["isp"]), @@ -178,8 +178,8 @@ class FieldTestingCSVFile: ping_roundtrip_min=float(row["ping_roundtrip_min"]), ping_roundtrip_avg=float(row["ping_roundtrip_avg"]), ping_roundtrip_max=float(row["ping_roundtrip_max"]), - err_message=str(row["err_message"]), - protocol=str(row["PT"]), + err_message=str(row["err_message"]).strip(), + protocol=str(row["PT"]), # rename from "PT" to "protocol" endpoint_hostname=str(row["endpoint_hostname"]), endpoint_address=str(row["endpoint_address"]), endpoint_port=int(row["endpoint_port"]), diff --git a/oonisubmitter/testdata/sample.csv b/oonisubmitter/testdata/sample.csv index 9ca756e..9818421 100644 --- a/oonisubmitter/testdata/sample.csv +++ b/oonisubmitter/testdata/sample.csv @@ -1,4 +1,4 @@ -date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT,endpoint_hostname,endpoint_address,endpoint_port,endpoint_asn,endpoint_cc,ping_target_address -Mon Jan 01 12:00:00 UTC 2024,AS12345,Test ISP,Test City,user1,TestRegion,ndt1.test.org,10.0.0.1,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4,bridge1.test.org,192.168.0.1,443,AS1234,IT,8.8.8.8 -Mon Jan 01 12:30:00 UTC 2024,AS12345,Test ISP,Test City,user1,TestRegion,ndt2.test.org,10.0.0.2,false,ERROR/tunnel,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,connection_failed,obfs4,bridge1.test.org,192.168.0.1,443,AS1234,IT,8.8.8.8 -Mon Jan 01 13:00:00 UTC 2024,AS12345,Test ISP,Test City,user1,TestRegion,ndt1.test.org,10.0.0.1,false,tunnel,120.0,60.0,22.0,27.0,0.015,0.025,0.0,11.0,13.0,16.0,,obfs4+kcp,bridge2.test.org,192.168.0.2,443,AS1234,IT,8.8.4.4 +filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT,endpoint_hostname,endpoint_address,endpoint_port,endpoint_asn,endpoint_cc,ping_target_address +1.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,Test ISP,Test City,user1,TestRegion,ndt1.test.org,10.0.0.1,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4,bridge1.test.org,192.168.0.1,443,AS1234,IT,8.8.8.8 +1.csv,Mon Jan 01 12:30:00 UTC 2024,AS12345,Test ISP,Test City,user1,TestRegion,ndt2.test.org,10.0.0.2,false,ERROR/tunnel,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,connection_failed,obfs4,bridge1.test.org,192.168.0.1,443,AS1234,IT,8.8.8.8 +1.csv,Mon Jan 01 13:00:00 UTC 2024,AS12345,Test ISP,Test City,user1,TestRegion,ndt1.test.org,10.0.0.1,false,tunnel,120.0,60.0,22.0,27.0,0.015,0.025,0.0,11.0,13.0,16.0,,obfs4+kcp,bridge2.test.org,192.168.0.2,443,AS1234,IT,8.8.4.4 -- GitLab From 2677c91dcbcd68e3104d9394b8da305f0eeafe28 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Fri, 6 Dec 2024 17:56:16 +0100 Subject: [PATCH 05/75] doc: revamp design/explanation for this branch I tried to write down clearly what is my understanding, what are my assumptions, and what changes are needed now or I think may be needed in the future, to accommodate for this work. --- oonisubmitter/DESIGN.md | 197 +++++++++++++++++++++++++++++++++------- 1 file changed, 165 insertions(+), 32 deletions(-) diff --git a/oonisubmitter/DESIGN.md b/oonisubmitter/DESIGN.md index 268a201..cb3ad6f 100644 --- a/oonisubmitter/DESIGN.md +++ b/oonisubmitter/DESIGN.md @@ -1,48 +1,139 @@ -# Submit aggregate tunnel metrics to OONI +# Design Notes for OONI Metrics Integration -This pipeline component is meant to run periodically. The entry -point is `submitter.py` and the other scripts are lower-level -components of the overall implementation: +* Author: sbs +* Status: draft/notes +* Xref: https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md +* Branch: https://0xacab.org/solitech/monitoring/-/commits/feat/aggregate_tunnel_metrics -- `aggregator.py` contains the code to aggregate endpoint metrics; +THIS IS A WORKING DOCUMENT for the integration branch. The branch +itself is meant for development and for showing how everything would +fit together. I'd like to actually merge the changes as smaller, +focused, well-tested merge requests. -- `model.py` contains the data model; +Note: while this document asks questions about future +trajectory and improvements, my aim here is mainly +to ensure that my assumptions and understanding with +respect to the current and future trajectory are correct +and that I am not making wrong assumptions due to lack +of context or participation to previous convos. -- `serializer.py` creates aggregate OONI measurements. -The component reads on-disk state specifying the last measurement -that was submitted and the path to the CSV file or files. +## Current Pipeline Understanding -It loads the CSV entries and goes through them building aggregated -endpoint results, and produces OONI measurements. +The current design is simple and compact. I tried to understand +it first before proposing how to integrate the OONI metrics +submission component. I would like you to double check my +understanding and assumptions; feedback on this would help +me to contribute better in the future. -Finally, the measurements are submitted to the OONI collector. +IIUC, the existing ETL pipeline has this structure: + +```ascii +[field tests] -> [logs/*.txt] -> [parse_all_logs.py (batch)] -> [metrics_all_logs.csv] -> [dashboard.py (dedup)] + ^ + | + [file-monitor (fastpath)] -> [prometheus metrics] +``` + +Key characteristics: + +1. Using columnar store (CSV). + +2. `parse_all_logs.py` is the batch/reprocessing component. + +3. `file-monitor` is the fastpath/streaming component. + +4. Two current outputs: +- Dashboard visualization (via pandas); +- Grafana dashboard (via pushgateway and prometheus metrics). + +5. Deduplication happens late (`dashboard.py`). + +6. `docker-compose` used to put Python scripts together. + +7. No CSV file locking, unless I am missing some detail here. + +8. Given the focus on reprocessing I assume the intent is to keep +the logfiles forever and preriodically reprocess them. + +9. I assume, given the current project structure, that scaling will +be implemented by partitioning the CSV file across dates, or that we +will eventually switch to a different columnar data store (e.g., +Parquet) and partition it -- my overall observation here is that a +columnar format seems a good fit for this kind of data. + + +## Topics to Discuss + +1. I am adding 3-4 new files to the pipeline and I am following +the existing coding style of not using `__init__.py` but I think +we should discuss whether the amount of code I am adding could +suggest we start adding some `__init__.py` and modules. + +2. IIUC, there is no file locking protecting the columnar data +store, so it is possible that multiple readers and writers could +race to update the datastore itself. I wonder whether it would +be a good idea to consider adding file locking. + +3. I tried to design the OONI submitter as an independent +component that could work also in case of time based partitioning +of the CSV files, under the assumption that this would be the +future expected growth trajectory. + +4. To make OONI export and submission more robust, I suppose we +should see to move deduplication in the batch processing, rather +than inside the dashboard support script -- sounds good? + +5. I think we need to introduce OONI-export-specific state that +remembers the last processed log entry, to avoid considering them +multiple times. This state should be orthogonal to the CSV and +persist also in case of pipeline reprocessing. + +6. I needed to *extend* the CSV data format to include more +fields required to produce the OONI measurements. I have not +implemented changes in the code generating the CSVs yet, because +I think we should ensure we all agree to include these extra +fields (I am aiming to submit at `endpoint` scope, thus I need +to include the bridge address and port -- please, let me know +if my understanding of what I should do here is wrong!). See +below for details about the CSV field extension. + + +### Short-term Integration Strategy + +I assume we could proceed as follows: -This component is designed to interface with the current ETL pipline, -which uses a single flat columnar CSV-based datastore. +1. Initial Integration + - Manual OONI submission from deduplicated CSV + - Run after parse_all_logs.py completes + - Use state file to track last submission -We probably want to start using file locking since we have more -then one component reading or writing the CSV file(s) at the same time. +2. Minimal Pipeline Changes + - Add required fields to CSV + - Keep existing deduplication for now + - Document operational procedures -TODO(bassosimone): deduplication happens using pandas and we use the -CSVs, so we should probably deduplicate ourselves here. +As a later step, we could consider how processing of telemetry +should work, but this seems a few months away from now. -## 2024-06-02: Additional issues -While working to write unit and integration tests (and specifically while -writing fixtures), I realised that the current CSV format is not sufficient -to produce aggregate tunnel metrics. My proposal is to fix this by adding -additional fields to the CSV data format. (I don't think we should shuffle -the fields to retain backwards compatibility.) +### Future Considerations -We may also want to rename some fields for clarity, however this is a -secondary order concern, for in the CSV domain what actually matters is -the index of the fields. Additionally, I don't know if the name of the -fielfs is used elsewhere in the pipeline (need to double check). +1. Deduplication + - Move to parse_all_logs.py + - Implement at CSV write time -### CSV Field Structure +2. Scaling + - Date-based CSV partitioning + - Add file locking -#### Current Fields +3. Automation + - Periodic OONI submission (via cron?) + + +## CSV Field Structure + +### Current Fields | Field | Purpose | Source | |---------------------|-----------------------------|--------------------| @@ -63,7 +154,7 @@ fielfs is used elsewhere in the pipeline (need to double check). | err_message | Error details | Parser detection | | PT | Protocol type | obfsvpn logs | -#### Required Additional Fields +### Required Additional Fields | Field | Purpose | Source | |----------------------|----------------------------|---------------------| @@ -74,8 +165,50 @@ fielfs is used elsewhere in the pipeline (need to double check). | endpoint_cc | Tunnel endpoint CC | (needs resolution) | | ping_target_address | Ping test target IP | ping command target | -#### CSV Field Mapping Notes + +### CSV Field Mapping Notes + +Notes on the `@dataclass` mapping the CSV (in `./oonisubmitter/model.py`): - The `protocol` field is mapped from the CSV column named `PT` - Empty error messages are represented as "" (empty string) - The `filename` field is parsed but not used for metrics aggregation + + +## Implementation Structure + +This pipeline component is meant to run periodically. The entry +point is `submitter.py` and the other scripts are lower-level +components of the overall implementation: + +- `aggregator.py` contains the code to aggregate endpoint metrics; + +- `model.py` contains the I/O data model; + +- `serializer.py` creates aggregate OONI measurements. + +The component reads on-disk state specifying the last measurement +that was submitted and the path to the CSV file or files. + +It loads the CSV entries and goes through them building aggregated +endpoint results, and produces OONI measurements. + +Finally, the measurements are submitted to the OONI collector. + +The high-level API of `submitter.py` is still a bit in flux +because I need to understand how we want to integrate this with +the rest of the ETL pipeline first. + + +## Recap: Questions for Reviewers + +*Immediate needs*: + +1. Is my understanding of the pipeline correct? +2. Is manual submission acceptable for initial implementation? +3. Are the proposed CSV-fields extensions acceptable? + +*Future considerations*: + +4. Should we consider implementing file locking? +5. What do you think of moving the deduplication to the batch processing? -- GitLab From 8c5606db5b0969b2a734417d7ca059a9f483b808 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Mon, 9 Dec 2024 13:08:54 +0100 Subject: [PATCH 06/75] feat(oonisubmitter): implement and use file locking This diff implements file locking using as the blueprint an existing well-known pattern in the Go landscape. We also modify oonisubmitter to use this new feature. We probably need to use a similar pattern also for other parts of the ETL pipeline. Tests are left as an exercise for my future self. --- oonisubmitter/lockedfile.py | 115 ++++++++++++++++++++++++++++++++++++ oonisubmitter/submitter.py | 25 +++++--- 2 files changed, 132 insertions(+), 8 deletions(-) create mode 100644 oonisubmitter/lockedfile.py diff --git a/oonisubmitter/lockedfile.py b/oonisubmitter/lockedfile.py new file mode 100644 index 0000000..1d0dca8 --- /dev/null +++ b/oonisubmitter/lockedfile.py @@ -0,0 +1,115 @@ +""" +Support module for reading and writing files while holding a lock on them. + +Only work on Unix systems. + +Patterned after https://github.com/rogpeppe/go-internal `lockedfile`. + +SPDX-License-Identifier: BSD-3-Clause +""" + +from dataclasses import dataclass +from io import TextIOWrapper +from typing import Optional + +import fcntl +import os +import time + +# TODO(bassosimone): write tests for this functionality +# once we have addresses more pressing issues with wiring +# in the `oonisubmitter` module into the ETL pipeline. + + +class FileLockError(Exception): + """Error emitted by this module.""" + + +@dataclass +class Config: + """ + Configures attempting to acquire a lock. + + Fields: + num_retries: Number of times to retry acquiring the lock + sleep_interval: Time to sleep between each attempt (in seconds) + """ + + num_retries: int = 10 + sleep_interval: float = 0.1 + + +def read(filepath: str, config: Optional[Config] = None) -> str: + """ + Read entire file while holding a shared lock. + + Args: + filepath: Path to file to read + config: Optional locking configuration + + Raises: + FileLockError: if cannot acquire the file lock + IOError: if file operations fail + FileNotFoundError: if the file does not exist + """ + with open(filepath, "r") as filep: + if not _acquire_shared(filep, config): + raise FileLockError(f"cannot acquire read lock on {filepath}") + try: + return filep.read() + finally: + _release(filep) + + +def write(filepath: str, data: str, config: Optional[Config] = None) -> None: + """ + Write entire file while holding an exclusive lock. + + Args: + filepath: Path to file to write + data: Content to write + config: Optional locking configuration + + Raises: + FileLockError: if cannot acquire the file lock + IOError: if file operations fail + FileNotFoundError: if the file does not exist + """ + with open(filepath, "w") as filep: + if not _acquire_exclusive(filep, config): + raise FileLockError(f"cannot acquire write lock on {filepath}") + try: + filep.write(data) + # Implementation note: flush to buffer cache and then + # persist to permanent storage with fsync + filep.flush() + os.fsync(filep.fileno()) + finally: + _release(filep) + + +def _acquire_shared(filep: TextIOWrapper, config: Optional[Config]) -> bool: + return _try_lock(filep, fcntl.LOCK_SH | fcntl.LOCK_NB, config) + + +def _acquire_exclusive(filep: TextIOWrapper, config: Optional[Config]) -> bool: + return _try_lock(filep, fcntl.LOCK_EX | fcntl.LOCK_NB, config) + + +def _release(filep: TextIOWrapper) -> None: + try: + fcntl.flock(filep.fileno(), fcntl.LOCK_UN) + except OSError: + pass + + +def _try_lock(filep: TextIOWrapper, operation: int, config: Optional[Config]) -> bool: + if not config: + config = Config() + for _ in range(config.num_retries): + try: + fcntl.flock(filep.fileno(), operation) + return True + except BlockingIOError: + time.sleep(config.sleep_interval) + return False diff --git a/oonisubmitter/submitter.py b/oonisubmitter/submitter.py index 51338a3..d166fe4 100644 --- a/oonisubmitter/submitter.py +++ b/oonisubmitter/submitter.py @@ -13,6 +13,10 @@ testing measurements to aggregate; 4. the window of time that we're aggregating. +Uses: + +1. lockedfile.read and lockedfile.write to read and write the state file. + See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md """ @@ -27,6 +31,8 @@ from aggregator import AggregatorConfig, EndpointAggregator from model import FieldTestingCSVEntry, FieldTestingCSVFile, OONIMeasurement from serializer import OONISerializer, SerializationConfigError +import lockedfile + class OONISubmitter: """ @@ -34,8 +40,6 @@ class OONISubmitter: and serialization. """ - # TODO(bassosimone): implement state-file locking - def __init__(self, config: AggregatorConfig, state_file: str): """ Initializes the submitter with the given configuration and state file. @@ -47,18 +51,23 @@ class OONISubmitter: def _load_state(self) -> Dict: """ - Load submission state from disk. + Load submission state from disk using lockedfile.read. """ try: - with open(self.state_file) as f: - return json.load(f) + return json.loads(lockedfile.read(self.state_file)) except FileNotFoundError: return {"last_submitted": None} + except IOError: + return {"last_submitted": None} + except json.JSONDecodeError: + logging.warning("invalid state file, resetting") + return {"last_submitted": None} def _save_state(self): - """Save current state to disk""" - with open(self.state_file, "w") as f: - json.dump(self.state, f) + """ + Save current state to disk using lockedfile.write. + """ + lockedfile.write(self.state_file, json.dumps(self.state)) def _should_process(self, entry: FieldTestingCSVEntry) -> bool: """ -- GitLab From 09ce59d4719ea555c9f48d4c376fa894635dabf4 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Mon, 9 Dec 2024 15:04:00 +0100 Subject: [PATCH 07/75] feat(oonisubmitter): round and possibly omit sample_size This follows the spec more strictly. --- oonisubmitter/aggregator.py | 8 +++++ oonisubmitter/serializer.py | 72 +++++++++++++++++++------------------ 2 files changed, 45 insertions(+), 35 deletions(-) diff --git a/oonisubmitter/aggregator.py b/oonisubmitter/aggregator.py index baaee6c..a0264cd 100644 --- a/oonisubmitter/aggregator.py +++ b/oonisubmitter/aggregator.py @@ -20,12 +20,20 @@ class AggregatorConfig: Configuration for the measurement aggregator. """ + # TODO(bassosimone): consider renaming this class + provider: str upstream_collector: str probe_asn: str probe_cc: str scope: str = "endpoint" # for now we only care about endpoint scope + # threshold below which we emit sample_size + min_sample_size: int = 1000 + + # rounding sample_size to the nearest round_to + round_to: int = 100 + @dataclass class AggregateEndpointState: diff --git a/oonisubmitter/serializer.py b/oonisubmitter/serializer.py index 10ec694..53799b7 100644 --- a/oonisubmitter/serializer.py +++ b/oonisubmitter/serializer.py @@ -5,7 +5,7 @@ See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-ag """ from datetime import datetime -from typing import Dict, List +from typing import Any, Dict, List, Optional from statistics import quantiles # TODO(bassosimone): I think we should consider creating @@ -28,11 +28,9 @@ class OONISerializer: def _compute_percentiles(values: List[float]) -> Dict[str, float]: """Compute the required percentiles for OONI format""" - # TODO(bassosimone): we should not emit data if we have - # less than a configurable amount of measurements! - if not values: return {} + p25, p50, p75, p99 = quantiles(values, n=100, method="exclusive") return { "25p": round(p25, 1), @@ -58,11 +56,19 @@ class OONISerializer: return base + "?" + "&".join(f"{k}={v}" for k, v in params.items() if v) return base - def _round_sample_size(self, sample_size: int) -> int: - """Round the sample size according to the aggregate tunnel metrics spec""" + def _round_sample_size(self, sample_size: int) -> Optional[int]: + """Round the sample size according to the aggregate tunnel metrics spec.""" # TODO(bassosimone): implement rounding of the sample size # according to what has been written inside the spec - return sample_size + if sample_size < self.config.min_sample_size: + return None + return round(sample_size / self.config.round_to) * self.config.round_to + + @staticmethod + def _maybe_with_sample_size(obj: Dict[str, Any], ss: Optional[int]) -> Dict[str, Any]: + if ss is not None: + obj["sample_size"] = ss + return obj def _create_error_bodies(self, state: AggregateEndpointState) -> List[Dict]: """Create error bodies if there are any errors""" @@ -72,15 +78,15 @@ class OONISerializer: for error_type, count in state.errors.items(): if not error_type: # Skip success counts continue - bodies.append( + bodies.append(self._maybe_with_sample_size( { "phase": "creation", - "sample_size": self._round_sample_size(total), "type": "network-error", "failure_ratio": round(count / total, 2), "error": error_type, - } - ) + }, + self._round_sample_size(count), + )) return bodies def _create_ping_bodies(self, state: AggregateEndpointState) -> List[Dict]: @@ -90,33 +96,29 @@ class OONISerializer: # Min/Avg/Max latency bodies for metric_type in ["min", "avg", "max"]: if measurements[metric_type]: # Only if we have measurements - bodies.append( + bodies.append(self._maybe_with_sample_size( { "phase": "tunnel_ping", - "sample_size": self._round_sample_size( - len(measurements[metric_type]) - ), "type": f"ping_{metric_type}", "target_address": target_address, "latency_ms": self._compute_percentiles( measurements[metric_type] ), - } - ) + }, + self._round_sample_size(len(measurements[metric_type])) + )) # Packet loss body if measurements["loss"]: - bodies.append( + bodies.append(self._maybe_with_sample_size( { "phase": "tunnel_ping", - "sample_size": self._round_sample_size( - len(measurements["loss"]) - ), "type": "ping_loss", "target_address": target_address, "loss_percent": self._compute_percentiles(measurements["loss"]), - } - ) + }, + self._round_sample_size(len(measurements["loss"])), + )) return bodies def _create_ndt_bodies(self, state: AggregateEndpointState) -> List[Dict]: @@ -130,12 +132,9 @@ class OONISerializer: # Download measurements if measurements["download_throughput"]: - bodies.append( + bodies.append(self._maybe_with_sample_size( { "phase": "tunnel_ndt_download", - "sample_size": self._round_sample_size( - len(measurements["download_throughput"]) - ), "type": "ndt_download", "target_hostname": hostname, "target_address": address, @@ -149,17 +148,17 @@ class OONISerializer: "retransmission_percent": self._compute_percentiles( measurements["download_retransmission"] ), - } - ) + }, + self._round_sample_size( + len(measurements["download_throughput"]) + ), + )) # Upload measurements if measurements["upload_throughput"]: - bodies.append( + bodies.append(self._maybe_with_sample_size( { "phase": "tunnel_ndt_upload", - "sample_size": self._round_sample_size( - len(measurements["upload_throughput"]) - ), "type": "ndt_upload", "target_hostname": hostname, "target_address": address, @@ -173,8 +172,11 @@ class OONISerializer: "retransmission_percent": self._compute_percentiles( measurements["upload_retransmission"] ), - } - ) + }, + self._round_sample_size( + len(measurements["upload_throughput"]) + ), + )) return bodies def _create_bodies(self, state: AggregateEndpointState) -> List[Dict]: -- GitLab From 5e1fec7bc7eaa5c57f46ca8c25a7339d2122f580 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Mon, 9 Dec 2024 15:19:56 +0100 Subject: [PATCH 08/75] fix(oonisubmitter): safely generate URLs Make sure we correctly generate all the fields. --- oonisubmitter/serializer.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/oonisubmitter/serializer.py b/oonisubmitter/serializer.py index 53799b7..e52038b 100644 --- a/oonisubmitter/serializer.py +++ b/oonisubmitter/serializer.py @@ -7,6 +7,7 @@ See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-ag from datetime import datetime from typing import Any, Dict, List, Optional from statistics import quantiles +from urllib.parse import urlunparse, urlencode # TODO(bassosimone): I think we should consider creating # proper python modules eventually :thinking: @@ -41,20 +42,28 @@ class OONISerializer: def _create_input_url(self, state: AggregateEndpointState) -> str: """Create the measurement input URL""" - base = f"{state.protocol}://{state.provider}/" - # TODO(bassosimone): use an enum here? - # TODO(bassosimone): we currently only use endpoint scope - # so wondering whether to keep this if here + # TODO(bassosimone): use enum for the scope. + # Optionally include query + query = {} if state.scope == "endpoint": - params = { + query = { "address": state.address, "asn": state.asn, "hostname": state.hostname, "port": str(state.port), } - # TODO(bassosimone): serialise with proper urlencoding - return base + "?" + "&".join(f"{k}={v}" for k, v in params.items() if v) - return base + # Filter out None/empty values + query = {k: v for k, v in query.items() if v} + + # Build URL using urlunparse for safety + return urlunparse(( + state.protocol, # scheme (e.g., "openvpn+obfs4") + state.provider, # netloc (e.g., "riseup.net") + "/", # path + "", # params + urlencode(query), # query (e.g., "address=1.2.3.4&...") + "" # fragment + )) def _round_sample_size(self, sample_size: int) -> Optional[int]: """Round the sample size according to the aggregate tunnel metrics spec.""" -- GitLab From c9933483ee0450eb0caac5cd3bc29c689975fd24 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Mon, 9 Dec 2024 15:21:55 +0100 Subject: [PATCH 09/75] cleanup(oonisubmitter): zap addressed TODO --- oonisubmitter/serializer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/oonisubmitter/serializer.py b/oonisubmitter/serializer.py index e52038b..d614e0f 100644 --- a/oonisubmitter/serializer.py +++ b/oonisubmitter/serializer.py @@ -67,8 +67,6 @@ class OONISerializer: def _round_sample_size(self, sample_size: int) -> Optional[int]: """Round the sample size according to the aggregate tunnel metrics spec.""" - # TODO(bassosimone): implement rounding of the sample size - # according to what has been written inside the spec if sample_size < self.config.min_sample_size: return None return round(sample_size / self.config.round_to) * self.config.round_to -- GitLab From 364e14905cf85ff6789b74cd04a7a0a1c3339e2e Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Mon, 9 Dec 2024 15:32:44 +0100 Subject: [PATCH 10/75] fix(oonisubmitter): use proper enum for scope --- oonisubmitter/aggregator.py | 4 ++-- oonisubmitter/model.py | 11 ++++++++++- oonisubmitter/serializer.py | 13 ++++++------- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/oonisubmitter/aggregator.py b/oonisubmitter/aggregator.py index a0264cd..52f3c4b 100644 --- a/oonisubmitter/aggregator.py +++ b/oonisubmitter/aggregator.py @@ -26,7 +26,7 @@ class AggregatorConfig: upstream_collector: str probe_asn: str probe_cc: str - scope: str = "endpoint" # for now we only care about endpoint scope + scope: model.Scope = model.Scope.ENDPOINT # for now we only care about endpoint scope # threshold below which we emit sample_size min_sample_size: int = 1000 @@ -57,7 +57,7 @@ class AggregateEndpointState: provider: str probe_asn: str probe_cc: str - scope: str + scope: model.Scope # Time window window_start: datetime diff --git a/oonisubmitter/model.py b/oonisubmitter/model.py index a1cb6a1..fce15da 100644 --- a/oonisubmitter/model.py +++ b/oonisubmitter/model.py @@ -28,11 +28,20 @@ See: from dataclasses import dataclass from datetime import datetime, timezone +from enum import Enum from typing import Dict, List, Optional + import csv import logging +class Scope(Enum): + """Valid scopes for aggregate tunnel metrics.""" + ENDPOINT = "endpoint" + ENDPOINT_POOL = "endpoint_pool" + GLOBAL = "global" + + @dataclass class FieldTestingCSVEntry: """ @@ -237,7 +246,7 @@ class AggregateTunnelMetricsTestKeys: """ provider: str - scope: str # "endpoint", "endpoint_pool", or "global" + scope: Scope protocol: str time_window: AggregationTimeWindow diff --git a/oonisubmitter/serializer.py b/oonisubmitter/serializer.py index d614e0f..b3c4eba 100644 --- a/oonisubmitter/serializer.py +++ b/oonisubmitter/serializer.py @@ -42,10 +42,9 @@ class OONISerializer: def _create_input_url(self, state: AggregateEndpointState) -> str: """Create the measurement input URL""" - # TODO(bassosimone): use enum for the scope. # Optionally include query query = {} - if state.scope == "endpoint": + if state.scope == model.Scope.ENDPOINT: query = { "address": state.address, "asn": state.asn, @@ -199,9 +198,9 @@ class OONISerializer: Convert endpoint state to OONI measurement format. Raises: - SerializationError: if the scope is not "endpoint" + SerializationError: if the scope is not model.Scope.ENDPOINT. """ - if state.scope != "endpoint": + if state.scope != model.Scope.ENDPOINT: raise SerializationConfigError( f"cannot serialize measurement with scope '{state.scope}': " "only 'endpoint' scope is currently supported" @@ -216,9 +215,9 @@ class OONISerializer: time_window=model.AggregationTimeWindow( from_time=state.window_start, to_time=state.window_end ), - endpoint_hostname=state.hostname if state.scope == "endpoint" else None, - endpoint_address=state.address if state.scope == "endpoint" else None, - endpoint_port=state.port if state.scope == "endpoint" else None, + endpoint_hostname=state.hostname if state.scope == model.Scope.ENDPOINT else None, + endpoint_address=state.address if state.scope == model.Scope.ENDPOINT else None, + endpoint_port=state.port if state.scope == model.Scope.ENDPOINT else None, asn=state.asn, cc=state.cc, bodies=self._create_bodies(state), -- GitLab From 21f667d76b5c0f615ac3af097d6bd3db70b37486 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Mon, 9 Dec 2024 15:52:27 +0100 Subject: [PATCH 11/75] refactor(oonisubmitter): use explicit endpoint and NDT server ID Slightly improves code correctness and maintainability. --- oonisubmitter/aggregator.py | 18 +++- oonisubmitter/identifiers.py | 45 +++++++++ oonisubmitter/model.py | 1 + oonisubmitter/serializer.py | 184 +++++++++++++++++++---------------- 4 files changed, 162 insertions(+), 86 deletions(-) create mode 100644 oonisubmitter/identifiers.py diff --git a/oonisubmitter/aggregator.py b/oonisubmitter/aggregator.py index 52f3c4b..457bef9 100644 --- a/oonisubmitter/aggregator.py +++ b/oonisubmitter/aggregator.py @@ -11,6 +11,7 @@ from typing import Dict, List # TODO(bassosimone): I think we should consider creating # proper python modules eventually :thinking: +from identifiers import EndpointID, NDTServerID import model @@ -26,7 +27,9 @@ class AggregatorConfig: upstream_collector: str probe_asn: str probe_cc: str - scope: model.Scope = model.Scope.ENDPOINT # for now we only care about endpoint scope + scope: model.Scope = ( + model.Scope.ENDPOINT + ) # for now we only care about endpoint scope # threshold below which we emit sample_size min_sample_size: int = 1000 @@ -126,7 +129,9 @@ class AggregateEndpointState: def _update_ndt(self, entry: model.FieldTestingCSVEntry) -> None: """Unconditionally update the NDT metrics.""" - ndt_target = f"{entry.server_fqdn}:{entry.server_ip}" + ndt_target = str( + NDTServerID(hostname=entry.server_fqdn, address=entry.server_ip) + ) if ndt_target not in self.ndt_measurements: self.ndt_measurements[ndt_target] = { "download_throughput": [], @@ -208,7 +213,14 @@ class EndpointAggregator: def _make_key(self, entry: model.FieldTestingCSVEntry) -> str: """Create unique key for an endpoint""" - return f"{entry.endpoint_hostname}|{entry.endpoint_address}|{entry.endpoint_port}|{entry.protocol}" + return str( + EndpointID( + hostname=entry.endpoint_hostname, + address=entry.endpoint_address, + port=entry.endpoint_port, + protocol=entry.protocol, + ) + ) def _is_in_window(self, entry: model.FieldTestingCSVEntry) -> bool: """Check if entry falls within our time window""" diff --git a/oonisubmitter/identifiers.py b/oonisubmitter/identifiers.py new file mode 100644 index 0000000..5149c22 --- /dev/null +++ b/oonisubmitter/identifiers.py @@ -0,0 +1,45 @@ +"""Common identifiers used across the pipeline.""" + +from dataclasses import dataclass + +SEPARATOR = "|" + + +@dataclass +class NDTServerID: + """NDT server identifier.""" + + hostname: str + address: str + + def __str__(self) -> str: + """String representation for dictionary keys.""" + return f"{self.hostname}{SEPARATOR}{self.address}" + + @classmethod + def parse(cls, identifier: str) -> "NDTServerID": + """Parse identifier string into a NDTServerID.""" + hostname, address = identifier.split(SEPARATOR, 1) + return cls(hostname=hostname, address=address) + + +@dataclass +class EndpointID: + """Endpoint identifier.""" + + hostname: str + address: str + port: int + protocol: str + + def __str__(self) -> str: + """String representation for dictionary keys.""" + return f"{self.hostname}{SEPARATOR}{self.address}{SEPARATOR}{self.port}{SEPARATOR}{self.protocol}" + + @classmethod + def parse(cls, identifier: str) -> "EndpointID": + """Parse identifier string into an Endpoint.""" + hostname, address, port, protocol = identifier.split(SEPARATOR, 3) + return cls( + hostname=hostname, address=address, port=int(port), protocol=protocol + ) diff --git a/oonisubmitter/model.py b/oonisubmitter/model.py index fce15da..8ddbfc0 100644 --- a/oonisubmitter/model.py +++ b/oonisubmitter/model.py @@ -37,6 +37,7 @@ import logging class Scope(Enum): """Valid scopes for aggregate tunnel metrics.""" + ENDPOINT = "endpoint" ENDPOINT_POOL = "endpoint_pool" GLOBAL = "global" diff --git a/oonisubmitter/serializer.py b/oonisubmitter/serializer.py index b3c4eba..b4ebe26 100644 --- a/oonisubmitter/serializer.py +++ b/oonisubmitter/serializer.py @@ -11,8 +11,9 @@ from urllib.parse import urlunparse, urlencode # TODO(bassosimone): I think we should consider creating # proper python modules eventually :thinking: -import model +from identifiers import NDTServerID from aggregator import AggregateEndpointState, AggregatorConfig +import model class SerializationConfigError(Exception): @@ -55,14 +56,16 @@ class OONISerializer: query = {k: v for k, v in query.items() if v} # Build URL using urlunparse for safety - return urlunparse(( - state.protocol, # scheme (e.g., "openvpn+obfs4") - state.provider, # netloc (e.g., "riseup.net") - "/", # path - "", # params - urlencode(query), # query (e.g., "address=1.2.3.4&...") - "" # fragment - )) + return urlunparse( + ( + state.protocol, # scheme (e.g., "openvpn+obfs4") + state.provider, # netloc (e.g., "riseup.net") + "/", # path + "", # params + urlencode(query), # query (e.g., "address=1.2.3.4&...") + "", # fragment + ) + ) def _round_sample_size(self, sample_size: int) -> Optional[int]: """Round the sample size according to the aggregate tunnel metrics spec.""" @@ -71,7 +74,9 @@ class OONISerializer: return round(sample_size / self.config.round_to) * self.config.round_to @staticmethod - def _maybe_with_sample_size(obj: Dict[str, Any], ss: Optional[int]) -> Dict[str, Any]: + def _maybe_with_sample_size( + obj: Dict[str, Any], ss: Optional[int] + ) -> Dict[str, Any]: if ss is not None: obj["sample_size"] = ss return obj @@ -84,15 +89,17 @@ class OONISerializer: for error_type, count in state.errors.items(): if not error_type: # Skip success counts continue - bodies.append(self._maybe_with_sample_size( - { - "phase": "creation", - "type": "network-error", - "failure_ratio": round(count / total, 2), - "error": error_type, - }, - self._round_sample_size(count), - )) + bodies.append( + self._maybe_with_sample_size( + { + "phase": "creation", + "type": "network-error", + "failure_ratio": round(count / total, 2), + "error": error_type, + }, + self._round_sample_size(count), + ) + ) return bodies def _create_ping_bodies(self, state: AggregateEndpointState) -> List[Dict]: @@ -102,87 +109,94 @@ class OONISerializer: # Min/Avg/Max latency bodies for metric_type in ["min", "avg", "max"]: if measurements[metric_type]: # Only if we have measurements - bodies.append(self._maybe_with_sample_size( + bodies.append( + self._maybe_with_sample_size( + { + "phase": "tunnel_ping", + "type": f"ping_{metric_type}", + "target_address": target_address, + "latency_ms": self._compute_percentiles( + measurements[metric_type] + ), + }, + self._round_sample_size(len(measurements[metric_type])), + ) + ) + + # Packet loss body + if measurements["loss"]: + bodies.append( + self._maybe_with_sample_size( { "phase": "tunnel_ping", - "type": f"ping_{metric_type}", + "type": "ping_loss", "target_address": target_address, - "latency_ms": self._compute_percentiles( - measurements[metric_type] + "loss_percent": self._compute_percentiles( + measurements["loss"] ), }, - self._round_sample_size(len(measurements[metric_type])) - )) + self._round_sample_size(len(measurements["loss"])), + ) + ) - # Packet loss body - if measurements["loss"]: - bodies.append(self._maybe_with_sample_size( - { - "phase": "tunnel_ping", - "type": "ping_loss", - "target_address": target_address, - "loss_percent": self._compute_percentiles(measurements["loss"]), - }, - self._round_sample_size(len(measurements["loss"])), - )) return bodies def _create_ndt_bodies(self, state: AggregateEndpointState) -> List[Dict]: """Create bodies for NDT measurements""" bodies = [] for target_id, measurements in state.ndt_measurements.items(): - # TODO(bassosimone): I am not convinced we should use this algorithm - # here to represent an NDT server and maybe we could go for sth - # that is more direct and the `:` is annoying because of IPv6 anyway - hostname, address = target_id.split(":", 1) + server = NDTServerID.parse(target_id) # Download measurements if measurements["download_throughput"]: - bodies.append(self._maybe_with_sample_size( - { - "phase": "tunnel_ndt_download", - "type": "ndt_download", - "target_hostname": hostname, - "target_address": address, - "target_port": 443, # TODO: Get actual port - "latency_ms": self._compute_percentiles( - measurements["download_latency"] - ), - "speed_mbits": self._compute_percentiles( - measurements["download_throughput"] - ), - "retransmission_percent": self._compute_percentiles( - measurements["download_retransmission"] + bodies.append( + self._maybe_with_sample_size( + { + "phase": "tunnel_ndt_download", + "type": "ndt_download", + "target_hostname": server.hostname, + "target_address": server.address, + "target_port": 443, # TODO: Get actual port + "latency_ms": self._compute_percentiles( + measurements["download_latency"] + ), + "speed_mbits": self._compute_percentiles( + measurements["download_throughput"] + ), + "retransmission_percent": self._compute_percentiles( + measurements["download_retransmission"] + ), + }, + self._round_sample_size( + len(measurements["download_throughput"]) ), - }, - self._round_sample_size( - len(measurements["download_throughput"]) - ), - )) + ) + ) # Upload measurements if measurements["upload_throughput"]: - bodies.append(self._maybe_with_sample_size( - { - "phase": "tunnel_ndt_upload", - "type": "ndt_upload", - "target_hostname": hostname, - "target_address": address, - "target_port": 443, # TODO(bassosimone): get actual port - "latency_ms": self._compute_percentiles( - measurements["upload_latency"] - ), - "speed_mbits": self._compute_percentiles( - measurements["upload_throughput"] - ), - "retransmission_percent": self._compute_percentiles( - measurements["upload_retransmission"] - ), - }, - self._round_sample_size( - len(measurements["upload_throughput"]) - ), - )) + bodies.append( + self._maybe_with_sample_size( + { + "phase": "tunnel_ndt_upload", + "type": "ndt_upload", + "target_hostname": server.hostname, + "target_address": server.address, + "target_port": 443, # TODO(bassosimone): get actual port + "latency_ms": self._compute_percentiles( + measurements["upload_latency"] + ), + "speed_mbits": self._compute_percentiles( + measurements["upload_throughput"] + ), + "retransmission_percent": self._compute_percentiles( + measurements["upload_retransmission"] + ), + }, + self._round_sample_size(len(measurements["upload_throughput"])), + ) + ) + return bodies def _create_bodies(self, state: AggregateEndpointState) -> List[Dict]: @@ -215,8 +229,12 @@ class OONISerializer: time_window=model.AggregationTimeWindow( from_time=state.window_start, to_time=state.window_end ), - endpoint_hostname=state.hostname if state.scope == model.Scope.ENDPOINT else None, - endpoint_address=state.address if state.scope == model.Scope.ENDPOINT else None, + endpoint_hostname=( + state.hostname if state.scope == model.Scope.ENDPOINT else None + ), + endpoint_address=( + state.address if state.scope == model.Scope.ENDPOINT else None + ), endpoint_port=state.port if state.scope == model.Scope.ENDPOINT else None, asn=state.asn, cc=state.cc, -- GitLab From 32453fdc2b13e46878386ca580842c91fbdf6859 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Thu, 12 Dec 2024 14:03:11 +0100 Subject: [PATCH 12/75] feat(oonisubmitter): implement time window management --- oonisubmitter/policy.py | 77 +++++++++++++ oonisubmitter/submitter.py | 216 ++++++++++++++++++++++--------------- 2 files changed, 204 insertions(+), 89 deletions(-) create mode 100644 oonisubmitter/policy.py diff --git a/oonisubmitter/policy.py b/oonisubmitter/policy.py new file mode 100644 index 0000000..6e20bb2 --- /dev/null +++ b/oonisubmitter/policy.py @@ -0,0 +1,77 @@ +""" +Manages different policies for submitting OONI measurements. +""" + +from abc import ABC, abstractmethod +from datetime import datetime, timedelta, timezone +from typing import List, Tuple + +# Define project start as a module constant +PROJECT_START = datetime(2024, 1, 1, tzinfo=timezone.utc) + + +class WindowPolicy(ABC): + """ + The abstract policy for dividing measurements into timed windows. + """ + + def generate_windows( + self, last_submission: datetime, current_time: datetime + ) -> List[Tuple[datetime, datetime]]: + """ + Generate all windows between last_submission and current_time + that do not include the current time. + + Uses half-open intervals, i.e., [start, end). + """ + windows = [] + window_start = last_submission + + while window_start < current_time: + window_end = self._compute_window_end(window_start) + + # If this window contains current_time, it's still open + if current_time <= window_end: + break + + windows.append((window_start, window_end)) + window_start = window_end + + return windows + + @abstractmethod + def _compute_window_end(self, start: datetime) -> datetime: + """Compute the window end for a window starting at start""" + pass + + +class WeeklyWindowPolicy(WindowPolicy): + """A WindowPolicy organizing measurements in weekly buckets.""" + + def _compute_window_end(self, start: datetime) -> datetime: + days_until_sunday = (7 - start.isoweekday()) % 7 + return (start + timedelta(days=days_until_sunday)).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + + +class DailyWindowPolicy(WindowPolicy): + """A WindowPolicy organizing measurements in daily buckets.""" + + def _compute_window_end(self, start: datetime) -> datetime: + return start.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta( + days=1 + ) + +class HourlyWindowPolicy(WindowPolicy): + """ + A WindowPolicy organizing measurements in hourly buckets. + + Primarily intended for testing, but could also be useful for + high-frequency measurement scenarios. + """ + + def _compute_window_end(self, start: datetime) -> datetime: + return start.replace( + minute=0, second=0, microsecond=0 + ) + timedelta(hours=1) diff --git a/oonisubmitter/submitter.py b/oonisubmitter/submitter.py index d166fe4..670ccc1 100644 --- a/oonisubmitter/submitter.py +++ b/oonisubmitter/submitter.py @@ -1,139 +1,177 @@ """ -Entry point for the aggregate-endpoint-metrics OONI measurement submitter. +Entry point for the aggregate-tunnel-metrics OONI measurement submitter. -Depends on: +This module provides the main interface for submitting aggregate +tunnel metrics to OONI. It handles: -1. AggregationConfig specifying what we are aggregating; +1. Measurement aggregation and submission -2. file path of the state file, tracking the last field testing -measurement we considered for aggregation; +2. State management -3. FieldTestingCSVFile (or files) representing the field -testing measurements to aggregate; +3. Time window processing -4. the window of time that we're aggregating. +Example usage: -Uses: - -1. lockedfile.read and lockedfile.write to read and write the state file. + submitter = OONISubmitter( + config=AggregatorConfig(...), + state_file="/var/lib/oonisubmitter/state.json", + window_policy=WeeklyWindowPolicy() + ) + submitter.process_csv("/path/to/metrics.csv") See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md +for more information on the aggregate-tunnel-metrics OONI-compatible data format. """ -from datetime import datetime -from typing import Dict, List +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Optional import json -import logging # TODO(bassosimone): I think we should consider creating # proper python modules eventually :thinking: from aggregator import AggregatorConfig, EndpointAggregator -from model import FieldTestingCSVEntry, FieldTestingCSVFile, OONIMeasurement -from serializer import OONISerializer, SerializationConfigError +from model import FieldTestingCSVFile, OONIMeasurement +from policy import WindowPolicy, PROJECT_START +from serializer import OONISerializer import lockedfile +@dataclass +class SubmitterState: + """Persistent state of the submitter.""" + + next_submission_after: Optional[datetime] = None + + # TODO(bassosimone): consider using a more portable data format + # for storing the datetime, e.g., `20060102T150405Z`. + + @classmethod + def from_dict(cls, data: dict) -> "SubmitterState": + """Create from dictionary.""" + return cls( + next_submission_after=( + datetime.fromisoformat(data["next_submission_after"]) + if data.get("next_submission_after") + else None + ) + ) + + def to_dict(self) -> dict: + """Convert to dictionary.""" + return { + "next_submission_after": ( + self.next_submission_after.isoformat() + if self.next_submission_after + else None + ) + } + + class OONISubmitter: """ - Manages the submission of OONI measurements, including state tracking - and serialization. + Manages the submission of OONI measurements, including state + tracking and serialization, and window management. + + Args: + config: Configuration for aggregation + state_file: Path to persistent state file + window_policy: The window policy """ - def __init__(self, config: AggregatorConfig, state_file: str): - """ - Initializes the submitter with the given configuration and state file. - """ + def __init__( + self, + config: AggregatorConfig, + state_file: str, + window_policy: WindowPolicy, + ): self.config = config self.state_file = state_file - self.state = self._load_state() self.serializer = OONISerializer(config) + self.window_policy = window_policy + self.state = self._load_state() - def _load_state(self) -> Dict: + def _load_state(self) -> SubmitterState: """ - Load submission state from disk using lockedfile.read. + Load SubmitterState from disk using lockedfile.read. """ try: - return json.loads(lockedfile.read(self.state_file)) - except FileNotFoundError: - return {"last_submitted": None} - except IOError: - return {"last_submitted": None} - except json.JSONDecodeError: - logging.warning("invalid state file, resetting") - return {"last_submitted": None} + data = json.loads(lockedfile.read(self.state_file)) + return SubmitterState.from_dict(data) + except (FileNotFoundError, json.JSONDecodeError, IOError): + return SubmitterState() def _save_state(self): """ Save current state to disk using lockedfile.write. """ - lockedfile.write(self.state_file, json.dumps(self.state)) + lockedfile.write(self.state_file, json.dumps(self.state.to_dict())) + + def _submit(self, measurement: OONIMeasurement) -> None: + """ + Submit measurement to OONI. + """ + # TODO(bassosimone): implement the actual submission logic. + # For now we just print what we would do for debugging. + print(f"Would submit: {measurement.as_dict()}") - def _should_process(self, entry: FieldTestingCSVEntry) -> bool: + def process_csv(self, csv_path: str) -> None: """ - Determine whether we should process this entry based on - last submission state. + Process CSV file, submitting measurements for complete windows. Args: - entry: The CSV entry to check + csv_path: Path to CSV file to process - Returns: - True if entry should be processed, False otherwise + The submitter maintains state about the last processed window + and will continue from there on subsequent runs. """ - # If we've never submitted anything, just process everything - if self.state["last_submitted"] is None: - return True - - # Parse last submission time from state - last_submitted = datetime.fromisoformat(self.state["last_submitted"]) - - # Only process entries newer than our last submission - return entry.date > last_submitted - - # TODO(bassosimone): the current method is not super - # ergonomic for the following reasons: - # - # 1. we should probably load the CSV file(s) ourselves - # - # 2. we should auto-compute the time window - # - # For now, I am leaving it all as-is since we need - # this function for integration testing and we will - # need to iterate on the top-level API anyway: we - # still need to specify how to properly hook in the - # rest of the ETL pipeline. - - def process_csv_file( + csv_file = FieldTestingCSVFile(csv_path) + + # Start from last submission time or the default project start + start_time = self.state.next_submission_after + if not start_time: + start_time = PROJECT_START + + # Get the complete windows for which we need to submit + current_time = datetime.now(timezone.utc) + windows = self.window_policy.generate_windows(start_time, current_time) + + # Process each complete window in sequence + for window_start, window_end in windows: + self._process_with_window(csv_file, window_start, window_end) + # Update state after each window + self.state.next_submission_after = window_end + self._save_state() + + def _process_with_window( self, csv_file: FieldTestingCSVFile, - window_start: datetime, - window_end: datetime, - ) -> List[OONIMeasurement]: + start_time: datetime, + end_time: datetime, + ): """ - Process CSV file and return JSON strings for new measurements. - - Return the raw measurements to submit. + Process entries within a specific time window. - Note: DOES NOT update the state, which should be updated once the - measurements have been successfully submitted. + Args: + csv_file: The CSV file containing entries + start_time: Window start (inclusive) + end_time: Window end (exclusive) """ - aggregator = EndpointAggregator(self.config, window_start, window_end) + # Only consider entries in current window + window_entries = [ + entry for entry in csv_file.entries if start_time <= entry.date < end_time + ] + + # Only process if we have entries + if not window_entries: + return - # Only process the new field-testing entries - for entry in csv_file.entries: - if self._should_process(entry): - aggregator.update(entry) + aggregator = EndpointAggregator(self.config, start_time, end_time) + for entry in window_entries: + aggregator.update(entry) - # Generate the measurements - measurements: List[OONIMeasurement] = [] + # Create and submit measurements for endpoint in aggregator.endpoints.values(): - try: - measurement = self.serializer.serialize(endpoint) - measurements.append(measurement) - except SerializationConfigError: - logging.warning( - f"skipping endpoint {endpoint.hostname}: serialization config error" - ) - continue - - return measurements + measurement = self.serializer.serialize(endpoint) + self._submit(measurement) -- GitLab From ef5b1abd1ed1e7c84070ae32b30cadb8e453e582 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Thu, 12 Dec 2024 14:52:59 +0100 Subject: [PATCH 13/75] fix(submitter): hold file-based mutex while operating This ensures no-one could write the CSV file while we're reading it for the purpose of submitting. --- oonisubmitter/lockedfile.py | 34 ++++++++++++++++++++++++++++++++++ oonisubmitter/submitter.py | 8 ++++++++ 2 files changed, 42 insertions(+) diff --git a/oonisubmitter/lockedfile.py b/oonisubmitter/lockedfile.py index 1d0dca8..3f46823 100644 --- a/oonisubmitter/lockedfile.py +++ b/oonisubmitter/lockedfile.py @@ -39,6 +39,40 @@ class Config: sleep_interval: float = 0.1 +class Mutex: + """ + Provides mutual exclusion using a lock file and flock. + + The lock file persists between uses - only the lock itself + is acquired and released, not the file's existence. + """ + + def __init__(self, filepath: str): + self.filepath = filepath + self.filep = None + + def __enter__(self) -> 'Mutex': + """Acquire exclusive lock using flock""" + try: + # Open for read-write, create if doesn't exist + self.filep = open(self.filepath, 'a+') + fcntl.flock(self.filep.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + return self + except (IOError, BlockingIOError) as err: + if self.filep: + self.filep.close() + raise FileLockError(f"cannot acquire lock: {err}") + + def __exit__(self, *args): + """Release lock and close file""" + if self.filep: + try: + fcntl.flock(self.filep.fileno(), fcntl.LOCK_UN) + finally: + self.filep.close() + self.filep = None + + def read(filepath: str, config: Optional[Config] = None) -> str: """ Read entire file while holding a shared lock. diff --git a/oonisubmitter/submitter.py b/oonisubmitter/submitter.py index 670ccc1..2ee85fa 100644 --- a/oonisubmitter/submitter.py +++ b/oonisubmitter/submitter.py @@ -120,12 +120,20 @@ class OONISubmitter: """ Process CSV file, submitting measurements for complete windows. + This method is safe for concurrent access from multiple processes + using file-based locking. Only one process will process a given + CSV file at a time. The CSV file will be read atomically. + Args: csv_path: Path to CSV file to process The submitter maintains state about the last processed window and will continue from there on subsequent runs. """ + with lockedfile.Mutex(f"{csv_path}.lock"): + self._process_csv_locked(csv_path) + + def _process_csv_locked(self, csv_path: str) -> None: csv_file = FieldTestingCSVFile(csv_path) # Start from last submission time or the default project start -- GitLab From c860f77568c409012c9387759df446b9825f468c Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Thu, 12 Dec 2024 15:16:03 +0100 Subject: [PATCH 14/75] fix(oonisubmitter): improve handling of state file 1. make sure we emit an exception if we cannot read the state file because, e.g., is corrupted or because of I/O error 2. store the timestamp using a more compact format --- oonisubmitter/lockedfile.py | 4 ++-- oonisubmitter/policy.py | 5 ++--- oonisubmitter/submitter.py | 33 +++++++++++++++++++++------------ 3 files changed, 25 insertions(+), 17 deletions(-) diff --git a/oonisubmitter/lockedfile.py b/oonisubmitter/lockedfile.py index 3f46823..66b06c9 100644 --- a/oonisubmitter/lockedfile.py +++ b/oonisubmitter/lockedfile.py @@ -51,11 +51,11 @@ class Mutex: self.filepath = filepath self.filep = None - def __enter__(self) -> 'Mutex': + def __enter__(self) -> "Mutex": """Acquire exclusive lock using flock""" try: # Open for read-write, create if doesn't exist - self.filep = open(self.filepath, 'a+') + self.filep = open(self.filepath, "a+") fcntl.flock(self.filep.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) return self except (IOError, BlockingIOError) as err: diff --git a/oonisubmitter/policy.py b/oonisubmitter/policy.py index 6e20bb2..d78544b 100644 --- a/oonisubmitter/policy.py +++ b/oonisubmitter/policy.py @@ -63,6 +63,7 @@ class DailyWindowPolicy(WindowPolicy): days=1 ) + class HourlyWindowPolicy(WindowPolicy): """ A WindowPolicy organizing measurements in hourly buckets. @@ -72,6 +73,4 @@ class HourlyWindowPolicy(WindowPolicy): """ def _compute_window_end(self, start: datetime) -> datetime: - return start.replace( - minute=0, second=0, microsecond=0 - ) + timedelta(hours=1) + return start.replace(minute=0, second=0, microsecond=0) + timedelta(hours=1) diff --git a/oonisubmitter/submitter.py b/oonisubmitter/submitter.py index 2ee85fa..8d6c586 100644 --- a/oonisubmitter/submitter.py +++ b/oonisubmitter/submitter.py @@ -31,38 +31,45 @@ import json # TODO(bassosimone): I think we should consider creating # proper python modules eventually :thinking: from aggregator import AggregatorConfig, EndpointAggregator -from model import FieldTestingCSVFile, OONIMeasurement +from model import FieldTestingCSVFile, OONIMeasurement, datetime_to_compact_utc from policy import WindowPolicy, PROJECT_START from serializer import OONISerializer import lockedfile +class StateFileError(Exception): + """Raised when there are problems with the state file.""" + + @dataclass class SubmitterState: """Persistent state of the submitter.""" next_submission_after: Optional[datetime] = None - # TODO(bassosimone): consider using a more portable data format - # for storing the datetime, e.g., `20060102T150405Z`. - @classmethod def from_dict(cls, data: dict) -> "SubmitterState": """Create from dictionary.""" - return cls( - next_submission_after=( - datetime.fromisoformat(data["next_submission_after"]) - if data.get("next_submission_after") - else None + if not data.get("next_submission_after"): + return cls(next_submission_after=None) + + # Parse compact UTC format: YYYYMMDDThhmmssZ + try: + dt = datetime.strptime(data["next_submission_after"], "%Y%m%dT%H%M%SZ") + return cls(next_submission_after=dt.replace(tzinfo=timezone.utc)) + except ValueError as err: + raise StateFileError( + f"corrupt state file: invalid datetime format: {err}. " + "This error requires manual intervention to avoid " + "potential duplicate submissions." ) - ) def to_dict(self) -> dict: """Convert to dictionary.""" return { "next_submission_after": ( - self.next_submission_after.isoformat() + datetime_to_compact_utc(self.next_submission_after) if self.next_submission_after else None ) @@ -99,8 +106,10 @@ class OONISubmitter: try: data = json.loads(lockedfile.read(self.state_file)) return SubmitterState.from_dict(data) - except (FileNotFoundError, json.JSONDecodeError, IOError): + except FileNotFoundError: return SubmitterState() + except (json.JSONDecodeError, IOError) as err: + raise StateFileError(f"cannot load state file: {err}") def _save_state(self): """ -- GitLab From d3a15a4e4a473313313b74ccd51c7412ccf5ed51 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Thu, 12 Dec 2024 15:29:24 +0100 Subject: [PATCH 15/75] fix(oonisubmitter): further clarity in terms of error types --- oonisubmitter/submitter.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/oonisubmitter/submitter.py b/oonisubmitter/submitter.py index 8d6c586..72fbdb7 100644 --- a/oonisubmitter/submitter.py +++ b/oonisubmitter/submitter.py @@ -108,8 +108,14 @@ class OONISubmitter: return SubmitterState.from_dict(data) except FileNotFoundError: return SubmitterState() - except (json.JSONDecodeError, IOError) as err: - raise StateFileError(f"cannot load state file: {err}") + except json.JSONDecodeError as err: + raise StateFileError(f"corrupt state file: {err}") + except IOError as err: + # Don't risk duplicates if file exists but can't be read + raise IOError( + f"cannot access state file: {err}. Manual intervention required " + "to avoid potential duplicate submissions." + ) def _save_state(self): """ -- GitLab From 4b5b1734e2db4a920e549c5da3f1d3b89a27dece Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Thu, 12 Dec 2024 17:34:49 +0100 Subject: [PATCH 16/75] feat(oonisubmitter): submit via OONI API This set of changes probably concludes the effort of implementing the core functionality needed to submit aggregate tunnel metrics. From now on, it's a matter of refining, writing tests, etc. --- oonisubmitter/aggregator.py | 3 + oonisubmitter/model.py | 2 + oonisubmitter/ooniapi.py | 154 ++++++++++++++++++++++++++++++++++++ oonisubmitter/serializer.py | 2 + oonisubmitter/submitter.py | 20 ++++- 5 files changed, 178 insertions(+), 3 deletions(-) create mode 100644 oonisubmitter/ooniapi.py diff --git a/oonisubmitter/aggregator.py b/oonisubmitter/aggregator.py index 457bef9..bdaa7c1 100644 --- a/oonisubmitter/aggregator.py +++ b/oonisubmitter/aggregator.py @@ -37,6 +37,9 @@ class AggregatorConfig: # rounding sample_size to the nearest round_to round_to: int = 100 + software_name: str = "leap/aggregate-tunnel-metrics" + software_version: str = "0.1.0" + @dataclass class AggregateEndpointState: diff --git a/oonisubmitter/model.py b/oonisubmitter/model.py index 8ddbfc0..b4f818f 100644 --- a/oonisubmitter/model.py +++ b/oonisubmitter/model.py @@ -297,6 +297,8 @@ class OONIMeasurement: measurement_start_time: datetime probe_asn: str # Format: ^AS[0-9]+$ probe_cc: str # Format: ^[A-Z]{2}$ + software_name: str + software_version: str test_keys: AggregateTunnelMetricsTestKeys test_name: str test_runtime: float diff --git a/oonisubmitter/ooniapi.py b/oonisubmitter/ooniapi.py new file mode 100644 index 0000000..4a9a097 --- /dev/null +++ b/oonisubmitter/ooniapi.py @@ -0,0 +1,154 @@ +""" +OONI API client for submitting measurements to collectors. + +Implements the OONI collector protocol v3.0.0. + +See: https://github.com/ooni/spec/blob/master/backends/bk-003-collector.md +""" + +from dataclasses import dataclass +from typing import Optional + +import json +import urllib.request +import urllib.error + +import model + + +class OONIAPIError(Exception): + """Raised when there are OONI API errors.""" + + +@dataclass +class Config: + """Configuration for OONI API client""" + + collector_base_url: str # e.g., "https://api.ooni.io/" + timeout: float = 30.0 + + +class Client: + """ + Client for interacting with OONI collectors implementing v3.0.0 + of the specification. + """ + + # TODO(bassosimone): consider support for retries. The spec + # says the following about retries: + # + # > A client side implementation MAY retry any failing collector + # > operation immediately for three times in case there is a + # > DNS or TCP error. [...] If all these immediate retries fail, + # > then the client SHOULD arrange for resubmitting the + # > measurement at a later time + + def __init__(self, config: Config): + self.config = config + + def create_report_from_measurement( + self, + measurement: model.OONIMeasurement, + ) -> str: + """Convenience method to create report from existing measurement.""" + return self.create_report( + test_name=measurement.test_name, + test_version=measurement.test_version, + software_name=measurement.software_name, + software_version=measurement.software_version, + probe_asn=measurement.probe_asn, + probe_cc=measurement.probe_cc, + ) + + def create_report( + self, + test_name: str, + test_version: str, + software_name: str, + software_version: str, + probe_asn: str, + probe_cc: str, + ) -> str: + """ + Create a new report. + + Returns: + str: Report ID to use for submitting measurements + + Raises: + OONIAPIError: If submission fails + URLError: For network/DNS issues + """ + report = { + "data_format_version": "0.2.0", + "format": "json", + "probe_asn": probe_asn, + "probe_cc": probe_cc, + "software_name": software_name, + "software_version": software_version, + "test_name": test_name, + "test_version": test_version, + } + + data = json.dumps(report).encode("utf-8") + req = urllib.request.Request( + f"{self.config.collector_base_url}/report", + data=data, + headers={"Content-Type": "application/json"}, + method="POST", + ) + + try: + with urllib.request.urlopen(req, timeout=self.config.timeout) as resp: + if resp.status != 200: + raise OONIAPIError(f"unexpected status: {resp.status}") + response = json.loads(resp.read().decode()) + if "report_id" not in response: + raise OONIAPIError("missing report_id in response") + return response["report_id"] + + except urllib.error.HTTPError as err: + raise OONIAPIError(f"HTTP error: {err}") + + def update_report( + self, + report_id: str, + measurement: model.OONIMeasurement, + ) -> Optional[str]: + """ + Update a report by adding a measurement. + + Args: + report_id: The ID returned by create_report() + measurement: The measurement to submit + + Returns: + Optional[str]: measurement_id if provided by server + + Raises: + OONIAPIError: If submission fails + URLError: For network/DNS issues + """ + data = json.dumps({"format": "json", "content": measurement.as_dict()}).encode( + "utf-8" + ) + + req = urllib.request.Request( + f"{self.config.collector_base_url}/report/{report_id}", + data=data, + headers={"Content-Type": "application/json"}, + method="POST", + ) + + try: + with urllib.request.urlopen(req, timeout=self.config.timeout) as resp: + if resp.status != 200: + raise OONIAPIError(f"unexpected status: {resp.status}") + try: + response = json.loads(resp.read().decode()) + return response.get("measurement_id") + except json.JSONDecodeError: + return None + + except urllib.error.HTTPError as err: + raise OONIAPIError(f"HTTP error: {err}") diff --git a/oonisubmitter/serializer.py b/oonisubmitter/serializer.py index b4ebe26..44e2f24 100644 --- a/oonisubmitter/serializer.py +++ b/oonisubmitter/serializer.py @@ -248,6 +248,8 @@ class OONISerializer: measurement_start_time=measurement_time, probe_asn=self.config.probe_asn, probe_cc=self.config.probe_cc, + software_name=self.config.software_name, + software_version=self.config.software_version, test_keys=test_keys, test_name="aggregate_tunnel_metrics", test_runtime=0.0, diff --git a/oonisubmitter/submitter.py b/oonisubmitter/submitter.py index 72fbdb7..976de57 100644 --- a/oonisubmitter/submitter.py +++ b/oonisubmitter/submitter.py @@ -13,6 +13,7 @@ tunnel metrics to OONI. It handles: Example usage: submitter = OONISubmitter( + client=ooniapi.Client(...), config=AggregatorConfig(...), state_file="/var/lib/oonisubmitter/state.json", window_policy=WeeklyWindowPolicy() @@ -36,6 +37,7 @@ from policy import WindowPolicy, PROJECT_START from serializer import OONISerializer import lockedfile +import ooniapi class StateFileError(Exception): @@ -89,15 +91,18 @@ class OONISubmitter: def __init__( self, + client: ooniapi.Client, config: AggregatorConfig, state_file: str, window_policy: WindowPolicy, ): + self.client = client self.config = config self.state_file = state_file self.serializer = OONISerializer(config) self.window_policy = window_policy self.state = self._load_state() + self._cur_report_id: Optional[str] = None def _load_state(self) -> SubmitterState: """ @@ -127,9 +132,18 @@ class OONISubmitter: """ Submit measurement to OONI. """ - # TODO(bassosimone): implement the actual submission logic. - # For now we just print what we would do for debugging. - print(f"Would submit: {measurement.as_dict()}") + # Create a report ID on first usage. + # + # TODO(bassosimone): should we generate a new report ID for + # each submitted measurement? Maybe this would simplify indexing + # via the OONI API, but I am not sure about this. + if not self._cur_report_id: + self._cur_report_id = self.client.create_report_from_measurement( + measurement + ) + + # Update report and submit the measurement. + self.client.update_report(self._cur_report_id, measurement) def process_csv(self, csv_path: str) -> None: """ -- GitLab From 8bf657d4adea2f5f87efacef209e0f7be5d5b59c Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Tue, 17 Dec 2024 15:34:52 +0100 Subject: [PATCH 17/75] doc(oonisubmitter): update DESIGN w/ @powerpuffin comments This commit updates the interim DESIGN file to take into account the comments and feedback provided by @powerpuffin. --- oonisubmitter/DESIGN.md | 57 ++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/oonisubmitter/DESIGN.md b/oonisubmitter/DESIGN.md index cb3ad6f..434f47d 100644 --- a/oonisubmitter/DESIGN.md +++ b/oonisubmitter/DESIGN.md @@ -30,31 +30,35 @@ IIUC, the existing ETL pipeline has this structure: ```ascii [field tests] -> [logs/*.txt] -> [parse_all_logs.py (batch)] -> [metrics_all_logs.csv] -> [dashboard.py (dedup)] - ^ - | - [file-monitor (fastpath)] -> [prometheus metrics] + +[field tests] -> [logs/*.txt] -> [parse_log.py (fastpath)] -> [metrics_all_logs.csv] -> [prometheus metrics] ``` Key characteristics: 1. Using columnar store (CSV). -2. `parse_all_logs.py` is the batch/reprocessing component. +2. `parse_all_logs.py` is the batch/reprocessing component, which only +runs when the `dashboard` container is (re)started. -3. `file-monitor` is the fastpath/streaming component. +3. `file-monitor` is the fastpath/streaming component, which only runs +when new field testing logs are submitted. 4. Two current outputs: - Dashboard visualization (via pandas); - Grafana dashboard (via pushgateway and prometheus metrics). -5. Deduplication happens late (`dashboard.py`). +5. Deduplication happens late (`dashboard.py`), and *does not* seem +to happen in the fastpath/streaming component. 6. `docker-compose` used to put Python scripts together. -7. No CSV file locking, unless I am missing some detail here. +7. No CSV file locking, which may be a problem if we receive a +new log entry while the batch component is running. -8. Given the focus on reprocessing I assume the intent is to keep -the logfiles forever and preriodically reprocess them. +8. It is unclear whether the intent is to keep the logfiles +forever and reprocess them: given the small number of logs +available currently, reprocessing seemds an OK option. 9. I assume, given the current project structure, that scaling will be implemented by partitioning the CSV file across dates, or that we @@ -66,23 +70,21 @@ columnar format seems a good fit for this kind of data. ## Topics to Discuss 1. I am adding 3-4 new files to the pipeline and I am following -the existing coding style of not using `__init__.py` but I think -we should discuss whether the amount of code I am adding could -suggest we start adding some `__init__.py` and modules. +the existing coding style of not using `__init__.py` and it seems +there is rough consensus around doing this. -2. IIUC, there is no file locking protecting the columnar data -store, so it is possible that multiple readers and writers could -race to update the datastore itself. I wonder whether it would -be a good idea to consider adding file locking. +2. there is no file locking protecting the columnar data store, +so it is possible that multiple readers and writers could race to +update the datastore itself. To address, this, there's rough +consensus around adding file locking. 3. I tried to design the OONI submitter as an independent component that could work also in case of time based partitioning of the CSV files, under the assumption that this would be the future expected growth trajectory. -4. To make OONI export and submission more robust, I suppose we -should see to move deduplication in the batch processing, rather -than inside the dashboard support script -- sounds good? +4. To make OONI export and submission more robust, we should +run deduplication *before* running the OONI exporter. 5. I think we need to introduce OONI-export-specific state that remembers the last processed log entry, to avoid considering them @@ -121,7 +123,8 @@ should work, but this seems a few months away from now. 1. Deduplication - Move to parse_all_logs.py - - Implement at CSV write time + - Implement at CSV write time (taking into account that also + the streaming component may need deduplication) 2. Scaling - Date-based CSV partitioning @@ -165,6 +168,9 @@ should work, but this seems a few months away from now. | endpoint_cc | Tunnel endpoint CC | (needs resolution) | | ping_target_address | Ping test target IP | ping command target | +Based on feedback, it seems we *should not* submit the OBFS4 server address +and port, because we want to keep them private. + ### CSV Field Mapping Notes @@ -206,9 +212,14 @@ the rest of the ETL pipeline first. 1. Is my understanding of the pipeline correct? 2. Is manual submission acceptable for initial implementation? -3. Are the proposed CSV-fields extensions acceptable? +3. Are the proposed CSV-fields extensions acceptable? It seems we *should +not* submit the OBFS4 server address and port, so we need to modify the +current implementation to make this possible. *Future considerations*: -4. Should we consider implementing file locking? -5. What do you think of moving the deduplication to the batch processing? +4. Should we consider implementing file locking? Yes! +5. What do you think of moving the deduplication to the batch processing? Yes, but +there are corner cases we need to consider. For example, duplicate submission of +the same content (with same starting date) but, for some reason, some extra content +being presend in the second field-testing logs submission. -- GitLab From 7908c5f47a0bd733f89ecd3daac0e7bab4c9775e Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Tue, 17 Dec 2024 15:41:05 +0100 Subject: [PATCH 18/75] refactor(oonisubmitter): add __init__.py This commit transforms `oonisubmitter` into a module by adding the `__init__.py` file to the directory. --- oonisubmitter/__init__.py | 8 ++++++++ oonisubmitter/aggregator.py | 6 ++---- oonisubmitter/ooniapi.py | 2 +- oonisubmitter/serializer.py | 8 +++----- oonisubmitter/submitter.py | 12 ++++++------ 5 files changed, 20 insertions(+), 16 deletions(-) create mode 100644 oonisubmitter/__init__.py diff --git a/oonisubmitter/__init__.py b/oonisubmitter/__init__.py new file mode 100644 index 0000000..b1890b8 --- /dev/null +++ b/oonisubmitter/__init__.py @@ -0,0 +1,8 @@ +""" +Support for submitting aggregate-tunnel metrics as OONI +measurements using the OONI collector API. +""" + +from submitter import OONISubmitter + +__all__ = ["OONISubmitter"] diff --git a/oonisubmitter/aggregator.py b/oonisubmitter/aggregator.py index bdaa7c1..36adaf6 100644 --- a/oonisubmitter/aggregator.py +++ b/oonisubmitter/aggregator.py @@ -9,10 +9,8 @@ from dataclasses import dataclass, field from datetime import datetime from typing import Dict, List -# TODO(bassosimone): I think we should consider creating -# proper python modules eventually :thinking: -from identifiers import EndpointID, NDTServerID -import model +from .identifiers import EndpointID, NDTServerID +from . import model @dataclass diff --git a/oonisubmitter/ooniapi.py b/oonisubmitter/ooniapi.py index 4a9a097..f24f74e 100644 --- a/oonisubmitter/ooniapi.py +++ b/oonisubmitter/ooniapi.py @@ -13,7 +13,7 @@ import json import urllib.request import urllib.error -import model +from . import model class OONIAPIError(Exception): diff --git a/oonisubmitter/serializer.py b/oonisubmitter/serializer.py index 44e2f24..ecf7c6a 100644 --- a/oonisubmitter/serializer.py +++ b/oonisubmitter/serializer.py @@ -9,11 +9,9 @@ from typing import Any, Dict, List, Optional from statistics import quantiles from urllib.parse import urlunparse, urlencode -# TODO(bassosimone): I think we should consider creating -# proper python modules eventually :thinking: -from identifiers import NDTServerID -from aggregator import AggregateEndpointState, AggregatorConfig -import model +from .identifiers import NDTServerID +from .aggregator import AggregateEndpointState, AggregatorConfig +from . import model class SerializationConfigError(Exception): diff --git a/oonisubmitter/submitter.py b/oonisubmitter/submitter.py index 976de57..b3679fc 100644 --- a/oonisubmitter/submitter.py +++ b/oonisubmitter/submitter.py @@ -31,13 +31,13 @@ import json # TODO(bassosimone): I think we should consider creating # proper python modules eventually :thinking: -from aggregator import AggregatorConfig, EndpointAggregator -from model import FieldTestingCSVFile, OONIMeasurement, datetime_to_compact_utc -from policy import WindowPolicy, PROJECT_START -from serializer import OONISerializer +from .aggregator import AggregatorConfig, EndpointAggregator +from .model import FieldTestingCSVFile, OONIMeasurement, datetime_to_compact_utc +from .policy import WindowPolicy, PROJECT_START +from .serializer import OONISerializer -import lockedfile -import ooniapi +from . import lockedfile +from . import ooniapi class StateFileError(Exception): -- GitLab From 3c3448c7db1159a9dcf5e40673d84f53dd885b1c Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Tue, 17 Dec 2024 16:09:26 +0100 Subject: [PATCH 19/75] fix(oonisubmitter): make tests pass --- oonisubmitter/__init__.py | 2 +- oonisubmitter/aggregator.py | 4 ++-- test/__init__.py | 0 test/oonisubmitter/__init__.py | 0 .../oonisubmitter/test_aggregator.py | 21 ++++++++++++------- .../testdata/expected_state.json | 4 ++-- .../oonisubmitter}/testdata/sample.csv | 0 7 files changed, 19 insertions(+), 12 deletions(-) create mode 100644 test/__init__.py create mode 100644 test/oonisubmitter/__init__.py rename oonisubmitter/aggregator_test.py => test/oonisubmitter/test_aggregator.py (78%) rename {oonisubmitter => test/oonisubmitter}/testdata/expected_state.json (97%) rename {oonisubmitter => test/oonisubmitter}/testdata/sample.csv (100%) diff --git a/oonisubmitter/__init__.py b/oonisubmitter/__init__.py index b1890b8..d89ad7a 100644 --- a/oonisubmitter/__init__.py +++ b/oonisubmitter/__init__.py @@ -3,6 +3,6 @@ Support for submitting aggregate-tunnel metrics as OONI measurements using the OONI collector API. """ -from submitter import OONISubmitter +from .submitter import OONISubmitter __all__ = ["OONISubmitter"] diff --git a/oonisubmitter/aggregator.py b/oonisubmitter/aggregator.py index 36adaf6..72f4d6c 100644 --- a/oonisubmitter/aggregator.py +++ b/oonisubmitter/aggregator.py @@ -188,7 +188,7 @@ class AggregateEndpointState: "provider": self.provider, "probe_asn": self.probe_asn, "probe_cc": self.probe_cc, - "scope": self.scope, + "scope": self.scope.value, # Time window "window_start": model.datetime_to_compact_utc(self.window_start), "window_end": model.datetime_to_compact_utc(self.window_end), @@ -264,7 +264,7 @@ class EndpointAggregator: "upstream_collector": self.config.upstream_collector, "probe_asn": self.config.probe_asn, "probe_cc": self.config.probe_cc, - "scope": self.config.scope, + "scope": self.config.scope.value, }, # Time window "window_start": model.datetime_to_compact_utc(self.window_start), diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/oonisubmitter/__init__.py b/test/oonisubmitter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/oonisubmitter/aggregator_test.py b/test/oonisubmitter/test_aggregator.py similarity index 78% rename from oonisubmitter/aggregator_test.py rename to test/oonisubmitter/test_aggregator.py index d4a3498..97207f1 100755 --- a/oonisubmitter/aggregator_test.py +++ b/test/oonisubmitter/test_aggregator.py @@ -2,13 +2,18 @@ """Basic tests for the endpoint aggregator""" -import json -import unittest from datetime import datetime, timezone from difflib import unified_diff -import aggregator -import model +import json +import os +import unittest + +from oonisubmitter import aggregator +from oonisubmitter import model + + +_TESTROOT = os.path.dirname(os.path.abspath(__file__)) class TestAggregator(unittest.TestCase): @@ -17,7 +22,9 @@ class TestAggregator(unittest.TestCase): def test_aggregator(self): """Verify that the aggregator correctly processes sample data""" # Load test data - csv_file = model.FieldTestingCSVFile("testdata/sample.csv") + csv_file = model.FieldTestingCSVFile( + os.path.join(_TESTROOT, "testdata", "sample.csv") + ) # Configure aggregator config = aggregator.AggregatorConfig( @@ -36,8 +43,8 @@ class TestAggregator(unittest.TestCase): # Compare with expected state actual_lines = json.dumps(aggr.to_dict(), indent=4).splitlines() - with open("testdata/expected_state.json") as f: - expected_lines = f.read().splitlines() + with open(os.path.join(_TESTROOT, "testdata", "expected_state.json")) as filep: + expected_lines = filep.read().splitlines() if actual_lines != expected_lines: diff = list( diff --git a/oonisubmitter/testdata/expected_state.json b/test/oonisubmitter/testdata/expected_state.json similarity index 97% rename from oonisubmitter/testdata/expected_state.json rename to test/oonisubmitter/testdata/expected_state.json index 30bebf7..f95493c 100644 --- a/oonisubmitter/testdata/expected_state.json +++ b/test/oonisubmitter/testdata/expected_state.json @@ -43,7 +43,7 @@ } }, "ndt_measurements": { - "ndt1.test.org:10.0.0.1": { + "ndt1.test.org|10.0.0.1": { "download_throughput": [ 100.0 ], @@ -98,7 +98,7 @@ } }, "ndt_measurements": { - "ndt1.test.org:10.0.0.1": { + "ndt1.test.org|10.0.0.1": { "download_throughput": [ 120.0 ], diff --git a/oonisubmitter/testdata/sample.csv b/test/oonisubmitter/testdata/sample.csv similarity index 100% rename from oonisubmitter/testdata/sample.csv rename to test/oonisubmitter/testdata/sample.csv -- GitLab From 61dfaee7bc30f6d6cf79326b4f0f7d0bed6e5bc7 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Tue, 17 Dec 2024 17:57:58 +0100 Subject: [PATCH 20/75] feat(oonisubmitter): add tests for identifier.py --- oonisubmitter/aggregator.py | 4 +- oonisubmitter/identifiers.py | 46 ++++++++-- test/oonisubmitter/test_identifier.py | 122 ++++++++++++++++++++++++++ 3 files changed, 161 insertions(+), 11 deletions(-) create mode 100644 test/oonisubmitter/test_identifier.py diff --git a/oonisubmitter/aggregator.py b/oonisubmitter/aggregator.py index 72f4d6c..c6344b1 100644 --- a/oonisubmitter/aggregator.py +++ b/oonisubmitter/aggregator.py @@ -9,7 +9,7 @@ from dataclasses import dataclass, field from datetime import datetime from typing import Dict, List -from .identifiers import EndpointID, NDTServerID +from .identifiers import BridgeEndpointID, NDTServerID from . import model @@ -215,7 +215,7 @@ class EndpointAggregator: def _make_key(self, entry: model.FieldTestingCSVEntry) -> str: """Create unique key for an endpoint""" return str( - EndpointID( + BridgeEndpointID( hostname=entry.endpoint_hostname, address=entry.endpoint_address, port=entry.endpoint_port, diff --git a/oonisubmitter/identifiers.py b/oonisubmitter/identifiers.py index 5149c22..32120b7 100644 --- a/oonisubmitter/identifiers.py +++ b/oonisubmitter/identifiers.py @@ -1,4 +1,6 @@ -"""Common identifiers used across the pipeline.""" +""" +Common identifiers used across the pipeline. +""" from dataclasses import dataclass @@ -7,39 +9,65 @@ SEPARATOR = "|" @dataclass class NDTServerID: - """NDT server identifier.""" + """ + NDT server identifier. + """ + # The hostname used by an NDT server (e.g., `ndt.example.com`). hostname: str + + # The NDT server's IP address (e.g., `1.2.3.4`). address: str def __str__(self) -> str: - """String representation for dictionary keys.""" + """ + String representation of the NDT server identifier. + """ return f"{self.hostname}{SEPARATOR}{self.address}" @classmethod def parse(cls, identifier: str) -> "NDTServerID": - """Parse identifier string into a NDTServerID.""" + """ + Parse identifier string into a NDTServerID. + """ hostname, address = identifier.split(SEPARATOR, 1) + if hostname == "" or address == "": + raise ValueError("Empty fields in identifier") return cls(hostname=hostname, address=address) @dataclass -class EndpointID: - """Endpoint identifier.""" +class BridgeEndpointID: + """ + Bridge endpoint identifier. + """ + # hostname used by the bridge endpoint (e.g., `bridge.example.com`). hostname: str + + # IP address of the bridge endpoint (e.g., `1.2.3.4`) address: str + + # Port number used by the bridge endpoint (e.g., `443`). port: int + + # Protocol used by the bridge endpoint (e.g., `obfs4+kcp`). protocol: str def __str__(self) -> str: - """String representation for dictionary keys.""" + """ + String representation for dictionary keys. + """ return f"{self.hostname}{SEPARATOR}{self.address}{SEPARATOR}{self.port}{SEPARATOR}{self.protocol}" @classmethod - def parse(cls, identifier: str) -> "EndpointID": - """Parse identifier string into an Endpoint.""" + def parse(cls, identifier: str) -> "BridgeEndpointID": + """ + Parse identifier string into an Endpoint. + """ hostname, address, port, protocol = identifier.split(SEPARATOR, 3) + if hostname == "" or address == "" or port == "" or protocol == "": + raise ValueError("Empty fields in identifier") return cls( hostname=hostname, address=address, port=int(port), protocol=protocol ) diff --git a/test/oonisubmitter/test_identifier.py b/test/oonisubmitter/test_identifier.py new file mode 100644 index 0000000..135131f --- /dev/null +++ b/test/oonisubmitter/test_identifier.py @@ -0,0 +1,122 @@ +"""Tests for identifier classes used in the pipeline.""" + +import unittest + +from oonisubmitter.identifiers import ( + SEPARATOR, + NDTServerID, + BridgeEndpointID, +) + + +class TestNDTServerID(unittest.TestCase): + """Test suite for NDTServerID class.""" + + def test_create_valid_server(self): + """Test creating a valid NDTServerID.""" + server = NDTServerID(hostname="ndt.example.com", address="1.2.3.4") + self.assertEqual(server.hostname, "ndt.example.com") + self.assertEqual(server.address, "1.2.3.4") + + def test_string_representation(self): + """Test string serialization of NDTServerID.""" + server = NDTServerID(hostname="ndt.example.com", address="1.2.3.4") + expected = f"ndt.example.com{SEPARATOR}1.2.3.4" + self.assertEqual(str(server), expected) + + def test_parse_valid_string(self): + """Test parsing a valid string into NDTServerID.""" + identifier = f"ndt.example.com{SEPARATOR}1.2.3.4" + server = NDTServerID.parse(identifier) + self.assertEqual(server.hostname, "ndt.example.com") + self.assertEqual(server.address, "1.2.3.4") + + def test_roundtrip_conversion(self): + """Test server->string->server conversion preserves data.""" + original = NDTServerID(hostname="ndt.example.com", address="1.2.3.4") + serialized = str(original) + parsed = NDTServerID.parse(serialized) + self.assertEqual(parsed, original) + + def test_parse_invalid_format(self): + """Test parsing string with wrong format raises ValueError.""" + with self.assertRaises(ValueError): + NDTServerID.parse("invalid") + + def test_parse_empty_fields(self): + """Test parsing string with empty fields.""" + identifier = f"{SEPARATOR}" + with self.assertRaises(ValueError): + NDTServerID.parse(identifier) + + +class TestBridgeEndpointID(unittest.TestCase): + """Test suite for BridgeEndpointID class.""" + + def test_create_valid_endpoint(self): + """Test creating a valid BridgeEndpointID.""" + endpoint = BridgeEndpointID( + hostname="bridge.example.com", + address="1.2.3.4", + port=443, + protocol="obfs4", + ) + self.assertEqual(endpoint.hostname, "bridge.example.com") + self.assertEqual(endpoint.address, "1.2.3.4") + self.assertEqual(endpoint.port, 443) + self.assertEqual(endpoint.protocol, "obfs4") + + def test_string_representation(self): + """Test string serialization of BridgeEndpointID.""" + endpoint = BridgeEndpointID( + hostname="bridge.example.com", + address="1.2.3.4", + port=443, + protocol="obfs4", + ) + expected = f"bridge.example.com{SEPARATOR}1.2.3.4{SEPARATOR}443{SEPARATOR}obfs4" + self.assertEqual(str(endpoint), expected) + + def test_parse_valid_string(self): + """Test parsing a valid string into BridgeEndpointID.""" + identifier = f"bridge.example.com{SEPARATOR}1.2.3.4{SEPARATOR}443{SEPARATOR}obfs4" + endpoint = BridgeEndpointID.parse(identifier) + self.assertEqual(endpoint.hostname, "bridge.example.com") + self.assertEqual(endpoint.address, "1.2.3.4") + self.assertEqual(endpoint.port, 443) + self.assertEqual(endpoint.protocol, "obfs4") + + def test_roundtrip_conversion(self): + """Test endpoint->string->endpoint conversion preserves data.""" + original = BridgeEndpointID( + hostname="bridge.example.com", + address="1.2.3.4", + port=443, + protocol="obfs4", + ) + serialized = str(original) + parsed = BridgeEndpointID.parse(serialized) + self.assertEqual(parsed, original) + + def test_parse_invalid_format(self): + """Test parsing string with wrong format raises ValueError.""" + with self.assertRaises(ValueError): + BridgeEndpointID.parse("invalid") + + def test_parse_invalid_port(self): + """Test parsing string with non-integer port raises ValueError.""" + identifier = f"bridge.example.com{SEPARATOR}1.2.3.4{SEPARATOR}invalid{SEPARATOR}obfs4" + with self.assertRaises(ValueError): + BridgeEndpointID.parse(identifier) + + def test_parse_missing_fields(self): + """Test parsing string with missing fields raises ValueError.""" + identifier = f"bridge.example.com{SEPARATOR}1.2.3.4{SEPARATOR}443" + with self.assertRaises(ValueError): + BridgeEndpointID.parse(identifier) + + def test_parse_empty_fields(self): + """Test parsing string with empty fields.""" + identifier = f"{SEPARATOR}{SEPARATOR}{SEPARATOR}" + with self.assertRaises(ValueError): + BridgeEndpointID.parse(identifier) -- GitLab From 983edd1cc603536ef48aee545ff9d447e0fc43ba Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Tue, 17 Dec 2024 17:58:16 +0100 Subject: [PATCH 21/75] fix: use conventional name `tests` for tests Using `test` is less common. --- {test => tests}/__init__.py | 0 {test => tests}/oonisubmitter/__init__.py | 0 {test => tests}/oonisubmitter/test_aggregator.py | 0 {test => tests}/oonisubmitter/test_identifier.py | 0 {test => tests}/oonisubmitter/testdata/expected_state.json | 0 {test => tests}/oonisubmitter/testdata/sample.csv | 0 6 files changed, 0 insertions(+), 0 deletions(-) rename {test => tests}/__init__.py (100%) rename {test => tests}/oonisubmitter/__init__.py (100%) rename {test => tests}/oonisubmitter/test_aggregator.py (100%) rename {test => tests}/oonisubmitter/test_identifier.py (100%) rename {test => tests}/oonisubmitter/testdata/expected_state.json (100%) rename {test => tests}/oonisubmitter/testdata/sample.csv (100%) diff --git a/test/__init__.py b/tests/__init__.py similarity index 100% rename from test/__init__.py rename to tests/__init__.py diff --git a/test/oonisubmitter/__init__.py b/tests/oonisubmitter/__init__.py similarity index 100% rename from test/oonisubmitter/__init__.py rename to tests/oonisubmitter/__init__.py diff --git a/test/oonisubmitter/test_aggregator.py b/tests/oonisubmitter/test_aggregator.py similarity index 100% rename from test/oonisubmitter/test_aggregator.py rename to tests/oonisubmitter/test_aggregator.py diff --git a/test/oonisubmitter/test_identifier.py b/tests/oonisubmitter/test_identifier.py similarity index 100% rename from test/oonisubmitter/test_identifier.py rename to tests/oonisubmitter/test_identifier.py diff --git a/test/oonisubmitter/testdata/expected_state.json b/tests/oonisubmitter/testdata/expected_state.json similarity index 100% rename from test/oonisubmitter/testdata/expected_state.json rename to tests/oonisubmitter/testdata/expected_state.json diff --git a/test/oonisubmitter/testdata/sample.csv b/tests/oonisubmitter/testdata/sample.csv similarity index 100% rename from test/oonisubmitter/testdata/sample.csv rename to tests/oonisubmitter/testdata/sample.csv -- GitLab From 28b0a2c1762a5ae6f50b5059ac0d17e352d91ecc Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Tue, 17 Dec 2024 17:59:44 +0100 Subject: [PATCH 22/75] fix(tests): rename and remove the x bit --- tests/oonisubmitter/test_aggregator.py | 0 tests/oonisubmitter/{test_identifier.py => test_identifiers.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) mode change 100755 => 100644 tests/oonisubmitter/test_aggregator.py rename tests/oonisubmitter/{test_identifier.py => test_identifiers.py} (100%) diff --git a/tests/oonisubmitter/test_aggregator.py b/tests/oonisubmitter/test_aggregator.py old mode 100755 new mode 100644 diff --git a/tests/oonisubmitter/test_identifier.py b/tests/oonisubmitter/test_identifiers.py similarity index 100% rename from tests/oonisubmitter/test_identifier.py rename to tests/oonisubmitter/test_identifiers.py -- GitLab From 04913c745fd7fe3d4a002e12c5b298bb08f69295 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Tue, 17 Dec 2024 22:01:49 +0100 Subject: [PATCH 23/75] refactor(oonireport): independent module implementing submission This commit extracts code to submit from oonisubmitter. I want to have independent modules. While there improve tests and add a __main__.py. --- oonireport/__init__.py | 19 +++ oonireport/__main__.py | 87 +++++++++++++ oonireport/collector.py | 154 ++++++++++++++++++++++++ oonireport/load.py | 123 +++++++++++++++++++ oonireport/model.py | 101 ++++++++++++++++ tests/oonireport/__init__.py | 1 + tests/oonireport/test_collector.py | 142 ++++++++++++++++++++++ tests/oonireport/test_load.py | 39 ++++++ tests/oonireport/test_model.py | 63 ++++++++++ tests/oonisubmitter/test_identifiers.py | 8 +- 10 files changed, 735 insertions(+), 2 deletions(-) create mode 100644 oonireport/__init__.py create mode 100644 oonireport/__main__.py create mode 100644 oonireport/collector.py create mode 100644 oonireport/load.py create mode 100644 oonireport/model.py create mode 100644 tests/oonireport/__init__.py create mode 100644 tests/oonireport/test_collector.py create mode 100644 tests/oonireport/test_load.py create mode 100644 tests/oonireport/test_model.py diff --git a/oonireport/__init__.py b/oonireport/__init__.py new file mode 100644 index 0000000..4031900 --- /dev/null +++ b/oonireport/__init__.py @@ -0,0 +1,19 @@ +""" +Code to submit OONI reports to the OONI collector. + +Exposes a subset of the original `oonireport(1)` CLI tool +usable via `python3 -m oonireport [args...]`. +""" + +from .collector import CollectorClient, CollectorConfig +from .load import load_measurements +from .model import APIError, Measurement, TestKeys + +__all__ = [ + "APIError", + "CollectorClient", + "CollectorConfig", + "Measurement", + "TestKeys", + "load_measurements", +] diff --git a/oonireport/__main__.py b/oonireport/__main__.py new file mode 100644 index 0000000..6162e00 --- /dev/null +++ b/oonireport/__main__.py @@ -0,0 +1,87 @@ +""" +Command line interface for submitting OONI measurements +emulating a subset of the `oonireport(1)` tool. +""" + +from typing import List, Optional + +import argparse +import sys + +from . import CollectorClient, CollectorConfig, load_measurements + + +def main(args: Optional[List[str]] = None) -> int: + """Main function implementing the `oonireport(1)` functionality.""" + parser = argparse.ArgumentParser( + description="Submit OONI measurements to a collector" + ) + subparsers = parser.add_subparsers(dest="command", required=True) + upload = subparsers.add_parser("upload", help="upload measurements to a collector") + + upload.add_argument( + "-f", + "--file", + required=True, + help="measurement file to submit", + ) + upload.add_argument( + "-c", + "--collector", + default="https://api.ooni.io/", + help="collector base URL (default: https://api.ooni.io/)", + ) + upload.add_argument( + "-t", + "--timeout", + type=float, + default=30.0, + help="timeout in seconds (default: 30.0)", + ) + + opts = parser.parse_args(args) + if opts.command != "upload": + print("Unknown command", file=sys.stderr) + return 1 + + try: + measurements = load_measurements(opts.file) + except Exception as exc: + print(f"Failed to load measurements: {exc}", file=sys.stderr) + return 1 + + if not measurements: + print("No measurements to submit", file=sys.stderr) + return 1 + + config = CollectorConfig( + collector_base_url=opts.collector, + timeout=opts.timeout, + ) + client = CollectorClient(config) + + for idx, measurement in enumerate(measurements, 1): + print(f"Submitting measurement {idx}/{len(measurements)}...") + + try: + # Create a new report for this measurement + report_id = client.create_report_from_measurement(measurement) + print(f"Created report {report_id}") + + # Submit the measurement to the report + measurement_uid = client.update_report(report_id, measurement) + if measurement_uid: + print(f"Submitted measurement: {measurement_uid}") + else: + print("Submitted measurement (no UID returned)") + + except Exception as exc: + print(f"Failed to submit measurement: {exc}", file=sys.stderr) + return 1 + + print("All measurements submitted successfully") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/oonireport/collector.py b/oonireport/collector.py new file mode 100644 index 0000000..c6b80b0 --- /dev/null +++ b/oonireport/collector.py @@ -0,0 +1,154 @@ +""" +Implements the OONI collector protocol. + +See https://github.com/ooni/spec/blob/master/backends/bk-003-collector.md. +""" + +from dataclasses import dataclass +from datetime import datetime +from typing import Optional +from urllib.parse import urljoin + +import json +import urllib.request + +from .model import APIError, Measurement, datetime_to_ooni_format + + +@dataclass +class CollectorConfig: + """Contains configuration for the OONI collector client.""" + + collector_base_url: str # e.g., "https://api.ooni.io/" + timeout: float = 30.0 + + +class CollectorClient: + """ + Implements the OONI collector client protocol. + + See https://github.com/ooni/spec/blob/master/backends/bk-003-collector.md. + """ + + # TODO(bassosimone): consider support for retries. The spec + # says the following about retries: + # + # > A client side implementation MAY retry any failing collector + # > operation immediately for three times in case there is a + # > DNS or TCP error. [...] If all these immediate retries fail, + # > then the client SHOULD arrange for resubmitting the + # > measurement at a later time + # + # However, for now, we're good without implementing retries. + + def __init__(self, config: CollectorConfig): + self.config = config + + def create_report_from_measurement(self, measurement: Measurement) -> str: + """Convenience method to create report from existing OONI Measurement.""" + return self.create_report( + test_name=measurement.test_name, + test_version=measurement.test_version, + software_name=measurement.software_name, + software_version=measurement.software_version, + probe_asn=measurement.probe_asn, + probe_cc=measurement.probe_cc, + test_start_time=measurement.test_start_time, + ) + + def create_report( + self, + test_name: str, + test_version: str, + software_name: str, + software_version: str, + probe_asn: str, + probe_cc: str, + test_start_time: datetime, + ) -> str: + """ + Creates a new report and returns the report ID. + + Returns: + Report ID to use for submitting measurements. + + Raises: + model.APIError: in case of failure. + """ + report = { + "data_format_version": "0.2.0", + "format": "json", + "probe_asn": probe_asn, + "probe_cc": probe_cc, + "software_name": software_name, + "software_version": software_version, + "test_name": test_name, + "test_start_time": datetime_to_ooni_format(test_start_time), + "test_version": test_version, + } + + data = json.dumps(report).encode("utf-8") + req = urllib.request.Request( + urljoin(self.config.collector_base_url, "report"), + data=data, + headers={"Content-Type": "application/json"}, + method="POST", + ) + + try: + with urllib.request.urlopen(req, timeout=self.config.timeout) as resp: + if resp.status != 200: + raise APIError(f"unexpected status: {resp.status}") + response = json.loads(resp.read().decode()) + if "report_id" not in response: + raise APIError("missing report_id in response") + return response["report_id"] + + except Exception as exc: + raise APIError(f"HTTP error: {exc}") + + def update_report( + self, + report_id: str, + measurement: Measurement, + ) -> Optional[str]: + """ + Update a report by adding a measurement. + + Args: + report_id: The ID returned by create_report(). + measurement: The measurement to submit. + + Returns: + The measurement_uid, if provided by server, otherwise None. + + Raises: + model.APIError: in case of failure. + """ + measurement.report_id = report_id # Required for Explorer visualization + data = json.dumps( + { + "format": "json", + "content": measurement.as_dict(), + } + ).encode("utf-8") + + req = urllib.request.Request( + urljoin(self.config.collector_base_url, f"report/{report_id}"), + data=data, + headers={"Content-Type": "application/json"}, + method="POST", + ) + + try: + with urllib.request.urlopen(req, timeout=self.config.timeout) as resp: + if resp.status != 200: + raise APIError(f"unexpected status: {resp.status}") + try: + response = json.loads(resp.read().decode()) + return response.get("measurement_uid") + except json.JSONDecodeError: + return None + + except Exception as exc: + raise APIError(f"HTTP error: {exc}") diff --git a/oonireport/load.py b/oonireport/load.py new file mode 100644 index 0000000..5104ff4 --- /dev/null +++ b/oonireport/load.py @@ -0,0 +1,123 @@ +"""Functions for loading OONI measurements from a given file.""" + +from datetime import datetime, timezone +from typing import Dict, List + +import json + +from .model import Measurement + + +class _DictTestKeys: + """Wrapper for raw test_keys dict implementing TestKeys protocol.""" + + def __init__(self, keys: Dict): + self._keys = keys + + def as_dict(self) -> Dict: + return self._keys + + +def _load_measurement(raw_data: str) -> Measurement: + """ + Creates a Measurement from JSON data. + + Args: + data: a JSON string. + + Returns: + A Measurement instance. + + Raises: + ValueError: if the data is invalid. + """ + data = json.loads(raw_data) + + # Basic measurement validation + required_fields = { + "annotations", + "data_format_version", + "input", + "measurement_start_time", + "probe_asn", + "probe_cc", + "software_name", + "software_version", + "test_keys", + "test_name", + "test_runtime", + "test_start_time", + "test_version", + } + missing = required_fields - set(data.keys()) + if missing: + raise ValueError(f"Missing required fields: {missing}") + + # Parse dates + try: + measurement_start_time = datetime.strptime( + data["measurement_start_time"], "%Y-%m-%d %H:%M:%S" + ).replace(tzinfo=timezone.utc) + + test_start_time = datetime.strptime( + data["test_start_time"], "%Y-%m-%d %H:%M:%S" + ).replace(tzinfo=timezone.utc) + except ValueError as exc: + raise ValueError(f"Invalid datetime format: {exc}") + + # Assemble the measurement + return Measurement( + # Mandatory fields + annotations=data["annotations"], + data_format_version=data["data_format_version"], + input=data["input"], + measurement_start_time=measurement_start_time, + probe_asn=data["probe_asn"], + probe_cc=data["probe_cc"], + software_name=data["software_name"], + software_version=data["software_version"], + test_keys=_DictTestKeys(data["test_keys"]), + test_name=data["test_name"], + test_runtime=float(data["test_runtime"]), + test_start_time=test_start_time, + test_version=data["test_version"], + # Fields emitted with possibly default values + probe_ip=data.get("probe_ip", "127.0.0.1"), + report_id=data.get("report_id", ""), + # Optional fields + options=data.get("options", []), + probe_network_name=data.get("probe_network_name", ""), + resolver_asn=data.get("resolver_asn", ""), + resolver_ip=data.get("resolver_ip", ""), + resolver_network_name=data.get("resolver_network_name", ""), + test_helpers=data.get("test_helpers", {}), + ) + + +def load_measurements(path: str) -> List[Measurement]: + """ + Loads measurements from a JSON file containing one measurement per line. + + Args: + path: Path to the JSON file + + Returns: + List of Measurement instances. + + Raises: + ValueError: if the file format is invalid. + OSError: if file operations fail. + """ + with open(path) as filep: + content = filep.read().strip() + if not content: + return [] + + # Try parsing as newline-delimited JSON + measurements = [] + for line in content.split("\n"): + line = line.strip() + if line: # Skip empty lines + measurements.append(_load_measurement(line)) + + return measurements diff --git a/oonireport/model.py b/oonireport/model.py new file mode 100644 index 0000000..1439358 --- /dev/null +++ b/oonireport/model.py @@ -0,0 +1,101 @@ +""" +OONI measurement model. + +See https://github.com/ooni/spec/blob/master/data-formats/df-000-base.md. +""" + +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Dict, Protocol + + +class APIError(Exception): + """Raised when there are OONI API errors.""" + + +class TestKeys(Protocol): + """Models the OONI measurement test keys.""" + + def as_dict(self) -> Dict: + """Converts the test keys to a JSON-serializable dict.""" + ... + + +@dataclass +class Measurement: + """Models the OONI measurement envelope.""" + + # mandatory fields + annotations: Dict[str, str] + data_format_version: str + input: str # e.g., {protocol}://{provider}/?{query_string} + measurement_start_time: datetime + probe_asn: str # Format: ^AS[0-9]+$ + probe_cc: str # Format: ^[A-Z]{2}$ + software_name: str + software_version: str + test_keys: TestKeys + test_name: str + test_runtime: float + test_start_time: datetime + test_version: str + + # Fields emitted with possibly default values + probe_ip: str = "127.0.0.1" + report_id: str = "" + + # Optional fields + options: list[str] = field(default_factory=list) + probe_network_name: str = "" + resolver_asn: str = "" + resolver_cc: str = "" + resolver_ip: str = "" + resolver_network_name: str = "" + test_helpers: Dict[str, Any] = field(default_factory=dict) + + def as_dict(self) -> Dict: + """Converts the measurement to a JSON-serializable dict""" + + # Add mandatory fields + dct = { + "annotations": self.annotations, + "data_format_version": self.data_format_version, + "input": self.input, + "measurement_start_time": datetime_to_ooni_format( + self.measurement_start_time + ), + "probe_asn": self.probe_asn, + "probe_cc": self.probe_cc, + "software_name": self.software_name, + "software_version": self.software_version, + "test_keys": self.test_keys.as_dict(), + "test_name": self.test_name, + "test_runtime": self.test_runtime, + "test_start_time": datetime_to_ooni_format(self.test_start_time), + "test_version": self.test_version, + } + + # Fields emitted with possibly default values + dct["probe_ip"] = self.probe_ip if self.probe_ip else "127.0.0.1" + dct["report_id"] = self.report_id + + # Add optional fields + if self.options: + dct["options"] = self.options + if self.probe_network_name: + dct["probe_network_name"] = self.probe_network_name + if self.resolver_asn: + dct["resolver_asn"] = self.resolver_asn + if self.resolver_ip: + dct["resolver_ip"] = self.resolver_ip + if self.resolver_network_name: + dct["resolver_network_name"] = self.resolver_network_name + if self.test_helpers: + dct["test_helpers"] = self.test_helpers + + return dct + + +def datetime_to_ooni_format(dt: datetime) -> str: + """Converts a datetime to OONI's datetime format (YYYY-mm-dd HH:MM:SS).""" + return dt.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") diff --git a/tests/oonireport/__init__.py b/tests/oonireport/__init__.py new file mode 100644 index 0000000..cbbb2ce --- /dev/null +++ b/tests/oonireport/__init__.py @@ -0,0 +1 @@ +"""Tests for the oonireport module.""" diff --git a/tests/oonireport/test_collector.py b/tests/oonireport/test_collector.py new file mode 100644 index 0000000..c3b6060 --- /dev/null +++ b/tests/oonireport/test_collector.py @@ -0,0 +1,142 @@ +"""Tests for the collector module.""" + +from unittest.mock import patch +from datetime import datetime, timezone + +import json +import unittest + +from oonireport.collector import CollectorClient, CollectorConfig +from oonireport.model import APIError, Measurement + + +class MockResponse: + """Simulates urllib response object.""" + + def __init__(self, status, data): + self.status = status + self._data = data + + def read(self): + return json.dumps(self._data).encode() + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + +class SimpleTestKeys: + """Simple TestKeys implementation for testing.""" + + def as_dict(self): + return {"simple": "test"} + + +class TestCollectorClient(unittest.TestCase): + def setUp(self): + self.config = CollectorConfig( + collector_base_url="https://example.org", timeout=30.0 + ) + self.client = CollectorClient(self.config) + + @patch("urllib.request.urlopen") + def test_create_report_success(self, mock_urlopen): + mock_urlopen.return_value = MockResponse(200, {"report_id": "test_report_id"}) + + report_id = self.client.create_report( + test_name="web_connectivity", + test_version="0.0.1", + software_name="ooniprobe", + software_version="3.0.0", + probe_asn="AS12345", + probe_cc="IT", + ) + + self.assertEqual(report_id, "test_report_id") + + @patch("urllib.request.urlopen") + def test_create_report_http_error(self, mock_urlopen): + mock_urlopen.side_effect = Exception("Connection failed") + + with self.assertRaises(APIError) as cm: + self.client.create_report( + test_name="web_connectivity", + test_version="0.0.1", + software_name="ooniprobe", + software_version="3.0.0", + probe_asn="AS12345", + probe_cc="IT", + ) + + self.assertIn("HTTP error", str(cm.exception)) + + @patch("urllib.request.urlopen") + def test_create_report_invalid_response(self, mock_urlopen): + mock_urlopen.return_value = MockResponse( + 200, {"wrong_field": "value"} # Missing report_id + ) + + with self.assertRaises(APIError) as cm: + self.client.create_report( + test_name="web_connectivity", + test_version="0.0.1", + software_name="ooniprobe", + software_version="3.0.0", + probe_asn="AS12345", + probe_cc="IT", + ) + + self.assertIn("missing report_id", str(cm.exception)) + + @patch("urllib.request.urlopen") + def test_update_report_success(self, mock_urlopen): + mock_urlopen.return_value = MockResponse( + 200, {"measurement_id": "test_measurement_id"} + ) + + measurement = Measurement( + annotations={}, + data_format_version="0.2.0", + input="https://example.com", + measurement_start_time=datetime.now(timezone.utc), + probe_asn="AS12345", + probe_cc="IT", + software_name="ooniprobe", + software_version="3.0.0", + test_keys=SimpleTestKeys(), + test_name="web_connectivity", + test_runtime=1.0, + test_start_time=datetime.now(timezone.utc), + test_version="0.0.1", + ) + + measurement_id = self.client.update_report("test_report_id", measurement) + + self.assertEqual(measurement_id, "test_measurement_id") + + @patch("urllib.request.urlopen") + def test_update_report_http_error(self, mock_urlopen): + mock_urlopen.side_effect = Exception("Connection failed") + + measurement = Measurement( + annotations={}, + data_format_version="0.2.0", + input="https://example.com", + measurement_start_time=datetime.now(timezone.utc), + probe_asn="AS12345", + probe_cc="IT", + software_name="ooniprobe", + software_version="3.0.0", + test_keys=SimpleTestKeys(), + test_name="web_connectivity", + test_runtime=1.0, + test_start_time=datetime.now(timezone.utc), + test_version="0.0.1", + ) + + with self.assertRaises(APIError) as cm: + self.client.update_report("test_report_id", measurement) + + self.assertIn("HTTP error", str(cm.exception)) diff --git a/tests/oonireport/test_load.py b/tests/oonireport/test_load.py new file mode 100644 index 0000000..d166d93 --- /dev/null +++ b/tests/oonireport/test_load.py @@ -0,0 +1,39 @@ +"""Tests for measurement loading functionality.""" + +import json +import os +import tempfile +import unittest + +from oonireport import load_measurements + + +SAMPLE_MEASUREMENT = { + "annotations": {}, + "data_format_version": "0.2.0", + "input": "https://example.com/", + "measurement_start_time": "2023-01-01 12:00:00", + "probe_asn": "AS12345", + "probe_cc": "IT", + "software_name": "ooniprobe", + "software_version": "3.0.0", + "test_keys": {"simple": "test"}, + "test_name": "web_connectivity", + "test_runtime": 1.0, + "test_start_time": "2023-01-01 12:00:00", + "test_version": "0.0.1", +} + + +class TestMeasurementLoading(unittest.TestCase): + + def test_load_from_file(self): + with tempfile.NamedTemporaryFile(mode="w", delete=False) as filep: + filep.write(json.dumps(SAMPLE_MEASUREMENT) + "\n") + filep.write(json.dumps(SAMPLE_MEASUREMENT) + "\n") + + try: + measurements = load_measurements(filep.name) + self.assertEqual(len(measurements), 2) + finally: + os.unlink(filep.name) diff --git a/tests/oonireport/test_model.py b/tests/oonireport/test_model.py new file mode 100644 index 0000000..a2ede64 --- /dev/null +++ b/tests/oonireport/test_model.py @@ -0,0 +1,63 @@ +"""Tests for the model module.""" + +from datetime import datetime, timedelta, timezone + +import unittest + +from oonireport.model import Measurement, datetime_to_ooni_format + + +class SimpleTestKeys: + """Simple TestKeys implementation for testing.""" + + def as_dict(self): + return {"simple": "test"} + + +class TestModel(unittest.TestCase): + + def test_measurement_as_dict(self): + dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) + measurement = Measurement( + annotations={"annotation_key": "value"}, + data_format_version="0.2.0", + input="https://example.com", + measurement_start_time=dt, + probe_asn="AS12345", + probe_cc="IT", + software_name="ooniprobe", + software_version="3.0.0", + test_keys=SimpleTestKeys(), + test_name="web_connectivity", + test_runtime=1.0, + test_start_time=dt, + test_version="0.0.1", + ) + + data = measurement.as_dict() + + self.assertEqual(data["annotations"], {"annotation_key": "value"}) + self.assertEqual(data["data_format_version"], "0.2.0") + self.assertEqual(data["input"], "https://example.com") + self.assertEqual(data["measurement_start_time"], "2023-01-01 12:00:00") + self.assertEqual(data["probe_asn"], "AS12345") + self.assertEqual(data["probe_cc"], "IT") + self.assertEqual(data["test_keys"], {"simple": "test"}) + self.assertEqual(data["test_name"], "web_connectivity") + self.assertEqual(data["test_runtime"], 1.0) + self.assertEqual(data["test_start_time"], "2023-01-01 12:00:00") + self.assertEqual(data["test_version"], "0.0.1") + + def test_datetime_to_ooni_format(self): + dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) + formatted = datetime_to_ooni_format(dt) + self.assertEqual(formatted, "2023-01-01 12:00:00") + + # Test timezone conversion + dt = datetime(2023, 1, 1, 14, 0, tzinfo=timezone(timedelta(hours=2))) + formatted = datetime_to_ooni_format(dt) + self.assertEqual(formatted, "2023-01-01 12:00:00") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/oonisubmitter/test_identifiers.py b/tests/oonisubmitter/test_identifiers.py index 135131f..6310fe4 100644 --- a/tests/oonisubmitter/test_identifiers.py +++ b/tests/oonisubmitter/test_identifiers.py @@ -79,7 +79,9 @@ class TestBridgeEndpointID(unittest.TestCase): def test_parse_valid_string(self): """Test parsing a valid string into BridgeEndpointID.""" - identifier = f"bridge.example.com{SEPARATOR}1.2.3.4{SEPARATOR}443{SEPARATOR}obfs4" + identifier = ( + f"bridge.example.com{SEPARATOR}1.2.3.4{SEPARATOR}443{SEPARATOR}obfs4" + ) endpoint = BridgeEndpointID.parse(identifier) self.assertEqual(endpoint.hostname, "bridge.example.com") self.assertEqual(endpoint.address, "1.2.3.4") @@ -105,7 +107,9 @@ class TestBridgeEndpointID(unittest.TestCase): def test_parse_invalid_port(self): """Test parsing string with non-integer port raises ValueError.""" - identifier = f"bridge.example.com{SEPARATOR}1.2.3.4{SEPARATOR}invalid{SEPARATOR}obfs4" + identifier = ( + f"bridge.example.com{SEPARATOR}1.2.3.4{SEPARATOR}invalid{SEPARATOR}obfs4" + ) with self.assertRaises(ValueError): BridgeEndpointID.parse(identifier) -- GitLab From 6922821e95cfd5a641f9bc655e53e9a0f58f8e13 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 1 Feb 2025 14:44:10 +0100 Subject: [PATCH 24/75] refactor(oonireport): changes after self code review --- oonireport/__init__.py | 26 ++++++++++++-- oonireport/__main__.py | 75 +++++++++++++++++++++++++++++------------ oonireport/collector.py | 2 ++ oonireport/load.py | 28 ++++++++++++++- oonireport/model.py | 2 ++ 5 files changed, 108 insertions(+), 25 deletions(-) diff --git a/oonireport/__init__.py b/oonireport/__init__.py index 4031900..ea29ba0 100644 --- a/oonireport/__init__.py +++ b/oonireport/__init__.py @@ -1,10 +1,30 @@ """ -Code to submit OONI reports to the OONI collector. +OONI Report Submission Library +============================== -Exposes a subset of the original `oonireport(1)` CLI tool -usable via `python3 -m oonireport [args...]`. +This package provides classes for interacting with OONI collectors +servers and submitting OONI measurements to them. + +Classes: + CollectorClient: A client for interacting with OONI collectors. + CollectorConfig: Configuration for the `CollectorClient`. + Measurement: A class representing an OONI measurement external envelope. + TestKeys: A class representing the keys of an OONI measurement. + +Functions: + load_measurements: Load measurements from given file containing one + measurement per line, each measurement being a JSON object. + +Exceptions: + APIError: An exception raised in case of failure. + +Scripts: + This package can also be run as a script using `python3 -m oonireport`. Run + the `python -m oonireport --help` command for detailed usage. """ +# SPDX-License-Identifier: GPL-3.0-or-later + from .collector import CollectorClient, CollectorConfig from .load import load_measurements from .model import APIError, Measurement, TestKeys diff --git a/oonireport/__main__.py b/oonireport/__main__.py index 6162e00..8738509 100644 --- a/oonireport/__main__.py +++ b/oonireport/__main__.py @@ -3,27 +3,33 @@ Command line interface for submitting OONI measurements emulating a subset of the `oonireport(1)` tool. """ +# SPDX-License-Identifier: GPL-3.0-or-later + from typing import List, Optional import argparse +import json +import os import sys from . import CollectorClient, CollectorConfig, load_measurements def main(args: Optional[List[str]] = None) -> int: - """Main function implementing the `oonireport(1)` functionality.""" + """Main function implementing the `oonireport(1)` command line tool.""" parser = argparse.ArgumentParser( - description="Submit OONI measurements to a collector" + description="Submit measurements to the OONI collector" ) subparsers = parser.add_subparsers(dest="command", required=True) - upload = subparsers.add_parser("upload", help="upload measurements to a collector") + upload = subparsers.add_parser( + "upload", help="upload measurements to the OONI collector" + ) upload.add_argument( - "-f", - "--file", - required=True, - help="measurement file to submit", + "-F", + "--dump-failed", + action="store_true", + help="dump to stdout the measurements we could not submit", ) upload.add_argument( "-c", @@ -31,6 +37,18 @@ def main(args: Optional[List[str]] = None) -> int: default="https://api.ooni.io/", help="collector base URL (default: https://api.ooni.io/)", ) + upload.add_argument( + "-d", + "--delete-input-file", + action="store_true", + help="delete the input file when done submitting", + ) + upload.add_argument( + "-f", + "--file", + required=True, + help="measurement file to submit", + ) upload.add_argument( "-t", "--timeout", @@ -41,17 +59,17 @@ def main(args: Optional[List[str]] = None) -> int: opts = parser.parse_args(args) if opts.command != "upload": - print("Unknown command", file=sys.stderr) + print("oonireport: unknown command", file=sys.stderr) return 1 try: measurements = load_measurements(opts.file) except Exception as exc: - print(f"Failed to load measurements: {exc}", file=sys.stderr) + print(f"oonireport: failed to load measurements: {exc}", file=sys.stderr) return 1 if not measurements: - print("No measurements to submit", file=sys.stderr) + print("oonireport: no measurements to submit", file=sys.stderr) return 1 config = CollectorConfig( @@ -60,27 +78,42 @@ def main(args: Optional[List[str]] = None) -> int: ) client = CollectorClient(config) + numfailures = 0 for idx, measurement in enumerate(measurements, 1): - print(f"Submitting measurement {idx}/{len(measurements)}...") + print( + f"oonireport: submitting measurement {idx}/{len(measurements)}...", + file=sys.stderr, + ) try: # Create a new report for this measurement report_id = client.create_report_from_measurement(measurement) - print(f"Created report {report_id}") + print(f"oonireport: created report {report_id}", file=sys.stderr) - # Submit the measurement to the report + # Append measurement to the report measurement_uid = client.update_report(report_id, measurement) - if measurement_uid: - print(f"Submitted measurement: {measurement_uid}") - else: - print("Submitted measurement (no UID returned)") + if not measurement_uid: + measurement_uid = "N/A" + print( + f"oonireport: submitted measurement: {measurement_uid}", file=sys.stderr + ) except Exception as exc: - print(f"Failed to submit measurement: {exc}", file=sys.stderr) - return 1 + print(f"oonireport: failed to submit measurement: {exc}", file=sys.stderr) + numfailures += 1 + if opts.dump_failed: + print(json.dumps(measurement.as_dict()), file=sys.stdout) + + print("oonireport: finished submitting measurements", file=sys.stderr) + + if opts.delete_input_file: + print("oonireport: deleting input file", file=sys.stderr) + try: + os.unlink(opts.file) + except Exception as exc: + print(f"oonireport: failed to delete input file: {exc}", file=sys.stderr) - print("All measurements submitted successfully") - return 0 + return 0 if numfailures == 0 else 1 if __name__ == "__main__": diff --git a/oonireport/collector.py b/oonireport/collector.py index c6b80b0..e035925 100644 --- a/oonireport/collector.py +++ b/oonireport/collector.py @@ -4,6 +4,8 @@ Implements the OONI collector protocol. See https://github.com/ooni/spec/blob/master/backends/bk-003-collector.md. """ +# SPDX-License-Identifier: GPL-3.0-or-later + from dataclasses import dataclass from datetime import datetime from typing import Optional diff --git a/oonireport/load.py b/oonireport/load.py index 5104ff4..8ae9eab 100644 --- a/oonireport/load.py +++ b/oonireport/load.py @@ -1,7 +1,9 @@ """Functions for loading OONI measurements from a given file.""" +# SPDX-License-Identifier: GPL-3.0-or-later + from datetime import datetime, timezone -from typing import Dict, List +from typing import Dict, Iterator, List import json @@ -121,3 +123,27 @@ def load_measurements(path: str) -> List[Measurement]: measurements.append(_load_measurement(line)) return measurements + + +def stream_measurements(path: str) -> Iterator[Measurement]: + """ + Streams measurements from a JSON file containing one measurement per line. + + Is more efficient than load_measurements() for large files. + + Args: + path: Path to the JSON file + + Returns: + Iterator yielding Measurement instances. + + Raises: + ValueError: if the file format is invalid. + OSError: if file operations fail. + + Note: + Not implemented yet. + """ + # TODO(bassosimone): implement this function and reimplement + # load_measurements() in terms of it. + raise NotImplementedError() diff --git a/oonireport/model.py b/oonireport/model.py index 1439358..8393b4e 100644 --- a/oonireport/model.py +++ b/oonireport/model.py @@ -4,6 +4,8 @@ OONI measurement model. See https://github.com/ooni/spec/blob/master/data-formats/df-000-base.md. """ +# SPDX-License-Identifier: GPL-3.0-or-later + from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any, Dict, Protocol -- GitLab From 167014dcf9b165045f1d2f00c5a81bba42ffbdf5 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 1 Feb 2025 15:26:51 +0100 Subject: [PATCH 25/75] refactor: move lockedfile into its own package --- lockedfile/__init__.py | 22 ++++++ lockedfile/common.py | 7 ++ lockedfile/fileio.py | 105 ++++++++++++++++++++++++ lockedfile/mutex.py | 41 ++++++++++ oonisubmitter/lockedfile.py | 154 ++++-------------------------------- 5 files changed, 190 insertions(+), 139 deletions(-) create mode 100644 lockedfile/__init__.py create mode 100644 lockedfile/common.py create mode 100644 lockedfile/fileio.py create mode 100644 lockedfile/mutex.py diff --git a/lockedfile/__init__.py b/lockedfile/__init__.py new file mode 100644 index 0000000..ea6366f --- /dev/null +++ b/lockedfile/__init__.py @@ -0,0 +1,22 @@ +""" +File locking utilities for safe concurrent access. + +This package provides utilities for reading and writing files while +holding locks on them, ensuring safe concurrent access. + +Note: this package only works on Unix systems. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from .common import FileLockError +from .fileio import FileIOConfig, read, write +from .mutex import Mutex + +__all__ = [ + "FileIOConfig", + "FileLockError", + "Mutex", + "read", + "write", +] diff --git a/lockedfile/common.py b/lockedfile/common.py new file mode 100644 index 0000000..f74c32f --- /dev/null +++ b/lockedfile/common.py @@ -0,0 +1,7 @@ +"""Common code for file locking.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + + +class FileLockError(Exception): + """Error emitted by this module.""" diff --git a/lockedfile/fileio.py b/lockedfile/fileio.py new file mode 100644 index 0000000..3c0b38b --- /dev/null +++ b/lockedfile/fileio.py @@ -0,0 +1,105 @@ +"""Safe file I/O operations with locking.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from dataclasses import dataclass +from io import TextIOWrapper +from typing import Optional + +import fcntl +import os +import time + +from .common import FileLockError + + +@dataclass +class FileIOConfig: + """ + Configures attempting to acquire a file lock. + + Fields: + num_retries: Number of times to retry acquiring the lock + sleep_interval: Time to sleep between each attempt (in seconds) + """ + + num_retries: int = 10 + sleep_interval: float = 0.1 + + +def read(filepath: str, config: Optional[FileIOConfig] = None) -> str: + """ + Read entire file while holding a shared lock. + + Args: + filepath: Path to file to read + config: Optional locking configuration + + Raises: + FileLockError: if cannot acquire the file lock + IOError: if file operations fail + FileNotFoundError: if the file does not exist + """ + with open(filepath, "r") as filep: + if not _acquire_shared(filep, config): + raise FileLockError(f"cannot acquire read lock on {filepath}") + try: + return filep.read() + finally: + _release(filep) + + +def write(filepath: str, data: str, config: Optional[FileIOConfig] = None) -> None: + """ + Write entire file while holding an exclusive lock. + + Args: + filepath: Path to file to write + data: Content to write + config: Optional locking configuration + + Raises: + FileLockError: if cannot acquire the file lock + IOError: if file operations fail + FileNotFoundError: if the file does not exist + """ + with open(filepath, "w") as filep: + if not _acquire_exclusive(filep, config): + raise FileLockError(f"cannot acquire write lock on {filepath}") + try: + filep.write(data) + # Implementation note: flush to buffer cache and then + # persist to permanent storage with fsync + filep.flush() + os.fsync(filep.fileno()) + finally: + _release(filep) + + +def _acquire_shared(filep: TextIOWrapper, config: Optional[FileIOConfig]) -> bool: + return _try_lock(filep, fcntl.LOCK_SH | fcntl.LOCK_NB, config) + + +def _acquire_exclusive(filep: TextIOWrapper, config: Optional[FileIOConfig]) -> bool: + return _try_lock(filep, fcntl.LOCK_EX | fcntl.LOCK_NB, config) + + +def _release(filep: TextIOWrapper) -> None: + try: + fcntl.flock(filep.fileno(), fcntl.LOCK_UN) + except OSError: + pass + + +def _try_lock( + filep: TextIOWrapper, operation: int, config: Optional[FileIOConfig] +) -> bool: + if not config: + config = FileIOConfig() + for _ in range(config.num_retries): + try: + fcntl.flock(filep.fileno(), operation) + return True + except BlockingIOError: + time.sleep(config.sleep_interval) + return False diff --git a/lockedfile/mutex.py b/lockedfile/mutex.py new file mode 100644 index 0000000..3f12f08 --- /dev/null +++ b/lockedfile/mutex.py @@ -0,0 +1,41 @@ +"""Mutex class for mutual exclusion using flock""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +import fcntl + +from .common import FileLockError + + +class Mutex: + """ + Provides mutual exclusion using a lock file and flock. + + The lock file persists between uses - only the lock itself + is acquired and released, not the file's existence. + """ + + def __init__(self, filepath: str): + self.filepath = filepath + self.filep = None + + def __enter__(self) -> "Mutex": + """Acquire exclusive lock using flock""" + try: + # Open for read-write, create if doesn't exist + self.filep = open(self.filepath, "a+") + fcntl.flock(self.filep.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) + return self + except (IOError, BlockingIOError) as err: + if self.filep: + self.filep.close() + raise FileLockError(f"cannot acquire lock: {err}") + + def __exit__(self, *args): + """Release lock and close file""" + if self.filep: + try: + fcntl.flock(self.filep.fileno(), fcntl.LOCK_UN) + finally: + self.filep.close() + self.filep = None diff --git a/oonisubmitter/lockedfile.py b/oonisubmitter/lockedfile.py index 66b06c9..d776a10 100644 --- a/oonisubmitter/lockedfile.py +++ b/oonisubmitter/lockedfile.py @@ -8,142 +8,18 @@ Patterned after https://github.com/rogpeppe/go-internal `lockedfile`. SPDX-License-Identifier: BSD-3-Clause """ -from dataclasses import dataclass -from io import TextIOWrapper -from typing import Optional - -import fcntl -import os -import time - -# TODO(bassosimone): write tests for this functionality -# once we have addresses more pressing issues with wiring -# in the `oonisubmitter` module into the ETL pipeline. - - -class FileLockError(Exception): - """Error emitted by this module.""" - - -@dataclass -class Config: - """ - Configures attempting to acquire a lock. - - Fields: - num_retries: Number of times to retry acquiring the lock - sleep_interval: Time to sleep between each attempt (in seconds) - """ - - num_retries: int = 10 - sleep_interval: float = 0.1 - - -class Mutex: - """ - Provides mutual exclusion using a lock file and flock. - - The lock file persists between uses - only the lock itself - is acquired and released, not the file's existence. - """ - - def __init__(self, filepath: str): - self.filepath = filepath - self.filep = None - - def __enter__(self) -> "Mutex": - """Acquire exclusive lock using flock""" - try: - # Open for read-write, create if doesn't exist - self.filep = open(self.filepath, "a+") - fcntl.flock(self.filep.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB) - return self - except (IOError, BlockingIOError) as err: - if self.filep: - self.filep.close() - raise FileLockError(f"cannot acquire lock: {err}") - - def __exit__(self, *args): - """Release lock and close file""" - if self.filep: - try: - fcntl.flock(self.filep.fileno(), fcntl.LOCK_UN) - finally: - self.filep.close() - self.filep = None - - -def read(filepath: str, config: Optional[Config] = None) -> str: - """ - Read entire file while holding a shared lock. - - Args: - filepath: Path to file to read - config: Optional locking configuration - - Raises: - FileLockError: if cannot acquire the file lock - IOError: if file operations fail - FileNotFoundError: if the file does not exist - """ - with open(filepath, "r") as filep: - if not _acquire_shared(filep, config): - raise FileLockError(f"cannot acquire read lock on {filepath}") - try: - return filep.read() - finally: - _release(filep) - - -def write(filepath: str, data: str, config: Optional[Config] = None) -> None: - """ - Write entire file while holding an exclusive lock. - - Args: - filepath: Path to file to write - data: Content to write - config: Optional locking configuration - - Raises: - FileLockError: if cannot acquire the file lock - IOError: if file operations fail - FileNotFoundError: if the file does not exist - """ - with open(filepath, "w") as filep: - if not _acquire_exclusive(filep, config): - raise FileLockError(f"cannot acquire write lock on {filepath}") - try: - filep.write(data) - # Implementation note: flush to buffer cache and then - # persist to permanent storage with fsync - filep.flush() - os.fsync(filep.fileno()) - finally: - _release(filep) - - -def _acquire_shared(filep: TextIOWrapper, config: Optional[Config]) -> bool: - return _try_lock(filep, fcntl.LOCK_SH | fcntl.LOCK_NB, config) - - -def _acquire_exclusive(filep: TextIOWrapper, config: Optional[Config]) -> bool: - return _try_lock(filep, fcntl.LOCK_EX | fcntl.LOCK_NB, config) - - -def _release(filep: TextIOWrapper) -> None: - try: - fcntl.flock(filep.fileno(), fcntl.LOCK_UN) - except OSError: - pass - - -def _try_lock(filep: TextIOWrapper, operation: int, config: Optional[Config]) -> bool: - if not config: - config = Config() - for _ in range(config.num_retries): - try: - fcntl.flock(filep.fileno(), operation) - return True - except BlockingIOError: - time.sleep(config.sleep_interval) - return False +from lockedfile import ( + FileIOConfig as Config, + FileLockError, + Mutex, + read, + write, +) + +__all__ = [ + "Config", + "FileLockError", + "Mutex", + "read", + "write", +] -- GitLab From 0ec0633b539dd37b88517aae312530157ae7ddd6 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 1 Feb 2025 15:54:56 +0100 Subject: [PATCH 26/75] refactor(oonisubmitter): mark lockedfile as backward compat This diff explicitly mentions that lockedfile.py in oonisubmitter is ther only for backward compatibility. I will need to refactor and consolitate many more modules before moving forward with further attempting to split the codebase. --- lockedfile/__init__.py | 2 ++ oonireport/__init__.py | 2 ++ oonisubmitter/lockedfile.py | 9 ++------- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/lockedfile/__init__.py b/lockedfile/__init__.py index ea6366f..9aaeadd 100644 --- a/lockedfile/__init__.py +++ b/lockedfile/__init__.py @@ -4,6 +4,8 @@ File locking utilities for safe concurrent access. This package provides utilities for reading and writing files while holding locks on them, ensuring safe concurrent access. +Patterned after https://github.com/rogpeppe/go-internal `lockedfile`. + Note: this package only works on Unix systems. """ diff --git a/oonireport/__init__.py b/oonireport/__init__.py index ea29ba0..485dd5c 100644 --- a/oonireport/__init__.py +++ b/oonireport/__init__.py @@ -5,6 +5,8 @@ OONI Report Submission Library This package provides classes for interacting with OONI collectors servers and submitting OONI measurements to them. +See https://github.com/ooni/spec/blob/master/backends/bk-003-collector.md. + Classes: CollectorClient: A client for interacting with OONI collectors. CollectorConfig: Configuration for the `CollectorClient`. diff --git a/oonisubmitter/lockedfile.py b/oonisubmitter/lockedfile.py index d776a10..1977ecf 100644 --- a/oonisubmitter/lockedfile.py +++ b/oonisubmitter/lockedfile.py @@ -1,12 +1,7 @@ -""" -Support module for reading and writing files while holding a lock on them. +"""Backward compatibility lockefile module.""" -Only work on Unix systems. +# TODO(bassosimone): refactor to use lockefile directly. -Patterned after https://github.com/rogpeppe/go-internal `lockedfile`. - -SPDX-License-Identifier: BSD-3-Clause -""" from lockedfile import ( FileIOConfig as Config, -- GitLab From 7aba46b85728af5c588e48abf74406e8b5736b50 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 1 Feb 2025 16:03:16 +0100 Subject: [PATCH 27/75] refactor: reuse datetime_to_ooni_format --- oonireport/__init__.py | 6 +++++- oonisubmitter/model.py | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/oonireport/__init__.py b/oonireport/__init__.py index 485dd5c..e7b1a15 100644 --- a/oonireport/__init__.py +++ b/oonireport/__init__.py @@ -14,6 +14,9 @@ Classes: TestKeys: A class representing the keys of an OONI measurement. Functions: + datetime_to_ooni_format: Convert a datetime object to an + serialized string using the OONI collector format. + load_measurements: Load measurements from given file containing one measurement per line, each measurement being a JSON object. @@ -29,7 +32,7 @@ Scripts: from .collector import CollectorClient, CollectorConfig from .load import load_measurements -from .model import APIError, Measurement, TestKeys +from .model import APIError, Measurement, TestKeys, datetime_to_ooni_format __all__ = [ "APIError", @@ -37,5 +40,6 @@ __all__ = [ "CollectorConfig", "Measurement", "TestKeys", + "datetime_to_ooni_format", "load_measurements", ] diff --git a/oonisubmitter/model.py b/oonisubmitter/model.py index b4f818f..168b1c4 100644 --- a/oonisubmitter/model.py +++ b/oonisubmitter/model.py @@ -34,6 +34,8 @@ from typing import Dict, List, Optional import csv import logging +import oonireport + class Scope(Enum): """Valid scopes for aggregate tunnel metrics.""" @@ -219,9 +221,7 @@ def datetime_to_compact_utc(dt: datetime) -> str: return dt.astimezone(timezone.utc).strftime("%Y%m%dT%H%M%SZ") -def datetime_to_ooni_format(dt: datetime) -> str: - """Convert datetime to OONI's format (YYYY-MM-DD hh:mm:ss)""" - return dt.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") +datetime_to_ooni_format = oonireport.datetime_to_ooni_format @dataclass -- GitLab From d3c85f8db848441db49fb171de50d1f7db3011e9 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 1 Feb 2025 16:14:20 +0100 Subject: [PATCH 28/75] refactor: consolidate on the Measurement definition --- oonisubmitter/model.py | 40 +--------------------------------------- 1 file changed, 1 insertion(+), 39 deletions(-) diff --git a/oonisubmitter/model.py b/oonisubmitter/model.py index 168b1c4..7354cf2 100644 --- a/oonisubmitter/model.py +++ b/oonisubmitter/model.py @@ -285,42 +285,4 @@ class AggregateTunnelMetricsTestKeys: return d -@dataclass -class OONIMeasurement: - """ - Models the OONI measurement envelope. - """ - - annotations: Dict[str, str] - data_format_version: str - input: str # {protocol}://{provider}/?{query_string} - measurement_start_time: datetime - probe_asn: str # Format: ^AS[0-9]+$ - probe_cc: str # Format: ^[A-Z]{2}$ - software_name: str - software_version: str - test_keys: AggregateTunnelMetricsTestKeys - test_name: str - test_runtime: float - test_start_time: datetime - test_version: str - - def as_dict(self) -> Dict: - """Convert to JSON-serializable dict""" - # TODO(bassosimone): ensure we include the correct - # annotation about the collector we're using - return { - "annotations": self.annotations, - "data_format_version": self.data_format_version, - "input": self.input, - "measurement_start_time": datetime_to_ooni_format( - self.measurement_start_time - ), - "probe_asn": self.probe_asn, - "probe_cc": self.probe_cc, - "test_keys": self.test_keys.as_dict(), - "test_name": self.test_name, - "test_runtime": self.test_runtime, - "test_start_time": datetime_to_ooni_format(self.test_start_time), - "test_version": self.test_version, - } +OONIMeasurement = oonireport.Measurement -- GitLab From 3066f598f44eb821c4ecc6841f8518d80c452323 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 1 Feb 2025 16:36:52 +0100 Subject: [PATCH 29/75] refactor: move fieldtestingcsv into its own package --- fieldtestingcsv/__init__.py | 19 +++++ fieldtestingcsv/model.py | 64 ++++++++++++++++ fieldtestingcsv/parser.py | 105 +++++++++++++++++++++++++++ oonisubmitter/model.py | 141 ++---------------------------------- 4 files changed, 194 insertions(+), 135 deletions(-) create mode 100644 fieldtestingcsv/__init__.py create mode 100644 fieldtestingcsv/model.py create mode 100644 fieldtestingcsv/parser.py diff --git a/fieldtestingcsv/__init__.py b/fieldtestingcsv/__init__.py new file mode 100644 index 0000000..c378114 --- /dev/null +++ b/fieldtestingcsv/__init__.py @@ -0,0 +1,19 @@ +""" +Field-Testing CSV +================= + +This package contains code for managing the field-testing CSV +based data format (e.g., loading and parsing). + +See https://0xacab.org/leap/solitech-compose-client/-/blob/main/images/obfsvpn-openvpn-client/start.sh. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +# TODO(bassosimone): documented the expected CSV field at +# the package level rather than just linking to code. + +from .model import Entry +from .parser import parse + +__all__ = ["Entry", "parse"] diff --git a/fieldtestingcsv/model.py b/fieldtestingcsv/model.py new file mode 100644 index 0000000..3f94bd8 --- /dev/null +++ b/fieldtestingcsv/model.py @@ -0,0 +1,64 @@ +"""Field-Testing CSV model.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from dataclasses import dataclass +from datetime import datetime + + +@dataclass +class Entry: + """ + Models a single field-testing entry read from the CSV datastore. + + The order of the fields in this dataclass it the same + of the fields within the CSV file. + """ + + # Fields originally present in the CSV file + # format as of 2024-12-06 + filename: str + date: datetime + asn: str + isp: str + est_city: str + user: str + region: str + server_fqdn: str + server_ip: str + mobile: bool + tunnel: str # 'baseline', 'tunnel', 'ERROR/baseline', 'ERROR/tunnel' + throughput_download: float + throughput_upload: float + latency_download: float + latency_upload: float + retransmission_download: float + retransmission_upload: float + ping_packets_loss: float + ping_roundtrip_min: float + ping_roundtrip_avg: float + ping_roundtrip_max: float + err_message: str + protocol: str + + # Fields added on 2024-12-06 to allow for exporting + # endpoint-level aggregate tunnel metrics. + # + # TODO(XXX): update the CSV file spec and generation. + endpoint_hostname: str + endpoint_address: str + endpoint_port: int + endpoint_asn: str + endpoint_cc: str + ping_target_address: str + + def is_tunnel_measurement(self) -> bool: + """ + Return whether this is a tunnel measurement, which includes both + successful and failed tunnel measurements. + """ + return self.tunnel in ("tunnel", "ERROR/tunnel") + + def is_tunnel_error_measurement(self) -> bool: + """Return whether this is a failed tunnel measurement""" + return self.tunnel == "ERROR/tunnel" diff --git a/fieldtestingcsv/parser.py b/fieldtestingcsv/parser.py new file mode 100644 index 0000000..8582eed --- /dev/null +++ b/fieldtestingcsv/parser.py @@ -0,0 +1,105 @@ +"""Field-Testing CSV parser.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from datetime import datetime, timezone +from typing import Iterator, List + +import csv +import logging + +from .model import Entry + + +def _parse_datetime(date_str: str) -> datetime: + """ + Parse ctime formatted date from CSV into datetime object. + + Example format: "Fri Dec 6 15:27:16 UTC 2024" + """ + # strptime directives: + # %a - Weekday name (Mon) + # %b - Month name (Nov) + # %d - Day of month (18) + # %H:%M:%S - Time (17:18:39) + # %Z - Timezone name (UTC) + # %Y - Year (2024) + dt = datetime.strptime(date_str, "%a %b %d %H:%M:%S %Z %Y") + + # For now, since we expect UTC, let's be strict + # + # TODO(bassosimone): do we need to care about non-UTC? + if "UTC" not in date_str: + raise ValueError(f"expected UTC timezone in date string, got: {date_str}") + + return dt.replace(tzinfo=timezone.utc) + + +def _parse_bool(value: str) -> bool: + """ + Parse boolean string from CSV into bool. + + Args: + value: String representation of boolean ("true" or "false") + + Returns: + True if value.lower() == "true", False otherwise + """ + return value.lower() == "true" + + +def parse(filename: str) -> List[Entry]: + """Parses and returns entries from CSV file.""" + entries = [] + + with open(filename, "r") as f: + reader = csv.DictReader(f) + for row in reader: + try: + + measurement = Entry( + filename=str(row["filename"]), + date=_parse_datetime(row["date"]), + asn=str(row["asn"]), + isp=str(row["isp"]), + est_city=str(row["est_city"]), + user=str(row["user"]), + region=str(row["region"]), + server_fqdn=str(row["server_fqdn"]), + server_ip=str(row["server_ip"]), + mobile=_parse_bool(row["mobile"]), + tunnel=str(row["tunnel"]), + throughput_download=float(row["throughput_download"]), + throughput_upload=float(row["throughput_upload"]), + latency_download=float(row["latency_download"]), + latency_upload=float(row["latency_upload"]), + retransmission_download=float(row["retransmission_download"]), + retransmission_upload=float(row["retransmission_upload"]), + ping_packets_loss=float(row["ping_packets_loss"]), + ping_roundtrip_min=float(row["ping_roundtrip_min"]), + ping_roundtrip_avg=float(row["ping_roundtrip_avg"]), + ping_roundtrip_max=float(row["ping_roundtrip_max"]), + err_message=str(row["err_message"]).strip(), + protocol=str(row["PT"]), # rename from "PT" to "protocol" + endpoint_hostname=str(row["endpoint_hostname"]), + endpoint_address=str(row["endpoint_address"]), + endpoint_port=int(row["endpoint_port"]), + endpoint_asn=str(row["endpoint_asn"]), + endpoint_cc=str(row["endpoint_cc"]), + ping_target_address=str(row["ping_target_address"]), + ) + entries.append(measurement) + + except (ValueError, KeyError) as exc: + logging.warning(f"cannot import row: {exc}") + continue + + return entries + + +def stream(filename: str) -> Iterator[Entry]: + """Stream entries from CSV file one at a time.""" + # TODO(bassosimone): implement streaming, which is going to be more + # efficient than loading the whole file on memory, and then reimplement + # the `parse` function to use this function + raise NotImplementedError("streaming not yet implemented") diff --git a/oonisubmitter/model.py b/oonisubmitter/model.py index 7354cf2..0d12299 100644 --- a/oonisubmitter/model.py +++ b/oonisubmitter/model.py @@ -34,6 +34,7 @@ from typing import Dict, List, Optional import csv import logging +import fieldtestingcsv import oonireport @@ -45,62 +46,7 @@ class Scope(Enum): GLOBAL = "global" -@dataclass -class FieldTestingCSVEntry: - """ - Models a single field-testing entry read from the CSV datastore. - - The order of the fields in this dataclass it the same - of the fields within the CSV file. - """ - - # Fields originally present in the CSV file - # format as of 2024-12-06 - filename: str - date: datetime - asn: str - isp: str - est_city: str - user: str - region: str - server_fqdn: str - server_ip: str - mobile: bool - tunnel: str # 'baseline', 'tunnel', 'ERROR/baseline', 'ERROR/tunnel' - throughput_download: float - throughput_upload: float - latency_download: float - latency_upload: float - retransmission_download: float - retransmission_upload: float - ping_packets_loss: float - ping_roundtrip_min: float - ping_roundtrip_avg: float - ping_roundtrip_max: float - err_message: str - protocol: str - - # Fields added on 2024-12-06 to allow for exporting - # endpoint-level aggregate tunnel metrics. - # - # TODO(XXX): update the CSV file spec and generation. - endpoint_hostname: str - endpoint_address: str - endpoint_port: int - endpoint_asn: str - endpoint_cc: str - ping_target_address: str - - def is_tunnel_measurement(self) -> bool: - """ - Return whether this is a tunnel measurement, which includes both - successful and failed tunnel measurements. - """ - return self.tunnel in ("tunnel", "ERROR/tunnel") - - def is_tunnel_error_measurement(self) -> bool: - """Return whether this is a failed tunnel measurement""" - return self.tunnel == "ERROR/tunnel" +FieldTestingCSVEntry = fieldtestingcsv.Entry class FieldTestingCSVFile: @@ -109,6 +55,9 @@ class FieldTestingCSVFile: to its entries through the `entries` property. """ + # TODO(bassosimone): I don't like this class much because it's + # not immutable, so maybe we should refactor it away. + def __init__(self, filename: str): """ Initialize with CSV filename to load. @@ -119,41 +68,6 @@ class FieldTestingCSVFile: self.filename = filename self._entries: Optional[List[FieldTestingCSVEntry]] = None - def _parse_datetime(self, date_str: str) -> datetime: - """ - Parse ctime formatted date from CSV into datetime object. - - Example format: "Fri Dec 6 15:27:16 UTC 2024" - """ - # strptime directives: - # %a - Weekday name (Mon) - # %b - Month name (Nov) - # %d - Day of month (18) - # %H:%M:%S - Time (17:18:39) - # %Z - Timezone name (UTC) - # %Y - Year (2024) - dt = datetime.strptime(date_str, "%a %b %d %H:%M:%S %Z %Y") - - # For now, since we expect UTC, let's be strict - # - # TODO(bassosimone): do we need to care about non-UTC? - if "UTC" not in date_str: - raise ValueError(f"expected UTC timezone in date string, got: {date_str}") - - return dt.replace(tzinfo=timezone.utc) - - def _parse_bool(self, value: str) -> bool: - """ - Parse boolean string from CSV into bool. - - Args: - value: String representation of boolean ("true" or "false") - - Returns: - True if value.lower() == "true", False otherwise - """ - return value.lower() == "true" - def load(self) -> List[FieldTestingCSVEntry]: """ Loads and returns entries from CSV file. @@ -161,50 +75,7 @@ class FieldTestingCSVFile: Also caches the entries and makes them available through the `entries` property. """ - entries = [] - - with open(self.filename, "r") as f: - reader = csv.DictReader(f) - for row in reader: - try: - - measurement = FieldTestingCSVEntry( - filename=str(row["filename"]), - date=self._parse_datetime(row["date"]), - asn=str(row["asn"]), - isp=str(row["isp"]), - est_city=str(row["est_city"]), - user=str(row["user"]), - region=str(row["region"]), - server_fqdn=str(row["server_fqdn"]), - server_ip=str(row["server_ip"]), - mobile=self._parse_bool(row["mobile"]), - tunnel=str(row["tunnel"]), - throughput_download=float(row["throughput_download"]), - throughput_upload=float(row["throughput_upload"]), - latency_download=float(row["latency_download"]), - latency_upload=float(row["latency_upload"]), - retransmission_download=float(row["retransmission_download"]), - retransmission_upload=float(row["retransmission_upload"]), - ping_packets_loss=float(row["ping_packets_loss"]), - ping_roundtrip_min=float(row["ping_roundtrip_min"]), - ping_roundtrip_avg=float(row["ping_roundtrip_avg"]), - ping_roundtrip_max=float(row["ping_roundtrip_max"]), - err_message=str(row["err_message"]).strip(), - protocol=str(row["PT"]), # rename from "PT" to "protocol" - endpoint_hostname=str(row["endpoint_hostname"]), - endpoint_address=str(row["endpoint_address"]), - endpoint_port=int(row["endpoint_port"]), - endpoint_asn=str(row["endpoint_asn"]), - endpoint_cc=str(row["endpoint_cc"]), - ping_target_address=str(row["ping_target_address"]), - ) - entries.append(measurement) - - except (ValueError, KeyError) as exc: - logging.warning(f"cannot import row: {exc}") - continue - + entries = fieldtestingcsv.parse(self.filename) self._entries = entries return entries -- GitLab From 1c7476b029571bf48fe07638756ba40481424878 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 1 Feb 2025 17:16:17 +0100 Subject: [PATCH 30/75] refactor: create tunnelmetrics specific package --- oonisubmitter/aggregator.py | 280 +---------------------------------- oonisubmitter/identifiers.py | 74 +-------- oonisubmitter/model.py | 9 +- tunnelmetrics/__init__.py | 16 ++ tunnelmetrics/endpoint.py | 250 +++++++++++++++++++++++++++++++ tunnelmetrics/identifiers.py | 71 +++++++++ tunnelmetrics/model.py | 40 +++++ 7 files changed, 391 insertions(+), 349 deletions(-) create mode 100644 tunnelmetrics/__init__.py create mode 100644 tunnelmetrics/endpoint.py create mode 100644 tunnelmetrics/identifiers.py create mode 100644 tunnelmetrics/model.py diff --git a/oonisubmitter/aggregator.py b/oonisubmitter/aggregator.py index c6344b1..9855e5b 100644 --- a/oonisubmitter/aggregator.py +++ b/oonisubmitter/aggregator.py @@ -1,276 +1,12 @@ -""" -Logic for aggregating field testing measurements into the aggregate -tunnel metrics OONI-compatible data format. +"""Backward copatibility model for tunnelmetrics.""" -See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md -""" +# TODO(bassosimone): use tunnelmetrics directly -from dataclasses import dataclass, field -from datetime import datetime -from typing import Dict, List -from .identifiers import BridgeEndpointID, NDTServerID -from . import model +from tunnelmetrics import ( + AggregatorConfig, + AggregateEndpointState, + EndpointAggregator, +) - -@dataclass -class AggregatorConfig: - """ - Configuration for the measurement aggregator. - """ - - # TODO(bassosimone): consider renaming this class - - provider: str - upstream_collector: str - probe_asn: str - probe_cc: str - scope: model.Scope = ( - model.Scope.ENDPOINT - ) # for now we only care about endpoint scope - - # threshold below which we emit sample_size - min_sample_size: int = 1000 - - # rounding sample_size to the nearest round_to - round_to: int = 100 - - software_name: str = "leap/aggregate-tunnel-metrics" - software_version: str = "0.1.0" - - -@dataclass -class AggregateEndpointState: - """ - Flat representation of an endpoint's aggregated state. - - All fields needed for measurement generation should be here. - """ - - # Core identification - hostname: str - address: str - port: int - protocol: str - - # Classification - asn: str - cc: str - - # Context from config - provider: str - probe_asn: str - probe_cc: str - scope: model.Scope - - # Time window - window_start: datetime - window_end: datetime - - # Error tracking - # - # Statistics about successes and failures using - # empty string to represent success - errors: Dict[str, int] = field(default_factory=dict) - - # Ping statistics organised by target - # - # { - # "8.8.8.8": { - # "min": [...], - # "avg": [...], - # "max": [...], - # "loss": [...] - # } - # } - ping_measurements: Dict[str, Dict[str, List[float]]] = field(default_factory=dict) - - # NDT statistics organised by target - # - # { - # "server.fqdn:ip": { - # "download_throughput": [...], - # "download_latency": [...], - # "download_retransmission": [...], - # "upload_throughput": [...], - # "upload_latency": [...], - # "upload_retransmission": [...], - # } - # } - ndt_measurements: Dict[str, Dict[str, List[float]]] = field(default_factory=dict) - - def update_error_counts(self, entry: model.FieldTestingCSVEntry) -> None: - """Update error counts based on a new entry""" - error_type = ( - "bootstrap.generic_error" if entry.is_tunnel_error_measurement() else "" - ) - self.errors[error_type] = self.errors.get(error_type, 0) + 1 - - def update_performance_metrics(self, entry: model.FieldTestingCSVEntry) -> None: - """Update performance metrics based on a new entry""" - if not entry.is_tunnel_error_measurement(): # only successful measurements - self._update_ping(entry) - self._update_ndt(entry) - - def _update_ping(self, entry: model.FieldTestingCSVEntry) -> None: - """Unconditionally update the ping metrics.""" - ping_target = entry.ping_target_address - if ping_target not in self.ping_measurements: - self.ping_measurements[ping_target] = { - "min": [], - "avg": [], - "max": [], - "loss": [], - } - metrics = self.ping_measurements[ping_target] - metrics["min"].append(entry.ping_roundtrip_min) - metrics["avg"].append(entry.ping_roundtrip_avg) - metrics["max"].append(entry.ping_roundtrip_max) - metrics["loss"].append(entry.ping_packets_loss) - - def _update_ndt(self, entry: model.FieldTestingCSVEntry) -> None: - """Unconditionally update the NDT metrics.""" - ndt_target = str( - NDTServerID(hostname=entry.server_fqdn, address=entry.server_ip) - ) - if ndt_target not in self.ndt_measurements: - self.ndt_measurements[ndt_target] = { - "download_throughput": [], - "download_latency": [], - "download_retransmission": [], - "upload_throughput": [], - "upload_latency": [], - "upload_retransmission": [], - } - metrics = self.ndt_measurements[ndt_target] - metrics["download_throughput"].append(entry.throughput_download) - metrics["download_latency"].append(entry.latency_download) - metrics["download_retransmission"].append(entry.retransmission_download) - metrics["upload_throughput"].append(entry.throughput_upload) - metrics["upload_latency"].append(entry.latency_upload) - metrics["upload_retransmission"].append(entry.retransmission_upload) - - @classmethod - def from_csv_entry( - cls, - entry: model.FieldTestingCSVEntry, - config: AggregatorConfig, - window_start: datetime, - window_end: datetime, - ) -> "AggregateEndpointState": - return cls( - hostname=entry.endpoint_hostname, - address=entry.endpoint_address, - port=entry.endpoint_port, - protocol=entry.protocol, - asn=entry.endpoint_asn, - cc=entry.endpoint_cc, - provider=config.provider, - probe_asn=config.probe_asn, - probe_cc=config.probe_cc, - scope=config.scope, - window_start=window_start, - window_end=window_end, - ) - - def to_dict(self) -> Dict: - """Convert state to a JSON-serializable dictionary""" - return { - # Core identification - "hostname": self.hostname, - "address": self.address, - "port": self.port, - "protocol": self.protocol, - # Classification - "asn": self.asn, - "cc": self.cc, - # Context - "provider": self.provider, - "probe_asn": self.probe_asn, - "probe_cc": self.probe_cc, - "scope": self.scope.value, - # Time window - "window_start": model.datetime_to_compact_utc(self.window_start), - "window_end": model.datetime_to_compact_utc(self.window_end), - # Measurements - "errors": self.errors, - "ping_measurements": self.ping_measurements, - "ndt_measurements": self.ndt_measurements, - } - - -class EndpointAggregator: - """ - Maintains state for multiple endpoints. - """ - - def __init__( - self, config: AggregatorConfig, window_start: datetime, window_end: datetime - ): - self.config = config - self.window_start = window_start - self.window_end = window_end - self.endpoints: Dict[str, AggregateEndpointState] = {} - - def _make_key(self, entry: model.FieldTestingCSVEntry) -> str: - """Create unique key for an endpoint""" - return str( - BridgeEndpointID( - hostname=entry.endpoint_hostname, - address=entry.endpoint_address, - port=entry.endpoint_port, - protocol=entry.protocol, - ) - ) - - def _is_in_window(self, entry: model.FieldTestingCSVEntry) -> bool: - """Check if entry falls within our time window""" - return self.window_start <= entry.date < self.window_end - - def _is_tunnel_entry(self, entry: model.FieldTestingCSVEntry) -> bool: - """Check if entry is a tunnel measurement""" - return entry.is_tunnel_measurement() - - def update(self, entry: model.FieldTestingCSVEntry) -> None: - """ - Update aggregator state with a new measurement. - """ - # Exclude outside-window events as well as the - # events related to the baseline measurement for - # which we don't have spec support for now - if not self._is_in_window(entry): - return - if not self._is_tunnel_entry(entry): - return - - # Make sure we are tracking this endpoint - key = self._make_key(entry) - if key not in self.endpoints: - self.endpoints[key] = AggregateEndpointState.from_csv_entry( - entry, self.config, self.window_start, self.window_end - ) - - # Update the endpoint statistics - epnt = self.endpoints[key] - epnt.update_performance_metrics(entry) - epnt.update_error_counts(entry) - - def to_dict(self) -> Dict: - """Convert aggregator state to a JSON-serializable dictionary""" - return { - # Config - "config": { - "provider": self.config.provider, - "upstream_collector": self.config.upstream_collector, - "probe_asn": self.config.probe_asn, - "probe_cc": self.config.probe_cc, - "scope": self.config.scope.value, - }, - # Time window - "window_start": model.datetime_to_compact_utc(self.window_start), - "window_end": model.datetime_to_compact_utc(self.window_end), - # Endpoints state - "endpoints": { - key: state.to_dict() for key, state in self.endpoints.items() - }, - } +__all__ = ["AggregatorConfig", "AggregateEndpointState", "EndpointAggregator"] diff --git a/oonisubmitter/identifiers.py b/oonisubmitter/identifiers.py index 32120b7..6abc706 100644 --- a/oonisubmitter/identifiers.py +++ b/oonisubmitter/identifiers.py @@ -1,73 +1,5 @@ -""" -Common identifiers used across the pipeline. -""" +"""Backward copatibility model for tunnelmetrics.identifiers""" -from dataclasses import dataclass +# TODO(bassosimone): use tunnelmetrics.identifiers directly -SEPARATOR = "|" - - -@dataclass -class NDTServerID: - """ - NDT server identifier. - """ - - # The hostname used by an NDT server (e.g., `ndt.example.com`). - hostname: str - - # The NDT server's IP address (e.g., `1.2.3.4`). - address: str - - def __str__(self) -> str: - """ - String representation of the NDT server identifier. - """ - return f"{self.hostname}{SEPARATOR}{self.address}" - - @classmethod - def parse(cls, identifier: str) -> "NDTServerID": - """ - Parse identifier string into a NDTServerID. - """ - hostname, address = identifier.split(SEPARATOR, 1) - if hostname == "" or address == "": - raise ValueError("Empty fields in identifier") - return cls(hostname=hostname, address=address) - - -@dataclass -class BridgeEndpointID: - """ - Bridge endpoint identifier. - """ - - # hostname used by the bridge endpoint (e.g., `bridge.example.com`). - hostname: str - - # IP address of the bridge endpoint (e.g., `1.2.3.4`) - address: str - - # Port number used by the bridge endpoint (e.g., `443`). - port: int - - # Protocol used by the bridge endpoint (e.g., `obfs4+kcp`). - protocol: str - - def __str__(self) -> str: - """ - String representation for dictionary keys. - """ - return f"{self.hostname}{SEPARATOR}{self.address}{SEPARATOR}{self.port}{SEPARATOR}{self.protocol}" - - @classmethod - def parse(cls, identifier: str) -> "BridgeEndpointID": - """ - Parse identifier string into an Endpoint. - """ - hostname, address, port, protocol = identifier.split(SEPARATOR, 3) - if hostname == "" or address == "" or port == "" or protocol == "": - raise ValueError("Empty fields in identifier") - return cls( - hostname=hostname, address=address, port=int(port), protocol=protocol - ) +from tunnelmetrics.identifiers import NDTServerID, BridgeEndpointID diff --git a/oonisubmitter/model.py b/oonisubmitter/model.py index 0d12299..3718e58 100644 --- a/oonisubmitter/model.py +++ b/oonisubmitter/model.py @@ -34,16 +34,13 @@ from typing import Dict, List, Optional import csv import logging +from tunnelmetrics import model as modelx + import fieldtestingcsv import oonireport -class Scope(Enum): - """Valid scopes for aggregate tunnel metrics.""" - - ENDPOINT = "endpoint" - ENDPOINT_POOL = "endpoint_pool" - GLOBAL = "global" +Scope = modelx.Scope FieldTestingCSVEntry = fieldtestingcsv.Entry diff --git a/tunnelmetrics/__init__.py b/tunnelmetrics/__init__.py new file mode 100644 index 0000000..9b4e50f --- /dev/null +++ b/tunnelmetrics/__init__.py @@ -0,0 +1,16 @@ +""" +Tunnel Statistics Aggregation +============================ + +Library for aggregating tunnel performance statistics. + +See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md +""" + +# TODO(bassosimone): for Solitech we need to implement a more coarse +# grained aggregation approach than the endpoint aggregation. + +from .model import AggregatorConfig +from .endpoint import AggregateEndpointState, EndpointAggregator + +__all__ = ["AggregatorConfig", "AggregateEndpointState", "EndpointAggregator"] diff --git a/tunnelmetrics/endpoint.py b/tunnelmetrics/endpoint.py new file mode 100644 index 0000000..8880670 --- /dev/null +++ b/tunnelmetrics/endpoint.py @@ -0,0 +1,250 @@ +"""Rules to aggregate endpoints.""" + +from dataclasses import dataclass, field +from datetime import datetime +from typing import Dict, List + +import fieldtestingcsv + +from .identifiers import BridgeEndpointID, NDTServerID +from . import model + + +@dataclass +class AggregateEndpointState: + """ + Flat representation of an endpoint's aggregated state. + + All fields needed for measurement generation should be here. + """ + + # Core identification + hostname: str + address: str + port: int + protocol: str + + # Classification + asn: str + cc: str + + # Context from config + provider: str + probe_asn: str + probe_cc: str + scope: model.Scope + + # Time window + window_start: datetime + window_end: datetime + + # Error tracking + # + # Statistics about successes and failures using + # empty string to represent success + errors: Dict[str, int] = field(default_factory=dict) + + # Ping statistics organised by target + # + # { + # "8.8.8.8": { + # "min": [...], + # "avg": [...], + # "max": [...], + # "loss": [...] + # } + # } + ping_measurements: Dict[str, Dict[str, List[float]]] = field(default_factory=dict) + + # NDT statistics organised by target + # + # { + # "server.fqdn:ip": { + # "download_throughput": [...], + # "download_latency": [...], + # "download_retransmission": [...], + # "upload_throughput": [...], + # "upload_latency": [...], + # "upload_retransmission": [...], + # } + # } + ndt_measurements: Dict[str, Dict[str, List[float]]] = field(default_factory=dict) + + def update_error_counts(self, entry: fieldtestingcsv.Entry) -> None: + """Update error counts based on a new entry""" + error_type = ( + "bootstrap.generic_error" if entry.is_tunnel_error_measurement() else "" + ) + self.errors[error_type] = self.errors.get(error_type, 0) + 1 + + def update_performance_metrics(self, entry: fieldtestingcsv.Entry) -> None: + """Update performance metrics based on a new entry""" + if not entry.is_tunnel_error_measurement(): # only successful measurements + self._update_ping(entry) + self._update_ndt(entry) + + def _update_ping(self, entry: fieldtestingcsv.Entry) -> None: + """Unconditionally update the ping metrics.""" + ping_target = entry.ping_target_address + if ping_target not in self.ping_measurements: + self.ping_measurements[ping_target] = { + "min": [], + "avg": [], + "max": [], + "loss": [], + } + metrics = self.ping_measurements[ping_target] + metrics["min"].append(entry.ping_roundtrip_min) + metrics["avg"].append(entry.ping_roundtrip_avg) + metrics["max"].append(entry.ping_roundtrip_max) + metrics["loss"].append(entry.ping_packets_loss) + + def _update_ndt(self, entry: fieldtestingcsv.Entry) -> None: + """Unconditionally update the NDT metrics.""" + ndt_target = str( + NDTServerID(hostname=entry.server_fqdn, address=entry.server_ip) + ) + if ndt_target not in self.ndt_measurements: + self.ndt_measurements[ndt_target] = { + "download_throughput": [], + "download_latency": [], + "download_retransmission": [], + "upload_throughput": [], + "upload_latency": [], + "upload_retransmission": [], + } + metrics = self.ndt_measurements[ndt_target] + metrics["download_throughput"].append(entry.throughput_download) + metrics["download_latency"].append(entry.latency_download) + metrics["download_retransmission"].append(entry.retransmission_download) + metrics["upload_throughput"].append(entry.throughput_upload) + metrics["upload_latency"].append(entry.latency_upload) + metrics["upload_retransmission"].append(entry.retransmission_upload) + + @classmethod + def from_csv_entry( + cls, + entry: fieldtestingcsv.Entry, + config: model.AggregatorConfig, + window_start: datetime, + window_end: datetime, + ) -> "AggregateEndpointState": + return cls( + hostname=entry.endpoint_hostname, + address=entry.endpoint_address, + port=entry.endpoint_port, + protocol=entry.protocol, + asn=entry.endpoint_asn, + cc=entry.endpoint_cc, + provider=config.provider, + probe_asn=config.probe_asn, + probe_cc=config.probe_cc, + scope=config.scope, + window_start=window_start, + window_end=window_end, + ) + + def to_dict(self) -> Dict: + """Convert state to a JSON-serializable dictionary""" + return { + # Core identification + "hostname": self.hostname, + "address": self.address, + "port": self.port, + "protocol": self.protocol, + # Classification + "asn": self.asn, + "cc": self.cc, + # Context + "provider": self.provider, + "probe_asn": self.probe_asn, + "probe_cc": self.probe_cc, + "scope": self.scope.value, + # Time window + "window_start": model.datetime_to_compact_utc(self.window_start), + "window_end": model.datetime_to_compact_utc(self.window_end), + # Measurements + "errors": self.errors, + "ping_measurements": self.ping_measurements, + "ndt_measurements": self.ndt_measurements, + } + + +class EndpointAggregator: + """ + Maintains state for multiple endpoints. + """ + + def __init__( + self, + config: model.AggregatorConfig, + window_start: datetime, + window_end: datetime, + ): + self.config = config + self.window_start = window_start + self.window_end = window_end + self.endpoints: Dict[str, AggregateEndpointState] = {} + + def _make_key(self, entry: fieldtestingcsv.Entry) -> str: + """Create unique key for an endpoint""" + return str( + BridgeEndpointID( + hostname=entry.endpoint_hostname, + address=entry.endpoint_address, + port=entry.endpoint_port, + protocol=entry.protocol, + ) + ) + + def _is_in_window(self, entry: fieldtestingcsv.Entry) -> bool: + """Check if entry falls within our time window""" + return self.window_start <= entry.date < self.window_end + + def _is_tunnel_entry(self, entry: fieldtestingcsv.Entry) -> bool: + """Check if entry is a tunnel measurement""" + return entry.is_tunnel_measurement() + + def update(self, entry: fieldtestingcsv.Entry) -> None: + """ + Update aggregator state with a new measurement. + """ + # Exclude outside-window events as well as the + # events related to the baseline measurement for + # which we don't have spec support for now + if not self._is_in_window(entry): + return + if not self._is_tunnel_entry(entry): + return + + # Make sure we are tracking this endpoint + key = self._make_key(entry) + if key not in self.endpoints: + self.endpoints[key] = AggregateEndpointState.from_csv_entry( + entry, self.config, self.window_start, self.window_end + ) + + # Update the endpoint statistics + epnt = self.endpoints[key] + epnt.update_performance_metrics(entry) + epnt.update_error_counts(entry) + + def to_dict(self) -> Dict: + """Convert aggregator state to a JSON-serializable dictionary""" + return { + # Config + "config": { + "provider": self.config.provider, + "upstream_collector": self.config.upstream_collector, + "probe_asn": self.config.probe_asn, + "probe_cc": self.config.probe_cc, + "scope": self.config.scope.value, + }, + # Time window + "window_start": model.datetime_to_compact_utc(self.window_start), + "window_end": model.datetime_to_compact_utc(self.window_end), + # Endpoints state + "endpoints": { + key: state.to_dict() for key, state in self.endpoints.items() + }, + } diff --git a/tunnelmetrics/identifiers.py b/tunnelmetrics/identifiers.py new file mode 100644 index 0000000..f9f8cee --- /dev/null +++ b/tunnelmetrics/identifiers.py @@ -0,0 +1,71 @@ +"""Common identifiers used when aggregating tunnel metrics.""" + +from dataclasses import dataclass + +SEPARATOR = "|" + + +@dataclass +class NDTServerID: + """ + NDT server identifier. + """ + + # The hostname used by an NDT server (e.g., `ndt.example.com`). + hostname: str + + # The NDT server's IP address (e.g., `1.2.3.4`). + address: str + + def __str__(self) -> str: + """ + String representation of the NDT server identifier. + """ + return f"{self.hostname}{SEPARATOR}{self.address}" + + @classmethod + def parse(cls, identifier: str) -> "NDTServerID": + """ + Parse identifier string into a NDTServerID. + """ + hostname, address = identifier.split(SEPARATOR, 1) + if hostname == "" or address == "": + raise ValueError("Empty fields in identifier") + return cls(hostname=hostname, address=address) + + +@dataclass +class BridgeEndpointID: + """ + Bridge endpoint identifier. + """ + + # hostname used by the bridge endpoint (e.g., `bridge.example.com`). + hostname: str + + # IP address of the bridge endpoint (e.g., `1.2.3.4`) + address: str + + # Port number used by the bridge endpoint (e.g., `443`). + port: int + + # Protocol used by the bridge endpoint (e.g., `obfs4+kcp`). + protocol: str + + def __str__(self) -> str: + """ + String representation for dictionary keys. + """ + return f"{self.hostname}{SEPARATOR}{self.address}{SEPARATOR}{self.port}{SEPARATOR}{self.protocol}" + + @classmethod + def parse(cls, identifier: str) -> "BridgeEndpointID": + """ + Parse identifier string into an Endpoint. + """ + hostname, address, port, protocol = identifier.split(SEPARATOR, 3) + if hostname == "" or address == "" or port == "" or protocol == "": + raise ValueError("Empty fields in identifier") + return cls( + hostname=hostname, address=address, port=int(port), protocol=protocol + ) diff --git a/tunnelmetrics/model.py b/tunnelmetrics/model.py new file mode 100644 index 0000000..0e55c79 --- /dev/null +++ b/tunnelmetrics/model.py @@ -0,0 +1,40 @@ +"""Data model used by aggregation.""" + +from dataclasses import dataclass +from datetime import datetime, timezone +from enum import Enum + + +class Scope(Enum): + """Valid scopes for aggregate tunnel metrics.""" + + ENDPOINT = "endpoint" + ENDPOINT_POOL = "endpoint_pool" + GLOBAL = "global" + + +@dataclass +class AggregatorConfig: + """ + Configuration for the measurement aggregator. + """ + + provider: str + upstream_collector: str + probe_asn: str + probe_cc: str + scope: Scope = Scope.ENDPOINT # for now we only care about this + + # threshold below which we emit sample_size + min_sample_size: int = 1000 + + # rounding sample_size to the nearest round_to + round_to: int = 100 + + software_name: str = "leap/aggregate-tunnel-metrics" + software_version: str = "0.1.0" + + +def datetime_to_compact_utc(dt: datetime) -> str: + """Convert datetime to compact UTC format (YYYYMMDDThhmmssZ)""" + return dt.astimezone(timezone.utc).strftime("%Y%m%dT%H%M%SZ") -- GitLab From d02ab0a59df17591a4bd33ea040070c1524780be Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 1 Feb 2025 17:42:08 +0100 Subject: [PATCH 31/75] refactor: create ooniformat package --- ooniformat/__init__.py | 13 ++ ooniformat/serializer.py | 254 +++++++++++++++++++++++++++++++++++ ooniformat/testkeys.py | 68 ++++++++++ oonisubmitter/model.py | 68 +--------- oonisubmitter/serializer.py | 260 +----------------------------------- 5 files changed, 347 insertions(+), 316 deletions(-) create mode 100644 ooniformat/__init__.py create mode 100644 ooniformat/serializer.py create mode 100644 ooniformat/testkeys.py diff --git a/ooniformat/__init__.py b/ooniformat/__init__.py new file mode 100644 index 0000000..587511b --- /dev/null +++ b/ooniformat/__init__.py @@ -0,0 +1,13 @@ +""" +OONI Format Library +================== + +Library for formatting data according to OONI specifications. + +See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md. +""" + +from .serializer import Serializer +from .testkeys import TestKeys + +__all__ = ["Serializer", "TestKeys"] diff --git a/ooniformat/serializer.py b/ooniformat/serializer.py new file mode 100644 index 0000000..35f4017 --- /dev/null +++ b/ooniformat/serializer.py @@ -0,0 +1,254 @@ +"""Serialization implementation.""" + +from datetime import datetime +from statistics import quantiles +from typing import Any, Dict, List, Optional +from urllib.parse import urlunparse, urlencode + +from tunnelmetrics import AggregatorConfig, AggregateEndpointState +from tunnelmetrics.model import Scope +from tunnelmetrics.identifiers import NDTServerID + +from oonireport import Measurement + +from .testkeys import TestKeys, AggregationTimeWindow + + +class SerializationConfigError(Exception): + """Raised when serialization configuration does not allow us to proceed.""" + + +class Serializer: + """Converts aggregate endpoint state into OONI measurements""" + + def __init__(self, config: AggregatorConfig): + self.config = config + + @staticmethod + def _compute_percentiles(values: List[float]) -> Dict[str, float]: + """Compute the required percentiles for OONI format""" + + if not values: + return {} + + p25, p50, p75, p99 = quantiles(values, n=100, method="exclusive") + return { + "25p": round(p25, 1), + "50p": round(p50, 1), + "75p": round(p75, 1), + "99p": round(p99, 1), + } + + def _create_input_url(self, state: AggregateEndpointState) -> str: + """Create the measurement input URL""" + # Optionally include query + query = {} + if state.scope == Scope.ENDPOINT: + query = { + "address": state.address, + "asn": state.asn, + "hostname": state.hostname, + "port": str(state.port), + } + # Filter out None/empty values + query = {k: v for k, v in query.items() if v} + + # Build URL using urlunparse for safety + return urlunparse( + ( + state.protocol, # scheme (e.g., "openvpn+obfs4") + state.provider, # netloc (e.g., "riseup.net") + "/", # path + "", # params + urlencode(query), # query (e.g., "address=1.2.3.4&...") + "", # fragment + ) + ) + + def _round_sample_size(self, sample_size: int) -> Optional[int]: + """Round the sample size according to the aggregate tunnel metrics spec.""" + if sample_size < self.config.min_sample_size: + return None + return round(sample_size / self.config.round_to) * self.config.round_to + + @staticmethod + def _maybe_with_sample_size( + obj: Dict[str, Any], ss: Optional[int] + ) -> Dict[str, Any]: + if ss is not None: + obj["sample_size"] = ss + return obj + + def _create_error_bodies(self, state: AggregateEndpointState) -> List[Dict]: + """Create error bodies if there are any errors""" + bodies = [] + total = sum(state.errors.values()) + if total > 0: + for error_type, count in state.errors.items(): + if not error_type: # Skip success counts + continue + bodies.append( + self._maybe_with_sample_size( + { + "phase": "creation", + "type": "network-error", + "failure_ratio": round(count / total, 2), + "error": error_type, + }, + self._round_sample_size(count), + ) + ) + return bodies + + def _create_ping_bodies(self, state: AggregateEndpointState) -> List[Dict]: + """Create bodies for ping measurements""" + bodies = [] + for target_address, measurements in state.ping_measurements.items(): + # Min/Avg/Max latency bodies + for metric_type in ["min", "avg", "max"]: + if measurements[metric_type]: # Only if we have measurements + bodies.append( + self._maybe_with_sample_size( + { + "phase": "tunnel_ping", + "type": f"ping_{metric_type}", + "target_address": target_address, + "latency_ms": self._compute_percentiles( + measurements[metric_type] + ), + }, + self._round_sample_size(len(measurements[metric_type])), + ) + ) + + # Packet loss body + if measurements["loss"]: + bodies.append( + self._maybe_with_sample_size( + { + "phase": "tunnel_ping", + "type": "ping_loss", + "target_address": target_address, + "loss_percent": self._compute_percentiles( + measurements["loss"] + ), + }, + self._round_sample_size(len(measurements["loss"])), + ) + ) + + return bodies + + def _create_ndt_bodies(self, state: AggregateEndpointState) -> List[Dict]: + """Create bodies for NDT measurements""" + bodies = [] + for target_id, measurements in state.ndt_measurements.items(): + server = NDTServerID.parse(target_id) + + # Download measurements + if measurements["download_throughput"]: + bodies.append( + self._maybe_with_sample_size( + { + "phase": "tunnel_ndt_download", + "type": "ndt_download", + "target_hostname": server.hostname, + "target_address": server.address, + "target_port": 443, # TODO: Get actual port + "latency_ms": self._compute_percentiles( + measurements["download_latency"] + ), + "speed_mbits": self._compute_percentiles( + measurements["download_throughput"] + ), + "retransmission_percent": self._compute_percentiles( + measurements["download_retransmission"] + ), + }, + self._round_sample_size( + len(measurements["download_throughput"]) + ), + ) + ) + + # Upload measurements + if measurements["upload_throughput"]: + bodies.append( + self._maybe_with_sample_size( + { + "phase": "tunnel_ndt_upload", + "type": "ndt_upload", + "target_hostname": server.hostname, + "target_address": server.address, + "target_port": 443, # TODO(bassosimone): get actual port + "latency_ms": self._compute_percentiles( + measurements["upload_latency"] + ), + "speed_mbits": self._compute_percentiles( + measurements["upload_throughput"] + ), + "retransmission_percent": self._compute_percentiles( + measurements["upload_retransmission"] + ), + }, + self._round_sample_size(len(measurements["upload_throughput"])), + ) + ) + + return bodies + + def _create_bodies(self, state: AggregateEndpointState) -> List[Dict]: + """Create the bodies section of test_keys""" + bodies = [] + bodies.extend(self._create_error_bodies(state)) + bodies.extend(self._create_ping_bodies(state)) + bodies.extend(self._create_ndt_bodies(state)) + return bodies + + def serialize(self, state: AggregateEndpointState) -> Measurement: + """ + Convert endpoint state to OONI measurement format. + + Raises: + SerializationError: if the scope is not model.Scope.ENDPOINT. + """ + if state.scope != Scope.ENDPOINT: + raise SerializationConfigError( + f"cannot serialize measurement with scope '{state.scope}': " + "only 'endpoint' scope is currently supported" + ) + + measurement_time = datetime.utcnow() + + test_keys = TestKeys( + provider=state.provider, + scope=state.scope, + protocol=state.protocol, + time_window=AggregationTimeWindow( + from_time=state.window_start, to_time=state.window_end + ), + endpoint_hostname=( + state.hostname if state.scope == Scope.ENDPOINT else None + ), + endpoint_address=(state.address if state.scope == Scope.ENDPOINT else None), + endpoint_port=state.port if state.scope == Scope.ENDPOINT else None, + asn=state.asn, + cc=state.cc, + bodies=self._create_bodies(state), + ) + + return Measurement( + annotations={"upstream_collector": self.config.upstream_collector}, + data_format_version="0.2.0", + input=self._create_input_url(state), + measurement_start_time=measurement_time, + probe_asn=self.config.probe_asn, + probe_cc=self.config.probe_cc, + software_name=self.config.software_name, + software_version=self.config.software_version, + test_keys=test_keys, + test_name="aggregate_tunnel_metrics", + test_runtime=0.0, + test_start_time=measurement_time, + test_version="0.1.0", + ) diff --git a/ooniformat/testkeys.py b/ooniformat/testkeys.py new file mode 100644 index 0000000..ea576f2 --- /dev/null +++ b/ooniformat/testkeys.py @@ -0,0 +1,68 @@ +"""TestKeys for the aggregate tunnel metrics specification.""" + +from datetime import datetime +from dataclasses import dataclass +from typing import Dict, List, Optional + +from tunnelmetrics.model import Scope, datetime_to_compact_utc + + +@dataclass +class AggregationTimeWindow: + """Time window for aggregating measurements""" + + from_time: datetime + to_time: datetime + + def as_dict(self) -> Dict: + """Convert to JSON-serializable dict""" + return { + "from": datetime_to_compact_utc(self.from_time), + "to": datetime_to_compact_utc(self.to_time), + } + + +@dataclass +class TestKeys: + """ + Models the test_keys portion of an OONI measurement as defined + in the aggregate tunnel metrics specification. + """ + + provider: str + scope: Scope + protocol: str + time_window: AggregationTimeWindow + + # Optional fields depending on scope + endpoint_hostname: Optional[str] + endpoint_address: Optional[str] + endpoint_port: Optional[int] + asn: Optional[str] # Format: ^AS[0-9]+$ + cc: Optional[str] # Format: ^[A-Z]{2}$ + bodies: List[Dict] # TODO (bassosimone): we can make this more specific later + + def as_dict(self) -> Dict: + """Convert to JSON-serializable dict""" + # Start with required fields + d = { + "provider": self.provider, + "scope": self.scope, + "protocol": self.protocol, + "time_window": self.time_window.as_dict(), + "bodies": self.bodies, + } + + # Add optional fields if they exist + for field in [ + "endpoint_hostname", + "endpoint_address", + "endpoint_port", + "asn", + "cc", + ]: + value = getattr(self, field) + if value is not None: + d[field] = value + + return d diff --git a/oonisubmitter/model.py b/oonisubmitter/model.py index 3718e58..49fcadb 100644 --- a/oonisubmitter/model.py +++ b/oonisubmitter/model.py @@ -38,6 +38,9 @@ from tunnelmetrics import model as modelx import fieldtestingcsv import oonireport +import tunnelmetrics + +from ooniformat import testkeys Scope = modelx.Scope @@ -84,73 +87,14 @@ class FieldTestingCSVFile: return self._entries -def datetime_to_compact_utc(dt: datetime) -> str: - """Convert datetime to compact UTC format (YYYYMMDDThhmmssZ)""" - return dt.astimezone(timezone.utc).strftime("%Y%m%dT%H%M%SZ") +datetime_to_compact_utc = modelx.datetime_to_compact_utc datetime_to_ooni_format = oonireport.datetime_to_ooni_format -@dataclass -class AggregationTimeWindow: - """Time window for aggregating measurements""" - - from_time: datetime - to_time: datetime - - def as_dict(self) -> Dict: - """Convert to JSON-serializable dict""" - return { - "from": datetime_to_compact_utc(self.from_time), - "to": datetime_to_compact_utc(self.to_time), - } - - -@dataclass -class AggregateTunnelMetricsTestKeys: - """ - Models the test_keys portion of an OONI measurement as defined - in the aggregate tunnel metrics specification. - """ - - provider: str - scope: Scope - protocol: str - time_window: AggregationTimeWindow - - # Optional fields depending on scope - endpoint_hostname: Optional[str] - endpoint_address: Optional[str] - endpoint_port: Optional[int] - asn: Optional[str] # Format: ^AS[0-9]+$ - cc: Optional[str] # Format: ^[A-Z]{2}$ - bodies: List[Dict] # TODO (bassosimone): we can make this more specific later - - def as_dict(self) -> Dict: - """Convert to JSON-serializable dict""" - # Start with required fields - d = { - "provider": self.provider, - "scope": self.scope, - "protocol": self.protocol, - "time_window": self.time_window.as_dict(), - "bodies": self.bodies, - } - - # Add optional fields if they exist - for field in [ - "endpoint_hostname", - "endpoint_address", - "endpoint_port", - "asn", - "cc", - ]: - value = getattr(self, field) - if value is not None: - d[field] = value - - return d +AggregationTimeWindow = testkeys.AggregationTimeWindow +AggregateTunnelMetricsTestKeys = testkeys.TestKeys OONIMeasurement = oonireport.Measurement diff --git a/oonisubmitter/serializer.py b/oonisubmitter/serializer.py index ecf7c6a..cc5ebce 100644 --- a/oonisubmitter/serializer.py +++ b/oonisubmitter/serializer.py @@ -1,256 +1,8 @@ -""" -Serializes aggregated endpoint state into OONI measurements format. +"""Backward-compatibility for the ooniformat package.""" -See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md -""" +# TODO(bassosimone): refactor this module away -from datetime import datetime -from typing import Any, Dict, List, Optional -from statistics import quantiles -from urllib.parse import urlunparse, urlencode - -from .identifiers import NDTServerID -from .aggregator import AggregateEndpointState, AggregatorConfig -from . import model - - -class SerializationConfigError(Exception): - """Raised when serialization configuration does not allow us to proceed.""" - - -class OONISerializer: - """Converts aggregate endpoint state into OONI measurements""" - - def __init__(self, config: AggregatorConfig): - self.config = config - - @staticmethod - def _compute_percentiles(values: List[float]) -> Dict[str, float]: - """Compute the required percentiles for OONI format""" - - if not values: - return {} - - p25, p50, p75, p99 = quantiles(values, n=100, method="exclusive") - return { - "25p": round(p25, 1), - "50p": round(p50, 1), - "75p": round(p75, 1), - "99p": round(p99, 1), - } - - def _create_input_url(self, state: AggregateEndpointState) -> str: - """Create the measurement input URL""" - # Optionally include query - query = {} - if state.scope == model.Scope.ENDPOINT: - query = { - "address": state.address, - "asn": state.asn, - "hostname": state.hostname, - "port": str(state.port), - } - # Filter out None/empty values - query = {k: v for k, v in query.items() if v} - - # Build URL using urlunparse for safety - return urlunparse( - ( - state.protocol, # scheme (e.g., "openvpn+obfs4") - state.provider, # netloc (e.g., "riseup.net") - "/", # path - "", # params - urlencode(query), # query (e.g., "address=1.2.3.4&...") - "", # fragment - ) - ) - - def _round_sample_size(self, sample_size: int) -> Optional[int]: - """Round the sample size according to the aggregate tunnel metrics spec.""" - if sample_size < self.config.min_sample_size: - return None - return round(sample_size / self.config.round_to) * self.config.round_to - - @staticmethod - def _maybe_with_sample_size( - obj: Dict[str, Any], ss: Optional[int] - ) -> Dict[str, Any]: - if ss is not None: - obj["sample_size"] = ss - return obj - - def _create_error_bodies(self, state: AggregateEndpointState) -> List[Dict]: - """Create error bodies if there are any errors""" - bodies = [] - total = sum(state.errors.values()) - if total > 0: - for error_type, count in state.errors.items(): - if not error_type: # Skip success counts - continue - bodies.append( - self._maybe_with_sample_size( - { - "phase": "creation", - "type": "network-error", - "failure_ratio": round(count / total, 2), - "error": error_type, - }, - self._round_sample_size(count), - ) - ) - return bodies - - def _create_ping_bodies(self, state: AggregateEndpointState) -> List[Dict]: - """Create bodies for ping measurements""" - bodies = [] - for target_address, measurements in state.ping_measurements.items(): - # Min/Avg/Max latency bodies - for metric_type in ["min", "avg", "max"]: - if measurements[metric_type]: # Only if we have measurements - bodies.append( - self._maybe_with_sample_size( - { - "phase": "tunnel_ping", - "type": f"ping_{metric_type}", - "target_address": target_address, - "latency_ms": self._compute_percentiles( - measurements[metric_type] - ), - }, - self._round_sample_size(len(measurements[metric_type])), - ) - ) - - # Packet loss body - if measurements["loss"]: - bodies.append( - self._maybe_with_sample_size( - { - "phase": "tunnel_ping", - "type": "ping_loss", - "target_address": target_address, - "loss_percent": self._compute_percentiles( - measurements["loss"] - ), - }, - self._round_sample_size(len(measurements["loss"])), - ) - ) - - return bodies - - def _create_ndt_bodies(self, state: AggregateEndpointState) -> List[Dict]: - """Create bodies for NDT measurements""" - bodies = [] - for target_id, measurements in state.ndt_measurements.items(): - server = NDTServerID.parse(target_id) - - # Download measurements - if measurements["download_throughput"]: - bodies.append( - self._maybe_with_sample_size( - { - "phase": "tunnel_ndt_download", - "type": "ndt_download", - "target_hostname": server.hostname, - "target_address": server.address, - "target_port": 443, # TODO: Get actual port - "latency_ms": self._compute_percentiles( - measurements["download_latency"] - ), - "speed_mbits": self._compute_percentiles( - measurements["download_throughput"] - ), - "retransmission_percent": self._compute_percentiles( - measurements["download_retransmission"] - ), - }, - self._round_sample_size( - len(measurements["download_throughput"]) - ), - ) - ) - - # Upload measurements - if measurements["upload_throughput"]: - bodies.append( - self._maybe_with_sample_size( - { - "phase": "tunnel_ndt_upload", - "type": "ndt_upload", - "target_hostname": server.hostname, - "target_address": server.address, - "target_port": 443, # TODO(bassosimone): get actual port - "latency_ms": self._compute_percentiles( - measurements["upload_latency"] - ), - "speed_mbits": self._compute_percentiles( - measurements["upload_throughput"] - ), - "retransmission_percent": self._compute_percentiles( - measurements["upload_retransmission"] - ), - }, - self._round_sample_size(len(measurements["upload_throughput"])), - ) - ) - - return bodies - - def _create_bodies(self, state: AggregateEndpointState) -> List[Dict]: - """Create the bodies section of test_keys""" - bodies = [] - bodies.extend(self._create_error_bodies(state)) - bodies.extend(self._create_ping_bodies(state)) - bodies.extend(self._create_ndt_bodies(state)) - return bodies - - def serialize(self, state: AggregateEndpointState) -> model.OONIMeasurement: - """ - Convert endpoint state to OONI measurement format. - - Raises: - SerializationError: if the scope is not model.Scope.ENDPOINT. - """ - if state.scope != model.Scope.ENDPOINT: - raise SerializationConfigError( - f"cannot serialize measurement with scope '{state.scope}': " - "only 'endpoint' scope is currently supported" - ) - - measurement_time = datetime.utcnow() - - test_keys = model.AggregateTunnelMetricsTestKeys( - provider=state.provider, - scope=state.scope, - protocol=state.protocol, - time_window=model.AggregationTimeWindow( - from_time=state.window_start, to_time=state.window_end - ), - endpoint_hostname=( - state.hostname if state.scope == model.Scope.ENDPOINT else None - ), - endpoint_address=( - state.address if state.scope == model.Scope.ENDPOINT else None - ), - endpoint_port=state.port if state.scope == model.Scope.ENDPOINT else None, - asn=state.asn, - cc=state.cc, - bodies=self._create_bodies(state), - ) - - return model.OONIMeasurement( - annotations={"upstream_collector": self.config.upstream_collector}, - data_format_version="0.2.0", - input=self._create_input_url(state), - measurement_start_time=measurement_time, - probe_asn=self.config.probe_asn, - probe_cc=self.config.probe_cc, - software_name=self.config.software_name, - software_version=self.config.software_version, - test_keys=test_keys, - test_name="aggregate_tunnel_metrics", - test_runtime=0.0, - test_start_time=measurement_time, - test_version="0.1.0", - ) +from ooniformat.serializer import ( + SerializationConfigError, + Serializer as OONISerializer, +) -- GitLab From 0a747cc3c52686e2aafd1df2e7feb2f5d9e77369 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 1 Feb 2025 17:58:25 +0100 Subject: [PATCH 32/75] cleanup: remove duplicate ooniapi.py code --- oonireport/__init__.py | 2 + oonisubmitter/ooniapi.py | 159 ++------------------------------------- 2 files changed, 10 insertions(+), 151 deletions(-) diff --git a/oonireport/__init__.py b/oonireport/__init__.py index e7b1a15..e1f1b26 100644 --- a/oonireport/__init__.py +++ b/oonireport/__init__.py @@ -7,6 +7,8 @@ servers and submitting OONI measurements to them. See https://github.com/ooni/spec/blob/master/backends/bk-003-collector.md. +We implement the OONI collector protocol v3.0.0. + Classes: CollectorClient: A client for interacting with OONI collectors. CollectorConfig: Configuration for the `CollectorClient`. diff --git a/oonisubmitter/ooniapi.py b/oonisubmitter/ooniapi.py index f24f74e..41c0dc9 100644 --- a/oonisubmitter/ooniapi.py +++ b/oonisubmitter/ooniapi.py @@ -1,154 +1,11 @@ -""" -OONI API client for submitting measurements to collectors. +"""Backward compatibility for the oonireport package.""" -Implements the OONI collector protocol v3.0.0. +# TODO(bassosimone): remove this backward copatibility layer. -See: https://github.com/ooni/spec/blob/master/backends/bk-003-collector.md -""" +from oonireport import ( + APIError as OONIAPIError, + CollectorClient as Client, + CollectorConfig as Config, +) -from dataclasses import dataclass -from typing import Optional - -import json -import urllib.request -import urllib.error - -from . import model - - -class OONIAPIError(Exception): - """Raised when there are OONI API errors.""" - - -@dataclass -class Config: - """Configuration for OONI API client""" - - collector_base_url: str # e.g., "https://api.ooni.io/" - timeout: float = 30.0 - - -class Client: - """ - Client for interacting with OONI collectors implementing v3.0.0 - of the specification. - """ - - # TODO(bassosimone): consider support for retries. The spec - # says the following about retries: - # - # > A client side implementation MAY retry any failing collector - # > operation immediately for three times in case there is a - # > DNS or TCP error. [...] If all these immediate retries fail, - # > then the client SHOULD arrange for resubmitting the - # > measurement at a later time - - def __init__(self, config: Config): - self.config = config - - def create_report_from_measurement( - self, - measurement: model.OONIMeasurement, - ) -> str: - """Convenience method to create report from existing measurement.""" - return self.create_report( - test_name=measurement.test_name, - test_version=measurement.test_version, - software_name=measurement.software_name, - software_version=measurement.software_version, - probe_asn=measurement.probe_asn, - probe_cc=measurement.probe_cc, - ) - - def create_report( - self, - test_name: str, - test_version: str, - software_name: str, - software_version: str, - probe_asn: str, - probe_cc: str, - ) -> str: - """ - Create a new report. - - Returns: - str: Report ID to use for submitting measurements - - Raises: - OONIAPIError: If submission fails - URLError: For network/DNS issues - """ - report = { - "data_format_version": "0.2.0", - "format": "json", - "probe_asn": probe_asn, - "probe_cc": probe_cc, - "software_name": software_name, - "software_version": software_version, - "test_name": test_name, - "test_version": test_version, - } - - data = json.dumps(report).encode("utf-8") - req = urllib.request.Request( - f"{self.config.collector_base_url}/report", - data=data, - headers={"Content-Type": "application/json"}, - method="POST", - ) - - try: - with urllib.request.urlopen(req, timeout=self.config.timeout) as resp: - if resp.status != 200: - raise OONIAPIError(f"unexpected status: {resp.status}") - response = json.loads(resp.read().decode()) - if "report_id" not in response: - raise OONIAPIError("missing report_id in response") - return response["report_id"] - - except urllib.error.HTTPError as err: - raise OONIAPIError(f"HTTP error: {err}") - - def update_report( - self, - report_id: str, - measurement: model.OONIMeasurement, - ) -> Optional[str]: - """ - Update a report by adding a measurement. - - Args: - report_id: The ID returned by create_report() - measurement: The measurement to submit - - Returns: - Optional[str]: measurement_id if provided by server - - Raises: - OONIAPIError: If submission fails - URLError: For network/DNS issues - """ - data = json.dumps({"format": "json", "content": measurement.as_dict()}).encode( - "utf-8" - ) - - req = urllib.request.Request( - f"{self.config.collector_base_url}/report/{report_id}", - data=data, - headers={"Content-Type": "application/json"}, - method="POST", - ) - - try: - with urllib.request.urlopen(req, timeout=self.config.timeout) as resp: - if resp.status != 200: - raise OONIAPIError(f"unexpected status: {resp.status}") - try: - response = json.loads(resp.read().decode()) - return response.get("measurement_id") - except json.JSONDecodeError: - return None - - except urllib.error.HTTPError as err: - raise OONIAPIError(f"HTTP error: {err}") +__all__ = ["OONIAPIError", "Client", "Config"] -- GitLab From 9ca3c06f2369a7e6eacc31321ec3950a259e473d Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 1 Feb 2025 18:02:11 +0100 Subject: [PATCH 33/75] chore: ensure all files have license identifier --- ooniformat/__init__.py | 3 +++ ooniformat/serializer.py | 3 +++ ooniformat/testkeys.py | 3 +++ tunnelmetrics/__init__.py | 3 +++ tunnelmetrics/endpoint.py | 3 +++ tunnelmetrics/identifiers.py | 3 +++ tunnelmetrics/model.py | 3 +++ 7 files changed, 21 insertions(+) diff --git a/ooniformat/__init__.py b/ooniformat/__init__.py index 587511b..8275526 100644 --- a/ooniformat/__init__.py +++ b/ooniformat/__init__.py @@ -7,6 +7,9 @@ Library for formatting data according to OONI specifications. See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md. """ +# SPDX-License-Identifier: GPL-3.0-or-later + + from .serializer import Serializer from .testkeys import TestKeys diff --git a/ooniformat/serializer.py b/ooniformat/serializer.py index 35f4017..10c4bb4 100644 --- a/ooniformat/serializer.py +++ b/ooniformat/serializer.py @@ -1,5 +1,8 @@ """Serialization implementation.""" +# SPDX-License-Identifier: GPL-3.0-or-later + + from datetime import datetime from statistics import quantiles from typing import Any, Dict, List, Optional diff --git a/ooniformat/testkeys.py b/ooniformat/testkeys.py index ea576f2..69ceac9 100644 --- a/ooniformat/testkeys.py +++ b/ooniformat/testkeys.py @@ -1,5 +1,8 @@ """TestKeys for the aggregate tunnel metrics specification.""" +# SPDX-License-Identifier: GPL-3.0-or-later + + from datetime import datetime from dataclasses import dataclass from typing import Dict, List, Optional diff --git a/tunnelmetrics/__init__.py b/tunnelmetrics/__init__.py index 9b4e50f..8d2eec8 100644 --- a/tunnelmetrics/__init__.py +++ b/tunnelmetrics/__init__.py @@ -7,6 +7,9 @@ Library for aggregating tunnel performance statistics. See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md """ +# SPDX-License-Identifier: GPL-3.0-or-later + + # TODO(bassosimone): for Solitech we need to implement a more coarse # grained aggregation approach than the endpoint aggregation. diff --git a/tunnelmetrics/endpoint.py b/tunnelmetrics/endpoint.py index 8880670..d0754f8 100644 --- a/tunnelmetrics/endpoint.py +++ b/tunnelmetrics/endpoint.py @@ -1,5 +1,8 @@ """Rules to aggregate endpoints.""" +# SPDX-License-Identifier: GPL-3.0-or-later + + from dataclasses import dataclass, field from datetime import datetime from typing import Dict, List diff --git a/tunnelmetrics/identifiers.py b/tunnelmetrics/identifiers.py index f9f8cee..35f7bdf 100644 --- a/tunnelmetrics/identifiers.py +++ b/tunnelmetrics/identifiers.py @@ -1,5 +1,8 @@ """Common identifiers used when aggregating tunnel metrics.""" +# SPDX-License-Identifier: GPL-3.0-or-later + + from dataclasses import dataclass SEPARATOR = "|" diff --git a/tunnelmetrics/model.py b/tunnelmetrics/model.py index 0e55c79..ca8b8e4 100644 --- a/tunnelmetrics/model.py +++ b/tunnelmetrics/model.py @@ -1,5 +1,8 @@ """Data model used by aggregation.""" +# SPDX-License-Identifier: GPL-3.0-or-later + + from dataclasses import dataclass from datetime import datetime, timezone from enum import Enum -- GitLab From b9139277bb0d77fc3db07af186d35c87084dbfca Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Mon, 3 Feb 2025 21:00:14 +0100 Subject: [PATCH 34/75] fix: start adapting to global scope While there start adding unit tests as well. --- fieldtestingcsv/model.py | 9 +- fieldtestingcsv/parser.py | 5 - tests/fieldtestingcsv/__init__.py | 0 tests/fieldtestingcsv/test_parser.py | 150 +++++++++++++++++++++++++++ 4 files changed, 152 insertions(+), 12 deletions(-) create mode 100644 tests/fieldtestingcsv/__init__.py create mode 100644 tests/fieldtestingcsv/test_parser.py diff --git a/fieldtestingcsv/model.py b/fieldtestingcsv/model.py index 3f94bd8..4da841c 100644 --- a/fieldtestingcsv/model.py +++ b/fieldtestingcsv/model.py @@ -41,15 +41,10 @@ class Entry: err_message: str protocol: str - # Fields added on 2024-12-06 to allow for exporting - # endpoint-level aggregate tunnel metrics. + # Fields added on 2025-02-03 to allow for exporting + # the global aggregate tunnel metrics. # # TODO(XXX): update the CSV file spec and generation. - endpoint_hostname: str - endpoint_address: str - endpoint_port: int - endpoint_asn: str - endpoint_cc: str ping_target_address: str def is_tunnel_measurement(self) -> bool: diff --git a/fieldtestingcsv/parser.py b/fieldtestingcsv/parser.py index 8582eed..a2d9fbb 100644 --- a/fieldtestingcsv/parser.py +++ b/fieldtestingcsv/parser.py @@ -81,11 +81,6 @@ def parse(filename: str) -> List[Entry]: ping_roundtrip_max=float(row["ping_roundtrip_max"]), err_message=str(row["err_message"]).strip(), protocol=str(row["PT"]), # rename from "PT" to "protocol" - endpoint_hostname=str(row["endpoint_hostname"]), - endpoint_address=str(row["endpoint_address"]), - endpoint_port=int(row["endpoint_port"]), - endpoint_asn=str(row["endpoint_asn"]), - endpoint_cc=str(row["endpoint_cc"]), ping_target_address=str(row["ping_target_address"]), ) entries.append(measurement) diff --git a/tests/fieldtestingcsv/__init__.py b/tests/fieldtestingcsv/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/fieldtestingcsv/test_parser.py b/tests/fieldtestingcsv/test_parser.py new file mode 100644 index 0000000..08291e2 --- /dev/null +++ b/tests/fieldtestingcsv/test_parser.py @@ -0,0 +1,150 @@ +import unittest +from datetime import datetime, timezone +import tempfile +import os + +from fieldtestingcsv import parse +from fieldtestingcsv.model import Entry + + +class TestFieldTestingCSVParser(unittest.TestCase): + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.csv_path = os.path.join(self.temp_dir, "test.csv") + + def tearDown(self): + try: + os.unlink(self.csv_path) + os.rmdir(self.temp_dir) + except FileNotFoundError: + pass + + def write_csv(self, content: str): + """Helper to write CSV content to temp file""" + with open(self.csv_path, "w") as f: + f.write(content) + + def test_parse_valid_entry(self): + """Test parsing a valid CSV row""" + csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT,ping_target_address +test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4,8.8.8.8""" + + self.write_csv(csv_content) + entries = parse(self.csv_path) + + self.assertEqual(len(entries), 1) + entry = entries[0] + + self.assertEqual(entry.filename, "test.csv") + self.assertEqual(entry.asn, "AS12345") + self.assertEqual(entry.protocol, "obfs4") # Note: mapped from PT + self.assertEqual(entry.ping_target_address, "8.8.8.8") + self.assertEqual(entry.throughput_download, 100.0) + self.assertEqual(entry.mobile, False) + + # Check datetime parsing + expected_dt = datetime(2024, 1, 1, 12, 0, tzinfo=timezone.utc) + self.assertEqual(entry.date, expected_dt) + + def test_parse_invalid_date(self): + """Test handling invalid date formats""" + csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT,ping_target_address +test.csv,invalid date,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4,8.8.8.8""" + + self.write_csv(csv_content) + entries = parse(self.csv_path) + self.assertEqual(len(entries), 0) + + def test_parse_invalid_numeric(self): + """Test handling invalid numeric values""" + csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT,ping_target_address +test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,invalid,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4,8.8.8.8""" + + self.write_csv(csv_content) + entries = parse(self.csv_path) + self.assertEqual(len(entries), 0) + + def test_missing_required_field(self): + """Test handling missing required fields""" + # Create CSV with missing ping_target_address + csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT +test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4""" + + self.write_csv(csv_content) + entries = parse(self.csv_path) + self.assertEqual(len(entries), 0) + + def test_is_tunnel_measurement(self): + """Test is_tunnel_measurement() method""" + entry = Entry( + filename="test.csv", + date=datetime.now(timezone.utc), + asn="AS12345", + isp="TestISP", + est_city="TestCity", + user="user1", + region="TestRegion", + server_fqdn="ndt.example.com", + server_ip="1.2.3.4", + mobile=False, + tunnel="tunnel", + throughput_download=100.0, + throughput_upload=50.0, + latency_download=20.0, + latency_upload=25.0, + retransmission_download=0.01, + retransmission_upload=0.02, + ping_packets_loss=0.0, + ping_roundtrip_min=10.0, + ping_roundtrip_avg=12.0, + ping_roundtrip_max=15.0, + err_message="", + protocol="obfs4", + ping_target_address="8.8.8.8", + ) + + self.assertTrue(entry.is_tunnel_measurement()) + + entry.tunnel = "baseline" + self.assertFalse(entry.is_tunnel_measurement()) + + entry.tunnel = "ERROR/tunnel" + self.assertTrue(entry.is_tunnel_measurement()) + + def test_is_tunnel_error_measurement(self): + """Test is_tunnel_error_measurement() method""" + entry = Entry( + filename="test.csv", + date=datetime.now(timezone.utc), + asn="AS12345", + isp="TestISP", + est_city="TestCity", + user="user1", + region="TestRegion", + server_fqdn="ndt.example.com", + server_ip="1.2.3.4", + mobile=False, + tunnel="ERROR/tunnel", + throughput_download=100.0, + throughput_upload=50.0, + latency_download=20.0, + latency_upload=25.0, + retransmission_download=0.01, + retransmission_upload=0.02, + ping_packets_loss=0.0, + ping_roundtrip_min=10.0, + ping_roundtrip_avg=12.0, + ping_roundtrip_max=15.0, + err_message="", + protocol="obfs4", + ping_target_address="8.8.8.8", + ) + + self.assertTrue(entry.is_tunnel_error_measurement()) + + entry.tunnel = "tunnel" + self.assertFalse(entry.is_tunnel_error_measurement()) + + +if __name__ == "__main__": + unittest.main() -- GitLab From 11187d6d11ca3615147bb6e5fef555144de33955 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Tue, 4 Feb 2025 21:23:02 +0100 Subject: [PATCH 35/75] feat(lockedfile): write tests for read and write --- tests/lockedfile/__init__.py | 0 tests/lockedfile/fileio.py | 98 ++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 tests/lockedfile/__init__.py create mode 100644 tests/lockedfile/fileio.py diff --git a/tests/lockedfile/__init__.py b/tests/lockedfile/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/lockedfile/fileio.py b/tests/lockedfile/fileio.py new file mode 100644 index 0000000..a5da0d2 --- /dev/null +++ b/tests/lockedfile/fileio.py @@ -0,0 +1,98 @@ +"""Tests for the lockedfile/fileio.py functionality.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Set + +import multiprocessing as mp +import os +import tempfile +import time +import unittest + +from lockedfile import common, fileio + + +class TestFileIO(unittest.TestCase): + """Tests for the lockedfile/fileio.py functionality.""" + + def setUp(self): + self.temp_fd, self.temp_path = tempfile.mkstemp() + + def tearDown(self): + os.close(self.temp_fd) + os.unlink(self.temp_path) + + def test_concurrent_access(self): + """ + The purpose of this test is to spawn many readers and + writers using background processes, to ensure that, + regardless of concurrent access attempts, we are always + able to write to the file, and read from the file, + consistent strings. With enough repetitions, if there + are actual race conditions, the file content will + possibly eventually be corrupted and we will notice. + """ + + # TODO(bassosimone): ensure we have some cases of contention + # by collecting statistics about retries. + + valid_contents = ["foo", "bar", "baz", "qux"] + results_queue = mp.Queue() + should_stop = mp.Event() + config = fileio.FileIOConfig(num_retries=30, sleep_interval=0.1) + + def writer_process(): + while not should_stop.is_set(): + for content in valid_contents: + fileio.write(self.temp_path, content, config=config) + + writers = [ + mp.Process(target=writer_process) + for _ in range(4) + ] + + def reader_process(): + while not should_stop.is_set(): + content = fileio.read(self.temp_path, config=config) + if content: + results_queue.put(content) + + readers = [ + mp.Process(target=reader_process) + for _ in range(8) + ] + + # Start all processes + for p in readers + writers: + p.start() + + # Allow the processes to run for a while + time.sleep(1) + + # Interrupt the processes + should_stop.set() + + # Wait for processes to terminate + for p in readers + writers: + p.join() + + # Collect the results from the queue + observed_contents = set() + while not results_queue.empty(): + value = results_queue.get() + observed_contents.add(value) + + # Ensure we never read garbled data + self.assertTrue( + len(observed_contents) > 0, + "No data was read", + ) + self.assertTrue( + observed_contents.issubset(valid_contents), + f"Observed contents: {observed_contents}" + ) + + +if __name__ == "__main__": + unittest.main() -- GitLab From 3bdefe46c85f106928b42e20b7a6933ad2ea9f3f Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 8 Feb 2025 12:51:53 +0100 Subject: [PATCH 36/75] fix(oonireport): repair unit tests --- oonireport/model.py | 2 +- tests/oonireport/test_collector.py | 12 ++++++---- tests/oonireport/test_load.py | 38 ++++++++++++++++++++++++++++++ tests/oonireport/test_model.py | 11 +++++---- 4 files changed, 54 insertions(+), 9 deletions(-) diff --git a/oonireport/model.py b/oonireport/model.py index 8393b4e..99faec9 100644 --- a/oonireport/model.py +++ b/oonireport/model.py @@ -18,7 +18,7 @@ class APIError(Exception): class TestKeys(Protocol): """Models the OONI measurement test keys.""" - def as_dict(self) -> Dict: + def as_dict(self) -> Dict[str, Any]: """Converts the test keys to a JSON-serializable dict.""" ... diff --git a/tests/oonireport/test_collector.py b/tests/oonireport/test_collector.py index c3b6060..eade928 100644 --- a/tests/oonireport/test_collector.py +++ b/tests/oonireport/test_collector.py @@ -40,6 +40,7 @@ class TestCollectorClient(unittest.TestCase): collector_base_url="https://example.org", timeout=30.0 ) self.client = CollectorClient(self.config) + self.dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) @patch("urllib.request.urlopen") def test_create_report_success(self, mock_urlopen): @@ -52,6 +53,7 @@ class TestCollectorClient(unittest.TestCase): software_version="3.0.0", probe_asn="AS12345", probe_cc="IT", + test_start_time=self.dt, ) self.assertEqual(report_id, "test_report_id") @@ -68,6 +70,7 @@ class TestCollectorClient(unittest.TestCase): software_version="3.0.0", probe_asn="AS12345", probe_cc="IT", + test_start_time=self.dt, ) self.assertIn("HTTP error", str(cm.exception)) @@ -86,6 +89,7 @@ class TestCollectorClient(unittest.TestCase): software_version="3.0.0", probe_asn="AS12345", probe_cc="IT", + test_start_time=self.dt, ) self.assertIn("missing report_id", str(cm.exception)) @@ -93,14 +97,14 @@ class TestCollectorClient(unittest.TestCase): @patch("urllib.request.urlopen") def test_update_report_success(self, mock_urlopen): mock_urlopen.return_value = MockResponse( - 200, {"measurement_id": "test_measurement_id"} + 200, {"measurement_uid": "test_measurement_uid"} ) measurement = Measurement( annotations={}, data_format_version="0.2.0", input="https://example.com", - measurement_start_time=datetime.now(timezone.utc), + measurement_start_time=self.dt, probe_asn="AS12345", probe_cc="IT", software_name="ooniprobe", @@ -114,7 +118,7 @@ class TestCollectorClient(unittest.TestCase): measurement_id = self.client.update_report("test_report_id", measurement) - self.assertEqual(measurement_id, "test_measurement_id") + self.assertEqual(measurement_id, "test_measurement_uid") @patch("urllib.request.urlopen") def test_update_report_http_error(self, mock_urlopen): @@ -124,7 +128,7 @@ class TestCollectorClient(unittest.TestCase): annotations={}, data_format_version="0.2.0", input="https://example.com", - measurement_start_time=datetime.now(timezone.utc), + measurement_start_time=self.dt, probe_asn="AS12345", probe_cc="IT", software_name="ooniprobe", diff --git a/tests/oonireport/test_load.py b/tests/oonireport/test_load.py index d166d93..a1c4505 100644 --- a/tests/oonireport/test_load.py +++ b/tests/oonireport/test_load.py @@ -35,5 +35,43 @@ class TestMeasurementLoading(unittest.TestCase): try: measurements = load_measurements(filep.name) self.assertEqual(len(measurements), 2) + for m in measurements: + dm = m.as_dict() + self.assertEqual(dm["data_format_version"], "0.2.0") + self.assertEqual(dm["input"], "https://example.com/") + self.assertEqual(dm["measurement_start_time"], "2023-01-01 12:00:00") + self.assertEqual(dm["probe_asn"], "AS12345") + self.assertEqual(dm["probe_cc"], "IT") + self.assertEqual(dm["probe_ip"], "127.0.0.1") + self.assertEqual(dm["report_id"], "") + self.assertEqual(dm["software_name"], "ooniprobe") + self.assertEqual(dm["software_version"], "3.0.0") + self.assertEqual(dm["test_keys"], {"simple": "test"}) + self.assertEqual(dm["test_name"], "web_connectivity") + self.assertEqual(dm["test_runtime"], 1.0) + self.assertEqual(dm["test_start_time"], "2023-01-01 12:00:00") + self.assertEqual(dm["test_version"], "0.0.1") + self.assertEqual( + set(dm.keys()), + set( + [ + "annotations", + "data_format_version", + "input", + "measurement_start_time", + "probe_asn", + "probe_cc", + "probe_ip", + "report_id", + "software_name", + "software_version", + "test_keys", + "test_name", + "test_runtime", + "test_start_time", + "test_version", + ] + ), + ) finally: os.unlink(filep.name) diff --git a/tests/oonireport/test_model.py b/tests/oonireport/test_model.py index a2ede64..d825207 100644 --- a/tests/oonireport/test_model.py +++ b/tests/oonireport/test_model.py @@ -42,21 +42,24 @@ class TestModel(unittest.TestCase): self.assertEqual(data["measurement_start_time"], "2023-01-01 12:00:00") self.assertEqual(data["probe_asn"], "AS12345") self.assertEqual(data["probe_cc"], "IT") + self.assertEqual(data["software_name"], "ooniprobe") + self.assertEqual(data["software_version"], "3.0.0") self.assertEqual(data["test_keys"], {"simple": "test"}) self.assertEqual(data["test_name"], "web_connectivity") self.assertEqual(data["test_runtime"], 1.0) self.assertEqual(data["test_start_time"], "2023-01-01 12:00:00") self.assertEqual(data["test_version"], "0.0.1") - def test_datetime_to_ooni_format(self): + def test_datetime_to_ooni_format_utc(self): dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) formatted = datetime_to_ooni_format(dt) self.assertEqual(formatted, "2023-01-01 12:00:00") - # Test timezone conversion - dt = datetime(2023, 1, 1, 14, 0, tzinfo=timezone(timedelta(hours=2))) + def test_datetime_to_ooni_format_timezone(self): + # Ensure we're correctly converting when there's a timezone + dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone(timedelta(hours=2))) formatted = datetime_to_ooni_format(dt) - self.assertEqual(formatted, "2023-01-01 12:00:00") + self.assertEqual(formatted, "2023-01-01 10:00:00") if __name__ == "__main__": -- GitLab From 5e0343967c1ab2902f7512ed53dd172bedbc233d Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 8 Feb 2025 13:21:09 +0100 Subject: [PATCH 37/75] fix: use solitech/ instead of leap/ as prefix --- tunnelmetrics/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tunnelmetrics/model.py b/tunnelmetrics/model.py index ca8b8e4..5a34563 100644 --- a/tunnelmetrics/model.py +++ b/tunnelmetrics/model.py @@ -34,7 +34,7 @@ class AggregatorConfig: # rounding sample_size to the nearest round_to round_to: int = 100 - software_name: str = "leap/aggregate-tunnel-metrics" + software_name: str = "solitech/aggregate-tunnel-metrics" software_version: str = "0.1.0" -- GitLab From 054872ebcf218396ace23c4c912d6c6b383ae2d4 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 8 Feb 2025 13:21:39 +0100 Subject: [PATCH 38/75] fix: software_name cannot contain a `/` --- tunnelmetrics/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tunnelmetrics/model.py b/tunnelmetrics/model.py index 5a34563..ff935f4 100644 --- a/tunnelmetrics/model.py +++ b/tunnelmetrics/model.py @@ -34,7 +34,7 @@ class AggregatorConfig: # rounding sample_size to the nearest round_to round_to: int = 100 - software_name: str = "solitech/aggregate-tunnel-metrics" + software_name: str = "solitech-aggregate-tunnel-metrics" software_version: str = "0.1.0" -- GitLab From 72da8f0800ab25c7437a5bfebd4a8a3b0765cb72 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 8 Feb 2025 13:29:24 +0100 Subject: [PATCH 39/75] refactor: move code I wrote into an umbrella package The intent here is to clearly separate that which pertains to aggregate tunnel metrics and the rest of the codebase. Further refactorings may happen in the future, but right now it seems to make sense to clearly indicate boundaries. --- aggregatetunnelmetrics/__init__.py | 1 + .../fieldtestingcsv}/__init__.py | 0 .../fieldtestingcsv}/model.py | 0 .../fieldtestingcsv}/parser.py | 0 .../lockedfile}/__init__.py | 0 .../lockedfile}/common.py | 0 .../lockedfile}/fileio.py | 0 .../lockedfile}/mutex.py | 0 .../oonireport}/__init__.py | 0 .../oonireport}/__main__.py | 0 .../oonireport}/collector.py | 0 .../oonireport}/load.py | 0 .../oonireport}/model.py | 0 .../fieldtestingcsv/__init__.py | 0 .../fieldtestingcsv/test_parser.py | 3 +-- .../lockedfile/__init__.py | 0 .../lockedfile/fileio.py | 14 ++++---------- .../oonireport/__init__.py | 0 .../oonireport/test_collector.py | 8 ++++++-- .../oonireport/test_load.py | 2 +- .../oonireport/test_model.py | 2 +- 21 files changed, 14 insertions(+), 16 deletions(-) create mode 100644 aggregatetunnelmetrics/__init__.py rename {fieldtestingcsv => aggregatetunnelmetrics/fieldtestingcsv}/__init__.py (100%) rename {fieldtestingcsv => aggregatetunnelmetrics/fieldtestingcsv}/model.py (100%) rename {fieldtestingcsv => aggregatetunnelmetrics/fieldtestingcsv}/parser.py (100%) rename {lockedfile => aggregatetunnelmetrics/lockedfile}/__init__.py (100%) rename {lockedfile => aggregatetunnelmetrics/lockedfile}/common.py (100%) rename {lockedfile => aggregatetunnelmetrics/lockedfile}/fileio.py (100%) rename {lockedfile => aggregatetunnelmetrics/lockedfile}/mutex.py (100%) rename {oonireport => aggregatetunnelmetrics/oonireport}/__init__.py (100%) rename {oonireport => aggregatetunnelmetrics/oonireport}/__main__.py (100%) rename {oonireport => aggregatetunnelmetrics/oonireport}/collector.py (100%) rename {oonireport => aggregatetunnelmetrics/oonireport}/load.py (100%) rename {oonireport => aggregatetunnelmetrics/oonireport}/model.py (100%) rename tests/{ => aggregatetunnelmetrics}/fieldtestingcsv/__init__.py (100%) rename tests/{ => aggregatetunnelmetrics}/fieldtestingcsv/test_parser.py (98%) rename tests/{ => aggregatetunnelmetrics}/lockedfile/__init__.py (100%) rename tests/{ => aggregatetunnelmetrics}/lockedfile/fileio.py (89%) rename tests/{ => aggregatetunnelmetrics}/oonireport/__init__.py (100%) rename tests/{ => aggregatetunnelmetrics}/oonireport/test_collector.py (97%) rename tests/{ => aggregatetunnelmetrics}/oonireport/test_load.py (97%) rename tests/{ => aggregatetunnelmetrics}/oonireport/test_model.py (96%) diff --git a/aggregatetunnelmetrics/__init__.py b/aggregatetunnelmetrics/__init__.py new file mode 100644 index 0000000..2d6d74a --- /dev/null +++ b/aggregatetunnelmetrics/__init__.py @@ -0,0 +1 @@ +"""Contains package for the aggregate tunnel metrics functionality.""" diff --git a/fieldtestingcsv/__init__.py b/aggregatetunnelmetrics/fieldtestingcsv/__init__.py similarity index 100% rename from fieldtestingcsv/__init__.py rename to aggregatetunnelmetrics/fieldtestingcsv/__init__.py diff --git a/fieldtestingcsv/model.py b/aggregatetunnelmetrics/fieldtestingcsv/model.py similarity index 100% rename from fieldtestingcsv/model.py rename to aggregatetunnelmetrics/fieldtestingcsv/model.py diff --git a/fieldtestingcsv/parser.py b/aggregatetunnelmetrics/fieldtestingcsv/parser.py similarity index 100% rename from fieldtestingcsv/parser.py rename to aggregatetunnelmetrics/fieldtestingcsv/parser.py diff --git a/lockedfile/__init__.py b/aggregatetunnelmetrics/lockedfile/__init__.py similarity index 100% rename from lockedfile/__init__.py rename to aggregatetunnelmetrics/lockedfile/__init__.py diff --git a/lockedfile/common.py b/aggregatetunnelmetrics/lockedfile/common.py similarity index 100% rename from lockedfile/common.py rename to aggregatetunnelmetrics/lockedfile/common.py diff --git a/lockedfile/fileio.py b/aggregatetunnelmetrics/lockedfile/fileio.py similarity index 100% rename from lockedfile/fileio.py rename to aggregatetunnelmetrics/lockedfile/fileio.py diff --git a/lockedfile/mutex.py b/aggregatetunnelmetrics/lockedfile/mutex.py similarity index 100% rename from lockedfile/mutex.py rename to aggregatetunnelmetrics/lockedfile/mutex.py diff --git a/oonireport/__init__.py b/aggregatetunnelmetrics/oonireport/__init__.py similarity index 100% rename from oonireport/__init__.py rename to aggregatetunnelmetrics/oonireport/__init__.py diff --git a/oonireport/__main__.py b/aggregatetunnelmetrics/oonireport/__main__.py similarity index 100% rename from oonireport/__main__.py rename to aggregatetunnelmetrics/oonireport/__main__.py diff --git a/oonireport/collector.py b/aggregatetunnelmetrics/oonireport/collector.py similarity index 100% rename from oonireport/collector.py rename to aggregatetunnelmetrics/oonireport/collector.py diff --git a/oonireport/load.py b/aggregatetunnelmetrics/oonireport/load.py similarity index 100% rename from oonireport/load.py rename to aggregatetunnelmetrics/oonireport/load.py diff --git a/oonireport/model.py b/aggregatetunnelmetrics/oonireport/model.py similarity index 100% rename from oonireport/model.py rename to aggregatetunnelmetrics/oonireport/model.py diff --git a/tests/fieldtestingcsv/__init__.py b/tests/aggregatetunnelmetrics/fieldtestingcsv/__init__.py similarity index 100% rename from tests/fieldtestingcsv/__init__.py rename to tests/aggregatetunnelmetrics/fieldtestingcsv/__init__.py diff --git a/tests/fieldtestingcsv/test_parser.py b/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py similarity index 98% rename from tests/fieldtestingcsv/test_parser.py rename to tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py index 08291e2..e1b8a3b 100644 --- a/tests/fieldtestingcsv/test_parser.py +++ b/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py @@ -3,8 +3,7 @@ from datetime import datetime, timezone import tempfile import os -from fieldtestingcsv import parse -from fieldtestingcsv.model import Entry +from aggregatetunnelmetrics.fieldtestingcsv import Entry, parse class TestFieldTestingCSVParser(unittest.TestCase): diff --git a/tests/lockedfile/__init__.py b/tests/aggregatetunnelmetrics/lockedfile/__init__.py similarity index 100% rename from tests/lockedfile/__init__.py rename to tests/aggregatetunnelmetrics/lockedfile/__init__.py diff --git a/tests/lockedfile/fileio.py b/tests/aggregatetunnelmetrics/lockedfile/fileio.py similarity index 89% rename from tests/lockedfile/fileio.py rename to tests/aggregatetunnelmetrics/lockedfile/fileio.py index a5da0d2..7231cb5 100644 --- a/tests/lockedfile/fileio.py +++ b/tests/aggregatetunnelmetrics/lockedfile/fileio.py @@ -10,7 +10,7 @@ import tempfile import time import unittest -from lockedfile import common, fileio +from aggregatetunnelmetrics.lockedfile import common, fileio class TestFileIO(unittest.TestCase): @@ -47,10 +47,7 @@ class TestFileIO(unittest.TestCase): for content in valid_contents: fileio.write(self.temp_path, content, config=config) - writers = [ - mp.Process(target=writer_process) - for _ in range(4) - ] + writers = [mp.Process(target=writer_process) for _ in range(4)] def reader_process(): while not should_stop.is_set(): @@ -58,10 +55,7 @@ class TestFileIO(unittest.TestCase): if content: results_queue.put(content) - readers = [ - mp.Process(target=reader_process) - for _ in range(8) - ] + readers = [mp.Process(target=reader_process) for _ in range(8)] # Start all processes for p in readers + writers: @@ -90,7 +84,7 @@ class TestFileIO(unittest.TestCase): ) self.assertTrue( observed_contents.issubset(valid_contents), - f"Observed contents: {observed_contents}" + f"Observed contents: {observed_contents}", ) diff --git a/tests/oonireport/__init__.py b/tests/aggregatetunnelmetrics/oonireport/__init__.py similarity index 100% rename from tests/oonireport/__init__.py rename to tests/aggregatetunnelmetrics/oonireport/__init__.py diff --git a/tests/oonireport/test_collector.py b/tests/aggregatetunnelmetrics/oonireport/test_collector.py similarity index 97% rename from tests/oonireport/test_collector.py rename to tests/aggregatetunnelmetrics/oonireport/test_collector.py index eade928..aa48abd 100644 --- a/tests/oonireport/test_collector.py +++ b/tests/aggregatetunnelmetrics/oonireport/test_collector.py @@ -6,8 +6,12 @@ from datetime import datetime, timezone import json import unittest -from oonireport.collector import CollectorClient, CollectorConfig -from oonireport.model import APIError, Measurement +from aggregatetunnelmetrics.oonireport import ( + APIError, + CollectorClient, + CollectorConfig, + Measurement, +) class MockResponse: diff --git a/tests/oonireport/test_load.py b/tests/aggregatetunnelmetrics/oonireport/test_load.py similarity index 97% rename from tests/oonireport/test_load.py rename to tests/aggregatetunnelmetrics/oonireport/test_load.py index a1c4505..10f795b 100644 --- a/tests/oonireport/test_load.py +++ b/tests/aggregatetunnelmetrics/oonireport/test_load.py @@ -5,7 +5,7 @@ import os import tempfile import unittest -from oonireport import load_measurements +from aggregatetunnelmetrics.oonireport import load_measurements SAMPLE_MEASUREMENT = { diff --git a/tests/oonireport/test_model.py b/tests/aggregatetunnelmetrics/oonireport/test_model.py similarity index 96% rename from tests/oonireport/test_model.py rename to tests/aggregatetunnelmetrics/oonireport/test_model.py index d825207..258f7a0 100644 --- a/tests/oonireport/test_model.py +++ b/tests/aggregatetunnelmetrics/oonireport/test_model.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta, timezone import unittest -from oonireport.model import Measurement, datetime_to_ooni_format +from aggregatetunnelmetrics.oonireport import Measurement, datetime_to_ooni_format class SimpleTestKeys: -- GitLab From 655c13223d611bc37a92d90d72a013075a4fc640 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 8 Feb 2025 18:16:59 +0100 Subject: [PATCH 40/75] feat: aggregate and format for global scope --- .../fieldtestingcsv/model.py | 10 +- .../globalscope/__init__.py | 148 +++++++ .../ooniformatter/__init__.py | 383 ++++++++++++++++++ .../fieldtestingcsv/test_parser.py | 7 +- 4 files changed, 536 insertions(+), 12 deletions(-) create mode 100644 aggregatetunnelmetrics/globalscope/__init__.py create mode 100644 aggregatetunnelmetrics/ooniformatter/__init__.py diff --git a/aggregatetunnelmetrics/fieldtestingcsv/model.py b/aggregatetunnelmetrics/fieldtestingcsv/model.py index 4da841c..79a183c 100644 --- a/aggregatetunnelmetrics/fieldtestingcsv/model.py +++ b/aggregatetunnelmetrics/fieldtestingcsv/model.py @@ -15,8 +15,7 @@ class Entry: of the fields within the CSV file. """ - # Fields originally present in the CSV file - # format as of 2024-12-06 + # Fields present in the CSV file format as of 2024-12-06 filename: str date: datetime asn: str @@ -41,11 +40,8 @@ class Entry: err_message: str protocol: str - # Fields added on 2025-02-03 to allow for exporting - # the global aggregate tunnel metrics. - # - # TODO(XXX): update the CSV file spec and generation. - ping_target_address: str + # TODO(bassosimone): do we need to specialize on the ping target address + # or shall we just consider it to be a constant? def is_tunnel_measurement(self) -> bool: """ diff --git a/aggregatetunnelmetrics/globalscope/__init__.py b/aggregatetunnelmetrics/globalscope/__init__.py new file mode 100644 index 0000000..1467502 --- /dev/null +++ b/aggregatetunnelmetrics/globalscope/__init__.py @@ -0,0 +1,148 @@ +""" +Logic for aggregating field-testing results using the global +scope to produce OONI measurements ready to submit. + +See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum + +from .. import fieldtestingcsv + + +@dataclass +class AggregatorConfig: + """ + Configuration for the measurement aggregator. + """ + + provider: str + upstream_collector: str + probe_asn: str + probe_cc: str + + # threshold below which we emit sample_size + min_sample_size: int = 1000 + + # rounding sample_size to the nearest round_to + round_to: int = 100 + + software_name: str = "solitech-aggregate-tunnel-metrics" + software_version: str = "0.1.0" + + +def datetime_to_compact_utc(dt: datetime) -> str: + """Convert datetime to compact UTC format (YYYYMMDDThhmmssZ)""" + return dt.astimezone(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + + +@dataclass +class AggregateProtocolState: + """Flat representation of the ggregated state at global scope.""" + + # Core identification + provider: str + + # Scope identification + protocol: str + + # Time window + window_start: datetime + window_end: datetime + + # Statistics about the creation phase + creation: dict[str, int] = field(default_factory=dict) + + # Statistics about the tunnel ping phase + tunnel_ping_min: list[float] = field(default_factory=list) + tunnel_ping_avg: list[float] = field(default_factory=list) + tunnel_ping_max: list[float] = field(default_factory=list) + tunnel_ping_loss: list[float] = field(default_factory=list) + + # Statistics about the NDT phase + tunnel_ndt_download_throughput: list[float] = field(default_factory=list) + tunnel_ndt_download_latency: list[float] = field(default_factory=list) + tunnel_ndt_download_rexmit: list[float] = field(default_factory=list) + tunnel_ndt_upload_throughput: list[float] = field(default_factory=list) + tunnel_ndt_upload_latency: list[float] = field(default_factory=list) + tunnel_ndt_upload_rexmit: list[float] = field(default_factory=list) + + def _update_error_counts(self, entry: fieldtestingcsv.Entry) -> None: + """Update error counts based on a new entry""" + error_type = ( + "bootstrap.generic_error" if entry.is_tunnel_error_measurement() else "" + ) + self.creation[error_type] = self.creation.get(error_type, 0) + 1 + + def _update_performance_metrics(self, entry: fieldtestingcsv.Entry) -> None: + """Update performance metrics based on a new entry""" + if not entry.is_tunnel_error_measurement(): # only successful measurements + self._update_ping(entry) + self._update_ndt(entry) + + def _update_ping(self, entry: fieldtestingcsv.Entry) -> None: + """Unconditionally update the ping metrics.""" + self.tunnel_ping_min.append(entry.ping_roundtrip_min) + self.tunnel_ping_avg.append(entry.ping_roundtrip_avg) + self.tunnel_ping_max.append(entry.ping_roundtrip_max) + self.tunnel_ping_loss.append(entry.ping_packets_loss) + + def _update_ndt(self, entry: fieldtestingcsv.Entry) -> None: + """Unconditionally update the NDT metrics.""" + self.tunnel_ndt_upload_throughput.append(entry.throughput_download) + self.tunnel_ndt_download_latency.append(entry.latency_download) + self.tunnel_ndt_download_rexmit.append(entry.retransmission_download) + self.tunnel_ndt_upload_throughput.append(entry.throughput_upload) + self.tunnel_ndt_upload_latency.append(entry.latency_upload) + self.tunnel_ndt_upload_rexmit.append(entry.retransmission_upload) + + def update(self, entry: fieldtestingcsv.Entry) -> None: + """ + Update aggregator state with a new measurement. + """ + self._update_error_counts(entry) + self._update_performance_metrics(entry) + + +@dataclass +class AggregateState: + """Aggregates measurements by protocol at global scope.""" + + config: AggregatorConfig + window_start: datetime + window_end: datetime + protocols: dict[str, AggregateProtocolState] = field(default_factory=dict) + + def _is_in_window(self, entry: fieldtestingcsv.Entry) -> bool: + """Check if entry falls within our time window""" + return self.window_start <= entry.date < self.window_end + + def _is_tunnel_entry(self, entry: fieldtestingcsv.Entry) -> bool: + """Check if entry is a tunnel measurement""" + return entry.is_tunnel_measurement() + + def update(self, entry: fieldtestingcsv.Entry) -> None: + """Update aggregator state with a new measurement.""" + # Ensure we're in window and we're looking at a tunnel entry + if not self._is_in_window(entry): + return + if not self._is_tunnel_entry(entry): + return + + # Get or create state for this protocol + if entry.protocol not in self.protocols: + self.protocols[entry.protocol] = AggregateProtocolState( + provider=self.config.provider, + protocol=entry.protocol, + window_start=self.window_start, + window_end=self.window_end, + ) + + # Update the protocol-specific state + self.protocols[entry.protocol].update(entry) diff --git a/aggregatetunnelmetrics/ooniformatter/__init__.py b/aggregatetunnelmetrics/ooniformatter/__init__.py new file mode 100644 index 0000000..10c0942 --- /dev/null +++ b/aggregatetunnelmetrics/ooniformatter/__init__.py @@ -0,0 +1,383 @@ +""" +Formats aggregate tunnel metrics as a OONI measurement. + +See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from datetime import datetime +from dataclasses import dataclass +from statistics import quantiles +from typing import Any, Optional +from urllib.parse import urlunparse, urlencode + +import logging + +from ..globalscope import ( + AggregateProtocolState, + AggregatorConfig, + AggregateState, + datetime_to_compact_utc, +) +from ..oonireport import Measurement + + +@dataclass +class AggregationTimeWindow: + """Time window for aggregating measurements""" + + from_time: datetime + to_time: datetime + + def as_dict(self) -> dict: + """Convert to JSON-serializable dict""" + return { + "from": datetime_to_compact_utc(self.from_time), + "to": datetime_to_compact_utc(self.to_time), + } + + +@dataclass +class TestKeys: + """ + Models the test_keys portion of an OONI measurement as defined + in the aggregate tunnel metrics specification. + """ + + # Mandatory fields + provider: str + scope: str + protocol: str + time_window: AggregationTimeWindow + + # Optional fields depending on scope + endpoint_hostname: Optional[str] + endpoint_address: Optional[str] + endpoint_port: Optional[int] + asn: Optional[str] # Format: ^AS[0-9]+$ + cc: Optional[str] # Format: ^[A-Z]{2}$ + + # The bodies should always be present + bodies: list[dict[str, Any]] + + def as_dict(self) -> dict: + """Convert to JSON-serializable dict""" + # Start with required fields + d = { + "provider": self.provider, + "scope": self.scope, + "protocol": self.protocol, + "time_window": self.time_window.as_dict(), + "bodies": self.bodies, + } + + # Add optional fields if they exist + for field in ( + "endpoint_hostname", + "endpoint_address", + "endpoint_port", + "asn", + "cc", + ): + value = getattr(self, field) + if value is not None: + d[field] = value + + return d + + +class SerializationConfigError(Exception): + """Raised when serialization configuration does not allow us to proceed.""" + + +class Serializer: + """Converts aggregate endpoint state into OONI measurements""" + + def __init__(self, config: AggregatorConfig): + self.config = config + + @staticmethod + def _compute_percentiles(values: list[float]) -> dict[str, float]: + """Compute the required percentiles for OONI format""" + + if not values: + return {} + + q = quantiles(values, n=100, method="exclusive") + return { + "25p": round(q[24], 1), + "50p": round(q[49], 1), + "75p": round(q[74], 1), + "99p": round(q[98], 1), + } + + def _create_input_url(self, state: AggregateProtocolState) -> str: + """Create the measurement input URL""" + # The query is empty when using the global state + query = {} + + # Build URL using urlunparse for safety + return urlunparse( + ( + state.protocol, # scheme (e.g., "openvpn+obfs4") + state.provider, # netloc (e.g., "riseup.net") + "/", # path + "", # params + urlencode(query), # query (e.g., "address=1.2.3.4&...") + "", # fragment + ) + ) + + def _round_sample_size(self, sample_size: int) -> Optional[int]: + """Round the sample size according to the aggregate tunnel metrics spec.""" + if sample_size < self.config.min_sample_size: + return None + return round(sample_size / self.config.round_to) * self.config.round_to + + @staticmethod + def _maybe_with_sample_size( + obj: dict[str, Any], ss: Optional[int] + ) -> dict[str, Any]: + if ss is not None: + obj["sample_size"] = ss + return obj + + def _create_error_bodies( + self, state: AggregateProtocolState + ) -> list[dict[str, Any]]: + """Create error bodies if there are any errors""" + bodies = [] + total = sum(state.creation.values()) + if total > 0: + for error_type, count in state.creation.items(): + if not error_type: # Skip success counts + continue + bodies.append( + self._maybe_with_sample_size( + { + "phase": "creation", + "type": "network-error", + "failure_ratio": round(count / total, 2), + "error": error_type, + }, + self._round_sample_size(count), + ) + ) + return bodies + + def _validate_ping_measurements( + self, state: AggregateProtocolState, metric_type: str, measurements: list[float] + ) -> None: + """Validate ping measurements""" + if metric_type in ["min", "avg", "max"]: + for min_v, avg_v, max_v in zip( + state.tunnel_ping_min, state.tunnel_ping_avg, state.tunnel_ping_max + ): + if not (0 <= min_v <= avg_v <= max_v): + raise SerializationConfigError("invalid ping latency ordering") + if max_v > 60000: # 60 seconds + raise SerializationConfigError("unreasonable ping latency") + elif metric_type == "loss": + for loss in state.tunnel_ping_loss: + if not (0 <= loss <= 100): + raise SerializationConfigError("ping loss out of range") + + def _create_ping_bodies( + self, state: AggregateProtocolState + ) -> list[dict[str, Any]]: + """Create bodies for ping measurements""" + bodies = [] + items = ( + ("min", "latency_ms", state.tunnel_ping_min), + ("avg", "latency_ms", state.tunnel_ping_avg), + ("max", "latency_ms", state.tunnel_ping_max), + ("loss", "loss_percent", state.tunnel_ping_loss), + ) + for metric_type, key, measurements in items: + if measurements: # Only if we have measurements + try: + self._validate_ping_measurements(state, metric_type, measurements) + except SerializationConfigError as e: + logging.warning(str(e)) + continue + bodies.append( + self._maybe_with_sample_size( + { + "phase": "tunnel_ping", + "type": f"ping_{metric_type}", + "target_address": "", + key: self._compute_percentiles(measurements), + }, + self._round_sample_size(len(measurements)), + ) + ) + return bodies + + def _validate_ndt_measurements( + self, + throughput: list[float], + latency: list[float], + rexmit: list[float], + phase: str, + ) -> None: + """Validate NDT measurements""" + if len({len(throughput), len(latency), len(rexmit)}) > 1: + raise SerializationConfigError( + f"inconsistent NDT {phase} measurement counts" + ) + for t, l, r in zip(throughput, latency, rexmit): + if t < 0 or t > 10000: # 10 Gbps + raise SerializationConfigError(f"unreasonable NDT {phase} throughput") + if l < 0 or l > 60000: # 60 seconds + raise SerializationConfigError(f"unreasonable NDT {phase} latency") + if not (0 <= r <= 100): + raise SerializationConfigError( + f"NDT {phase} retransmission out of range" + ) + + def _create_ndt_bodies(self, state: AggregateProtocolState) -> list[dict[str, Any]]: + """Create bodies for NDT measurements""" + bodies = [] + items = ( + ( + "download", + state.tunnel_ndt_download_throughput, + state.tunnel_ndt_download_latency, + state.tunnel_ndt_download_rexmit, + ), + ( + "upload", + state.tunnel_ndt_upload_throughput, + state.tunnel_ndt_upload_latency, + state.tunnel_ndt_upload_rexmit, + ), + ) + for phase, throughput, latency, rexmit in items: + try: + self._validate_ndt_measurements(throughput, latency, rexmit, phase) + except SerializationConfigError as e: + logging.warning(str(e)) + continue + bodies.append( + self._maybe_with_sample_size( + { + "phase": f"tunnel_ndt_{phase}", + "type": f"ndt_{phase}", + "target_hostname": "", + "target_address": "", + "target_port": 0, + "latency_ms": self._compute_percentiles(latency), + "speed_mbits": self._compute_percentiles(throughput), + "retransmission_percent": self._compute_percentiles(rexmit), + }, + self._round_sample_size(len(throughput)), + ) + ) + return bodies + + def _create_global_bodies( + self, state: AggregateProtocolState + ) -> list[dict[str, Any]]: + """Create the bodies section of test_keys""" + bodies = [] + bodies.extend(self._create_error_bodies(state)) + bodies.extend(self._create_ping_bodies(state)) + bodies.extend(self._create_ndt_bodies(state)) + return bodies + + def _is_valid_state(self, state: AggregateProtocolState) -> bool: + """ + Validates the state before serialization. Returns False if state + should be skipped, True if it's valid to serialize. + + Logs warning messages explaining validation failures. + """ + + # Basic field validations + if not state.provider or not isinstance(state.provider, str): + logging.warning("invalid provider field") + return False + if not state.protocol or not isinstance(state.protocol, str): + logging.warning("invalid protocol field") + return False + if not state.window_start or not state.window_end: + logging.warning("invalid time window") + return False + if state.window_end <= state.window_start: + logging.warning("end time before start time") + return False + + # Creation phase validations + if not state.creation: + logging.warning("no creation phase data") + return False + if any(count < 0 for count in state.creation.values()): + logging.warning("negative creation counts") + return False + if sum(state.creation.values()) == 0: + logging.warning("no measurements") + return False + + # Logical consistency validations + success_count = state.creation.get("", 0) + has_measurements = bool( + state.tunnel_ping_min or state.tunnel_ndt_download_throughput + ) + if success_count > 0 and not has_measurements: + logging.warning("successful creations but no measurements") + return False + if success_count == 0 and has_measurements: + logging.warning("measurements without successful creations") + return False + + return True + + def serialize_global(self, state: AggregateState) -> list[Measurement]: + """ + Convert global state to OONI measurement format. + + Raises: + SerializationError: if the scope is not model.Scope.ENDPOINT. + """ + measurement_time = datetime.utcnow() + measurements = [] + + for proto_name, proto_state in state.protocols.items(): + if not self._is_valid_state(proto_state): + logging.warning(f"skipping invalid state for protocol {proto_name}") + continue + + test_keys = TestKeys( + provider=state.config.provider, + scope="global", + protocol=proto_name, + time_window=AggregationTimeWindow( + from_time=state.window_start, to_time=state.window_end + ), + endpoint_hostname=None, + endpoint_address=None, + endpoint_port=None, + asn=None, + cc=None, + bodies=self._create_global_bodies(proto_state), + ) + + mx = Measurement( + annotations={"upstream_collector": self.config.upstream_collector}, + data_format_version="0.2.0", + input=self._create_input_url(proto_state), + measurement_start_time=measurement_time, + probe_asn=self.config.probe_asn, + probe_cc=self.config.probe_cc, + software_name=self.config.software_name, + software_version=self.config.software_version, + test_keys=test_keys, + test_name="aggregate_tunnel_metrics", + test_runtime=0.0, + test_start_time=measurement_time, + test_version="0.1.0", + ) + measurements.append(mx) + + return measurements diff --git a/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py b/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py index e1b8a3b..c1c446c 100644 --- a/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py +++ b/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py @@ -25,8 +25,8 @@ class TestFieldTestingCSVParser(unittest.TestCase): def test_parse_valid_entry(self): """Test parsing a valid CSV row""" - csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT,ping_target_address -test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4,8.8.8.8""" + csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT +test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4""" self.write_csv(csv_content) entries = parse(self.csv_path) @@ -37,7 +37,6 @@ test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion, self.assertEqual(entry.filename, "test.csv") self.assertEqual(entry.asn, "AS12345") self.assertEqual(entry.protocol, "obfs4") # Note: mapped from PT - self.assertEqual(entry.ping_target_address, "8.8.8.8") self.assertEqual(entry.throughput_download, 100.0) self.assertEqual(entry.mobile, False) @@ -99,7 +98,6 @@ test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion, ping_roundtrip_max=15.0, err_message="", protocol="obfs4", - ping_target_address="8.8.8.8", ) self.assertTrue(entry.is_tunnel_measurement()) @@ -136,7 +134,6 @@ test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion, ping_roundtrip_max=15.0, err_message="", protocol="obfs4", - ping_target_address="8.8.8.8", ) self.assertTrue(entry.is_tunnel_error_measurement()) -- GitLab From 2cecb018f1f76961e5a0ea729d66aebd9aac5cc0 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 15 Feb 2025 16:08:21 +0100 Subject: [PATCH 41/75] fix: make sure we can run tests --- aggregatetunnelmetrics/fieldtestingcsv/parser.py | 1 - tests/aggregatetunnelmetrics/__init__.py | 0 .../fieldtestingcsv/test_parser.py | 11 +++++++---- 3 files changed, 7 insertions(+), 5 deletions(-) create mode 100644 tests/aggregatetunnelmetrics/__init__.py diff --git a/aggregatetunnelmetrics/fieldtestingcsv/parser.py b/aggregatetunnelmetrics/fieldtestingcsv/parser.py index a2d9fbb..ac08994 100644 --- a/aggregatetunnelmetrics/fieldtestingcsv/parser.py +++ b/aggregatetunnelmetrics/fieldtestingcsv/parser.py @@ -81,7 +81,6 @@ def parse(filename: str) -> List[Entry]: ping_roundtrip_max=float(row["ping_roundtrip_max"]), err_message=str(row["err_message"]).strip(), protocol=str(row["PT"]), # rename from "PT" to "protocol" - ping_target_address=str(row["ping_target_address"]), ) entries.append(measurement) diff --git a/tests/aggregatetunnelmetrics/__init__.py b/tests/aggregatetunnelmetrics/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py b/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py index c1c446c..db1b494 100644 --- a/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py +++ b/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py @@ -1,10 +1,13 @@ -import unittest from datetime import datetime, timezone -import tempfile + +import logging import os +import tempfile +import unittest from aggregatetunnelmetrics.fieldtestingcsv import Entry, parse +logging.basicConfig(level=logging.ERROR) class TestFieldTestingCSVParser(unittest.TestCase): def setUp(self): @@ -65,8 +68,8 @@ test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion, def test_missing_required_field(self): """Test handling missing required fields""" # Create CSV with missing ping_target_address - csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT -test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4""" + csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message +test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,""" self.write_csv(csv_content) entries = parse(self.csv_path) -- GitLab From fa5da61b3679adeb9a6fa96e79a31425ef95daa1 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 15 Feb 2025 17:03:25 +0100 Subject: [PATCH 42/75] feat(fieldtestingcsv): ensure we have good coverage --- aggregatetunnelmetrics/__init__.py | 4 +- .../fieldtestingcsv/__init__.py | 10 +- .../fieldtestingcsv/model.py | 8 +- .../fieldtestingcsv/parser.py | 128 +++++---- .../fieldtestingcsv/test_parser.py | 265 ++++++++++++------ 5 files changed, 269 insertions(+), 146 deletions(-) diff --git a/aggregatetunnelmetrics/__init__.py b/aggregatetunnelmetrics/__init__.py index 2d6d74a..3c808ef 100644 --- a/aggregatetunnelmetrics/__init__.py +++ b/aggregatetunnelmetrics/__init__.py @@ -1 +1,3 @@ -"""Contains package for the aggregate tunnel metrics functionality.""" +"""Container package for the aggregate tunnel metrics functionality.""" + +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/aggregatetunnelmetrics/fieldtestingcsv/__init__.py b/aggregatetunnelmetrics/fieldtestingcsv/__init__.py index c378114..efffa6b 100644 --- a/aggregatetunnelmetrics/fieldtestingcsv/__init__.py +++ b/aggregatetunnelmetrics/fieldtestingcsv/__init__.py @@ -2,18 +2,14 @@ Field-Testing CSV ================= -This package contains code for managing the field-testing CSV -based data format (e.g., loading and parsing). +This package contains code for parsing the field-testing CSV data format. See https://0xacab.org/leap/solitech-compose-client/-/blob/main/images/obfsvpn-openvpn-client/start.sh. """ # SPDX-License-Identifier: GPL-3.0-or-later -# TODO(bassosimone): documented the expected CSV field at -# the package level rather than just linking to code. - from .model import Entry -from .parser import parse +from .parser import parse_file -__all__ = ["Entry", "parse"] +__all__ = ["Entry", "parse_file"] diff --git a/aggregatetunnelmetrics/fieldtestingcsv/model.py b/aggregatetunnelmetrics/fieldtestingcsv/model.py index 79a183c..85d1f25 100644 --- a/aggregatetunnelmetrics/fieldtestingcsv/model.py +++ b/aggregatetunnelmetrics/fieldtestingcsv/model.py @@ -1,4 +1,8 @@ -"""Field-Testing CSV model.""" +""" +Internal module implementing the field-Testing CSV model. + +We expect you to import `fieldtestingcsv` directly, not this module. +""" # SPDX-License-Identifier: GPL-3.0-or-later @@ -6,7 +10,7 @@ from dataclasses import dataclass from datetime import datetime -@dataclass +@dataclass(frozen=True) class Entry: """ Models a single field-testing entry read from the CSV datastore. diff --git a/aggregatetunnelmetrics/fieldtestingcsv/parser.py b/aggregatetunnelmetrics/fieldtestingcsv/parser.py index ac08994..3e70dbd 100644 --- a/aggregatetunnelmetrics/fieldtestingcsv/parser.py +++ b/aggregatetunnelmetrics/fieldtestingcsv/parser.py @@ -1,9 +1,13 @@ -"""Field-Testing CSV parser.""" +""" +Internal module implementing the field-Testing CSV parser. + +We expect you to import `fieldtestingcsv` directly, not this module. +""" # SPDX-License-Identifier: GPL-3.0-or-later from datetime import datetime, timezone -from typing import Iterator, List +from typing import Dict, Iterator, List import csv import logging @@ -11,12 +15,16 @@ import logging from .model import Entry -def _parse_datetime(date_str: str) -> datetime: +def parse_datetime(date_str: str) -> datetime: """ Parse ctime formatted date from CSV into datetime object. Example format: "Fri Dec 6 15:27:16 UTC 2024" """ + # For now, since we expect UTC, let's be strict + if "UTC" not in date_str: + raise ValueError(f"expected UTC timezone in date string, got: {date_str}") + # strptime directives: # %a - Weekday name (Mon) # %b - Month name (Nov) @@ -26,16 +34,10 @@ def _parse_datetime(date_str: str) -> datetime: # %Y - Year (2024) dt = datetime.strptime(date_str, "%a %b %d %H:%M:%S %Z %Y") - # For now, since we expect UTC, let's be strict - # - # TODO(bassosimone): do we need to care about non-UTC? - if "UTC" not in date_str: - raise ValueError(f"expected UTC timezone in date string, got: {date_str}") - return dt.replace(tzinfo=timezone.utc) -def _parse_bool(value: str) -> bool: +def parse_bool(value: str) -> bool: """ Parse boolean string from CSV into bool. @@ -48,52 +50,72 @@ def _parse_bool(value: str) -> bool: return value.lower() == "true" -def parse(filename: str) -> List[Entry]: - """Parses and returns entries from CSV file.""" - entries = [] +def parse_single_row(row: Dict[str, str]) -> Entry: + """Convert a CSV row into a measurement entry. - with open(filename, "r") as f: - reader = csv.DictReader(f) - for row in reader: - try: + Args: + row: Dictionary containing CSV row data - measurement = Entry( - filename=str(row["filename"]), - date=_parse_datetime(row["date"]), - asn=str(row["asn"]), - isp=str(row["isp"]), - est_city=str(row["est_city"]), - user=str(row["user"]), - region=str(row["region"]), - server_fqdn=str(row["server_fqdn"]), - server_ip=str(row["server_ip"]), - mobile=_parse_bool(row["mobile"]), - tunnel=str(row["tunnel"]), - throughput_download=float(row["throughput_download"]), - throughput_upload=float(row["throughput_upload"]), - latency_download=float(row["latency_download"]), - latency_upload=float(row["latency_upload"]), - retransmission_download=float(row["retransmission_download"]), - retransmission_upload=float(row["retransmission_upload"]), - ping_packets_loss=float(row["ping_packets_loss"]), - ping_roundtrip_min=float(row["ping_roundtrip_min"]), - ping_roundtrip_avg=float(row["ping_roundtrip_avg"]), - ping_roundtrip_max=float(row["ping_roundtrip_max"]), - err_message=str(row["err_message"]).strip(), - protocol=str(row["PT"]), # rename from "PT" to "protocol" - ) - entries.append(measurement) + Returns: + Entry object constructed from row data - except (ValueError, KeyError) as exc: - logging.warning(f"cannot import row: {exc}") - continue + Raises: + ValueError: If numeric fields cannot be parsed + KeyError: If required fields are missing + """ + return Entry( + filename=str(row["filename"]), + date=parse_datetime(row["date"]), + asn=str(row["asn"]), + isp=str(row["isp"]), + est_city=str(row["est_city"]), + user=str(row["user"]), + region=str(row["region"]), + server_fqdn=str(row["server_fqdn"]), + server_ip=str(row["server_ip"]), + mobile=parse_bool(row["mobile"]), + tunnel=str(row["tunnel"]), + throughput_download=float(row["throughput_download"]), + throughput_upload=float(row["throughput_upload"]), + latency_download=float(row["latency_download"]), + latency_upload=float(row["latency_upload"]), + retransmission_download=float(row["retransmission_download"]), + retransmission_upload=float(row["retransmission_upload"]), + ping_packets_loss=float(row["ping_packets_loss"]), + ping_roundtrip_min=float(row["ping_roundtrip_min"]), + ping_roundtrip_avg=float(row["ping_roundtrip_avg"]), + ping_roundtrip_max=float(row["ping_roundtrip_max"]), + err_message=str(row["err_message"]).strip(), + protocol=str(row["PT"]), # rename from "PT" to "protocol" + ) + + +def parse_file(filename: str) -> List[Entry]: + """ + Parses and returns entries from CSV file. + + Note: This function loads all entries into memory. For large files, + consider using stream() instead if you don't need all entries at once. + """ + return list(stream_file(filename)) - return entries +def stream_file(filename: str) -> Iterator[Entry]: + """ + Stream entries from CSV file one at a time. -def stream(filename: str) -> Iterator[Entry]: - """Stream entries from CSV file one at a time.""" - # TODO(bassosimone): implement streaming, which is going to be more - # efficient than loading the whole file on memory, and then reimplement - # the `parse` function to use this function - raise NotImplementedError("streaming not yet implemented") + This function reads and parses the file line by line, yielding + entries as they are parsed. Invalid entries are logged and skipped. + Memory usage is O(1) as entries are not accumulated. + + Yields: + Entry objects one at a time + """ + with open(filename, "r") as filep: + crd = csv.DictReader(filep) + for row in crd: + try: + yield parse_single_row(row) + except (ValueError, KeyError) as exc: + logging.warning(f"cannot import row: {exc}") + continue diff --git a/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py b/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py index db1b494..16574ce 100644 --- a/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py +++ b/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py @@ -1,15 +1,158 @@ -from datetime import datetime, timezone +"""Tests for the field-testing CSV parser functionality.""" -import logging -import os -import tempfile -import unittest +# SPDX-License-Identifier: GPL-3.0-or-later -from aggregatetunnelmetrics.fieldtestingcsv import Entry, parse +import unittest +from datetime import datetime, timezone +import tempfile +import os +import logging -logging.basicConfig(level=logging.ERROR) +logging.basicConfig(level=logging.ERROR) # do not spew when running tests + +from aggregatetunnelmetrics.fieldtestingcsv.model import Entry +from aggregatetunnelmetrics.fieldtestingcsv.parser import ( + parse_datetime, + parse_bool, + parse_file, + parse_single_row, + stream_file, +) + + +class TestParserFunctions(unittest.TestCase): + """Test individual parser functions.""" + + def test_parse_datetime(self): + """Test parse_datetime with valid and invalid inputs""" + # Valid UTC date + dt = parse_datetime("Mon Jan 01 12:00:00 UTC 2024") + self.assertEqual(dt, datetime(2024, 1, 1, 12, 0, tzinfo=timezone.utc)) + + # Non-UTC timezone + with self.assertRaises(ValueError): + parse_datetime("Mon Jan 01 12:00:00 EST 2024") + + # Invalid format + with self.assertRaises(ValueError): + parse_datetime("2024-01-01 12:00:00") + + def test_parse_datetime_non_utc(self): + """Test parse_datetime explicitly with non-UTC timezone""" + with self.assertRaises(ValueError): + parse_datetime("Mon Jan 01 12:00:00 EST 2024") + + def test_parse_bool(self): + """Test parse_bool with various inputs""" + self.assertTrue(parse_bool("true")) + self.assertTrue(parse_bool("TRUE")) + self.assertFalse(parse_bool("false")) + self.assertFalse(parse_bool("FALSE")) + self.assertFalse(parse_bool("invalid")) + + +class TestRowParsing(unittest.TestCase): + """Test parsing individual rows.""" + + def test_parse_valid_row(self): + """Test parsing a valid row dictionary""" + row = { + "filename": "test.csv", + "date": "Mon Jan 01 12:00:00 UTC 2024", + "asn": "AS12345", + "isp": "TestISP", + "est_city": "TestCity", + "user": "user1", + "region": "TestRegion", + "server_fqdn": "ndt.example.com", + "server_ip": "1.2.3.4", + "mobile": "false", + "tunnel": "tunnel", + "throughput_download": "100.0", + "throughput_upload": "50.0", + "latency_download": "20.0", + "latency_upload": "25.0", + "retransmission_download": "0.01", + "retransmission_upload": "0.02", + "ping_packets_loss": "0.0", + "ping_roundtrip_min": "10.0", + "ping_roundtrip_avg": "12.0", + "ping_roundtrip_max": "15.0", + "err_message": "", + "PT": "obfs4", + } + + entry = parse_single_row(row) + self.assertEqual(entry.filename, "test.csv") + self.assertEqual(entry.protocol, "obfs4") + self.assertEqual(entry.throughput_download, 100.0) + self.assertEqual(entry.mobile, False) + self.assertEqual(entry.date, datetime(2024, 1, 1, 12, 0, tzinfo=timezone.utc)) + self.assertEqual(entry.asn, "AS12345") + self.assertEqual(entry.isp, "TestISP") + self.assertEqual(entry.est_city, "TestCity") + self.assertEqual(entry.user, "user1") + self.assertEqual(entry.region, "TestRegion") + self.assertEqual(entry.server_fqdn, "ndt.example.com") + self.assertEqual(entry.server_ip, "1.2.3.4") + self.assertEqual(entry.tunnel, "tunnel") + self.assertEqual(entry.throughput_upload, 50.0) + self.assertEqual(entry.latency_download, 20.0) + self.assertEqual(entry.latency_upload, 25.0) + self.assertEqual(entry.retransmission_download, 0.01) + self.assertEqual(entry.retransmission_upload, 0.02) + self.assertEqual(entry.ping_packets_loss, 0.0) + self.assertEqual(entry.ping_roundtrip_min, 10.0) + self.assertEqual(entry.ping_roundtrip_avg, 12.0) + self.assertEqual(entry.ping_roundtrip_max, 15.0) + self.assertEqual(entry.err_message, "") + + def test_parse_row_missing_field(self): + """Test parsing row with missing required field""" + row = { + # Missing most required fields + "filename": "test.csv", + "date": "Mon Jan 01 12:00:00 UTC 2024", + } + + with self.assertRaises(KeyError): + parse_single_row(row) + + def test_parse_row_invalid_numeric(self): + """Test parsing row with invalid numeric values""" + valid_row = { + "filename": "test.csv", + "date": "Mon Jan 01 12:00:00 UTC 2024", + "asn": "AS12345", + "isp": "TestISP", + "est_city": "TestCity", + "user": "user1", + "region": "TestRegion", + "server_fqdn": "ndt.example.com", + "server_ip": "1.2.3.4", + "mobile": "false", + "tunnel": "tunnel", + "throughput_download": "invalid", # Invalid numeric value + "throughput_upload": "50.0", + "latency_download": "20.0", + "latency_upload": "25.0", + "retransmission_download": "0.01", + "retransmission_upload": "0.02", + "ping_packets_loss": "0.0", + "ping_roundtrip_min": "10.0", + "ping_roundtrip_avg": "12.0", + "ping_roundtrip_max": "15.0", + "err_message": "", + "PT": "obfs4", + } + + with self.assertRaises(ValueError): + parse_single_row(valid_row) + + +class TestFileOperations(unittest.TestCase): + """Test file-level operations (parse and stream).""" -class TestFieldTestingCSVParser(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.mkdtemp() self.csv_path = os.path.join(self.temp_dir, "test.csv") @@ -26,58 +169,44 @@ class TestFieldTestingCSVParser(unittest.TestCase): with open(self.csv_path, "w") as f: f.write(content) - def test_parse_valid_entry(self): - """Test parsing a valid CSV row""" + def test_stream_valid_file(self): + """Test streaming from a valid CSV file""" csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4""" self.write_csv(csv_content) - entries = parse(self.csv_path) - + entries = list(stream_file(self.csv_path)) self.assertEqual(len(entries), 1) - entry = entries[0] + self.assertEqual(entries[0].protocol, "obfs4") - self.assertEqual(entry.filename, "test.csv") - self.assertEqual(entry.asn, "AS12345") - self.assertEqual(entry.protocol, "obfs4") # Note: mapped from PT - self.assertEqual(entry.throughput_download, 100.0) - self.assertEqual(entry.mobile, False) - - # Check datetime parsing - expected_dt = datetime(2024, 1, 1, 12, 0, tzinfo=timezone.utc) - self.assertEqual(entry.date, expected_dt) - - def test_parse_invalid_date(self): - """Test handling invalid date formats""" - csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT,ping_target_address -test.csv,invalid date,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4,8.8.8.8""" + def test_stream_file_with_invalid_rows(self): + """Test streaming from a file with invalid rows""" + csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT +test.csv,Mon Jan 01 12:00:00 EST 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4 +test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,not_a_number,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4""" self.write_csv(csv_content) - entries = parse(self.csv_path) + entries = list(stream_file(self.csv_path)) + # Both rows should be skipped due to errors self.assertEqual(len(entries), 0) - def test_parse_invalid_numeric(self): - """Test handling invalid numeric values""" - csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT,ping_target_address -test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,invalid,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4,8.8.8.8""" + def test_parse_equivalent_to_stream(self): + """Test that parse_file() returns same results as stream_file()""" + csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT +test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4""" self.write_csv(csv_content) - entries = parse(self.csv_path) - self.assertEqual(len(entries), 0) + streamed = list(stream_file(self.csv_path)) + parsed = parse_file(self.csv_path) + self.assertEqual(streamed, parsed) - def test_missing_required_field(self): - """Test handling missing required fields""" - # Create CSV with missing ping_target_address - csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message -test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,""" - self.write_csv(csv_content) - entries = parse(self.csv_path) - self.assertEqual(len(entries), 0) +class TestEntryMethods(unittest.TestCase): + """Test Entry class methods.""" - def test_is_tunnel_measurement(self): - """Test is_tunnel_measurement() method""" - entry = Entry( + @staticmethod + def make_entry(tunnel: str) -> Entry: + return Entry( filename="test.csv", date=datetime.now(timezone.utc), asn="AS12345", @@ -88,7 +217,7 @@ test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion, server_fqdn="ndt.example.com", server_ip="1.2.3.4", mobile=False, - tunnel="tunnel", + tunnel=tunnel, throughput_download=100.0, throughput_upload=50.0, latency_download=20.0, @@ -103,46 +232,16 @@ test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion, protocol="obfs4", ) - self.assertTrue(entry.is_tunnel_measurement()) - - entry.tunnel = "baseline" - self.assertFalse(entry.is_tunnel_measurement()) - - entry.tunnel = "ERROR/tunnel" - self.assertTrue(entry.is_tunnel_measurement()) + def test_is_tunnel_measurement(self): + """Test is_tunnel_measurement() method""" + self.assertTrue(self.make_entry("tunnel").is_tunnel_measurement()) + self.assertFalse(self.make_entry("baseline").is_tunnel_measurement()) + self.assertTrue(self.make_entry("ERROR/tunnel").is_tunnel_measurement()) def test_is_tunnel_error_measurement(self): """Test is_tunnel_error_measurement() method""" - entry = Entry( - filename="test.csv", - date=datetime.now(timezone.utc), - asn="AS12345", - isp="TestISP", - est_city="TestCity", - user="user1", - region="TestRegion", - server_fqdn="ndt.example.com", - server_ip="1.2.3.4", - mobile=False, - tunnel="ERROR/tunnel", - throughput_download=100.0, - throughput_upload=50.0, - latency_download=20.0, - latency_upload=25.0, - retransmission_download=0.01, - retransmission_upload=0.02, - ping_packets_loss=0.0, - ping_roundtrip_min=10.0, - ping_roundtrip_avg=12.0, - ping_roundtrip_max=15.0, - err_message="", - protocol="obfs4", - ) - - self.assertTrue(entry.is_tunnel_error_measurement()) - - entry.tunnel = "tunnel" - self.assertFalse(entry.is_tunnel_error_measurement()) + self.assertTrue(self.make_entry("ERROR/tunnel").is_tunnel_error_measurement()) + self.assertFalse(self.make_entry("tunnel").is_tunnel_error_measurement()) if __name__ == "__main__": -- GitLab From 13dac93247d8135847001b0a62c8d6f13b3340c1 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 15 Feb 2025 17:52:32 +0100 Subject: [PATCH 43/75] feat(lockefile): significanly improve tests --- .gitignore | 1 + aggregatetunnelmetrics/lockedfile/__init__.py | 2 +- aggregatetunnelmetrics/lockedfile/fileio.py | 8 +- aggregatetunnelmetrics/lockedfile/mutex.py | 8 +- .../lockedfile/fileio.py | 92 --------- .../lockedfile/test_fileio.py | 164 +++++++++++++++ .../lockedfile/test_mutex.py | 187 ++++++++++++++++++ 7 files changed, 366 insertions(+), 96 deletions(-) delete mode 100644 tests/aggregatetunnelmetrics/lockedfile/fileio.py create mode 100644 tests/aggregatetunnelmetrics/lockedfile/test_fileio.py create mode 100644 tests/aggregatetunnelmetrics/lockedfile/test_mutex.py diff --git a/.gitignore b/.gitignore index be38085..9633c92 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.coverage __pycache__ logs/* results/* diff --git a/aggregatetunnelmetrics/lockedfile/__init__.py b/aggregatetunnelmetrics/lockedfile/__init__.py index 9aaeadd..3aed17b 100644 --- a/aggregatetunnelmetrics/lockedfile/__init__.py +++ b/aggregatetunnelmetrics/lockedfile/__init__.py @@ -6,7 +6,7 @@ holding locks on them, ensuring safe concurrent access. Patterned after https://github.com/rogpeppe/go-internal `lockedfile`. -Note: this package only works on Unix systems. +Note: this package only works on Unix-like systems. """ # SPDX-License-Identifier: GPL-3.0-or-later diff --git a/aggregatetunnelmetrics/lockedfile/fileio.py b/aggregatetunnelmetrics/lockedfile/fileio.py index 3c0b38b..1e0cdc6 100644 --- a/aggregatetunnelmetrics/lockedfile/fileio.py +++ b/aggregatetunnelmetrics/lockedfile/fileio.py @@ -1,4 +1,8 @@ -"""Safe file I/O operations with locking.""" +""" +Internal implemenation of safe file I/O operations with locking. + +You should typically import `lockefile` directly instead. +""" # SPDX-License-Identifier: GPL-3.0-or-later @@ -13,7 +17,7 @@ import time from .common import FileLockError -@dataclass +@dataclass(frozen=True) class FileIOConfig: """ Configures attempting to acquire a file lock. diff --git a/aggregatetunnelmetrics/lockedfile/mutex.py b/aggregatetunnelmetrics/lockedfile/mutex.py index 3f12f08..a2df5ac 100644 --- a/aggregatetunnelmetrics/lockedfile/mutex.py +++ b/aggregatetunnelmetrics/lockedfile/mutex.py @@ -1,4 +1,8 @@ -"""Mutex class for mutual exclusion using flock""" +""" +Internal implementation of mutex class for mutual exclusion using flock. + +You should typically import `lockefile` directly instead. +""" # SPDX-License-Identifier: GPL-3.0-or-later @@ -36,6 +40,8 @@ class Mutex: if self.filep: try: fcntl.flock(self.filep.fileno(), fcntl.LOCK_UN) + except OSError: + pass # Suppress OSError during unlock finally: self.filep.close() self.filep = None diff --git a/tests/aggregatetunnelmetrics/lockedfile/fileio.py b/tests/aggregatetunnelmetrics/lockedfile/fileio.py deleted file mode 100644 index 7231cb5..0000000 --- a/tests/aggregatetunnelmetrics/lockedfile/fileio.py +++ /dev/null @@ -1,92 +0,0 @@ -"""Tests for the lockedfile/fileio.py functionality.""" - -# SPDX-License-Identifier: GPL-3.0-or-later - -from typing import Set - -import multiprocessing as mp -import os -import tempfile -import time -import unittest - -from aggregatetunnelmetrics.lockedfile import common, fileio - - -class TestFileIO(unittest.TestCase): - """Tests for the lockedfile/fileio.py functionality.""" - - def setUp(self): - self.temp_fd, self.temp_path = tempfile.mkstemp() - - def tearDown(self): - os.close(self.temp_fd) - os.unlink(self.temp_path) - - def test_concurrent_access(self): - """ - The purpose of this test is to spawn many readers and - writers using background processes, to ensure that, - regardless of concurrent access attempts, we are always - able to write to the file, and read from the file, - consistent strings. With enough repetitions, if there - are actual race conditions, the file content will - possibly eventually be corrupted and we will notice. - """ - - # TODO(bassosimone): ensure we have some cases of contention - # by collecting statistics about retries. - - valid_contents = ["foo", "bar", "baz", "qux"] - results_queue = mp.Queue() - should_stop = mp.Event() - config = fileio.FileIOConfig(num_retries=30, sleep_interval=0.1) - - def writer_process(): - while not should_stop.is_set(): - for content in valid_contents: - fileio.write(self.temp_path, content, config=config) - - writers = [mp.Process(target=writer_process) for _ in range(4)] - - def reader_process(): - while not should_stop.is_set(): - content = fileio.read(self.temp_path, config=config) - if content: - results_queue.put(content) - - readers = [mp.Process(target=reader_process) for _ in range(8)] - - # Start all processes - for p in readers + writers: - p.start() - - # Allow the processes to run for a while - time.sleep(1) - - # Interrupt the processes - should_stop.set() - - # Wait for processes to terminate - for p in readers + writers: - p.join() - - # Collect the results from the queue - observed_contents = set() - while not results_queue.empty(): - value = results_queue.get() - observed_contents.add(value) - - # Ensure we never read garbled data - self.assertTrue( - len(observed_contents) > 0, - "No data was read", - ) - self.assertTrue( - observed_contents.issubset(valid_contents), - f"Observed contents: {observed_contents}", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py b/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py new file mode 100644 index 0000000..3974bff --- /dev/null +++ b/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py @@ -0,0 +1,164 @@ +"""Tests for the lockedfile/fileio.py functionality.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Set +from unittest.mock import patch, mock_open + +import fcntl +import multiprocessing as mp +import os +import tempfile +import time +import unittest + +from aggregatetunnelmetrics.lockedfile import common, fileio + + +class TestFileIOUnit(unittest.TestCase): + """Unit tests for the lockedfile/fileio.py functionality.""" + + def test_read_success(self): + """Test successful file read with lock acquisition.""" + mock_content = "test content" + m = mock_open(read_data=mock_content) + + with patch("builtins.open", m), patch("fcntl.flock") as mock_flock: + result = fileio.read("test.txt") + self.assertEqual(result, mock_content) + m().read.assert_called_once() + self.assertEqual(mock_flock.call_count, 2) # acquire and release + + def test_read_lock_failure(self): + """Test read fails when unable to acquire lock.""" + m = mock_open() + + with patch("builtins.open", m), patch( + "fcntl.flock", side_effect=BlockingIOError() + ), patch("time.sleep"): + + with self.assertRaises(common.FileLockError): + fileio.read("test.txt") + + def test_write_success(self): + """Test successful file write with lock acquisition.""" + test_content = "test content" + m = mock_open() + + with patch("builtins.open", m), patch("fcntl.flock") as mock_flock, patch( + "os.fsync" + ) as mock_fsync: + + fileio.write("test.txt", test_content) + + m().write.assert_called_once_with(test_content) + m().flush.assert_called_once() + mock_fsync.assert_called_once() + self.assertEqual(mock_flock.call_count, 2) # acquire and release + + def test_write_lock_failure(self): + """Test write fails when unable to acquire lock.""" + m = mock_open() + + with patch("builtins.open", m), patch( + "fcntl.flock", side_effect=BlockingIOError() + ), patch("time.sleep"): + + with self.assertRaises(common.FileLockError): + fileio.write("test.txt", "content") + + def test_release_error_is_suppressed(self): + """Test that errors during lock release are suppressed.""" + mock_content = "test content" + m = mock_open(read_data=mock_content) + + def flock_side_effect(fd, operation): + if operation == fcntl.LOCK_UN: # When trying to unlock + raise OSError("mock release error") + return None # Success for lock acquisition + + with patch("builtins.open", m), patch( + "fcntl.flock", side_effect=flock_side_effect + ): + # Should complete without raising exception + result = fileio.read("test.txt") + self.assertEqual(result, mock_content) + + +class TestFileIOIntegration(unittest.TestCase): + """Integration tests for the lockedfile/fileio.py functionality.""" + + def setUp(self): + self.temp_fd, self.temp_path = tempfile.mkstemp() + + def tearDown(self): + os.close(self.temp_fd) + os.unlink(self.temp_path) + + def test_concurrent_access(self): + """ + The purpose of this test is to spawn many readers and + writers using background processes, to ensure that, + regardless of concurrent access attempts, we are always + able to write to the file, and read from the file, + consistent strings. With enough repetitions, if there + are actual race conditions, the file content will + possibly eventually be corrupted and we will notice. + """ + + # TODO(bassosimone): ensure we have some cases of contention + # by collecting statistics about retries. + + valid_contents = ["foo", "bar", "baz", "qux"] + results_queue = mp.Queue() + should_stop = mp.Event() + config = fileio.FileIOConfig(num_retries=30, sleep_interval=0.1) + + def writer_process(): + while not should_stop.is_set(): + for content in valid_contents: + fileio.write(self.temp_path, content, config=config) + + writers = [mp.Process(target=writer_process) for _ in range(4)] + + def reader_process(): + while not should_stop.is_set(): + content = fileio.read(self.temp_path, config=config) + if content: + results_queue.put(content) + + readers = [mp.Process(target=reader_process) for _ in range(8)] + + # Start all processes + for p in readers + writers: + p.start() + + # Allow the processes to run for a while + time.sleep(1) + + # Interrupt the processes + should_stop.set() + + # Wait for processes to terminate + for p in readers + writers: + p.join() + + # Collect the results from the queue + observed_contents = set() + while not results_queue.empty(): + value = results_queue.get() + observed_contents.add(value) + + # Ensure we never read garbled data + self.assertTrue( + len(observed_contents) > 0, + "No data was read", + ) + self.assertTrue( + observed_contents.issubset(valid_contents), + f"Observed contents: {observed_contents}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py b/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py new file mode 100644 index 0000000..d39c362 --- /dev/null +++ b/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py @@ -0,0 +1,187 @@ +"""Tests for the lockedfile/mutex.py functionality.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from unittest.mock import patch, mock_open +import fcntl +import multiprocessing as mp +import os +import tempfile +import time +import unittest + +from aggregatetunnelmetrics.lockedfile import Mutex, FileLockError + + +class TestMutexUnit(unittest.TestCase): + """Unit tests for the lockedfile/mutex.py functionality.""" + + def test_mutex_acquire_success(self): + """Test successful mutex lock acquisition.""" + m = mock_open() + + with patch("builtins.open", m), patch("fcntl.flock") as mock_flock: + with Mutex("test.lock"): + mock_flock.assert_called_once_with( + m().fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB + ) + + def test_mutex_release(self): + """Test mutex properly releases lock.""" + m = mock_open() + + with patch("builtins.open", m), patch("fcntl.flock") as mock_flock: + with Mutex("test.lock"): + pass + + # Second call should be unlock + mock_flock.assert_called_with(m().fileno(), fcntl.LOCK_UN) + self.assertEqual(mock_flock.call_count, 2) + m().close.assert_called_once() + + def test_mutex_acquire_failure_io_error(self): + """Test mutex fails to acquire with IOError.""" + m = mock_open() + + with patch("builtins.open", m), patch( + "fcntl.flock", side_effect=IOError("mock error") + ): + + with self.assertRaises(FileLockError) as cm: + with Mutex("test.lock"): + pass + self.assertIn("mock error", str(cm.exception)) + m().close.assert_called_once() + + def test_mutex_acquire_failure_blocking(self): + """Test mutex fails to acquire when already locked.""" + m = mock_open() + + with patch("builtins.open", m), patch( + "fcntl.flock", side_effect=BlockingIOError() + ): + + with self.assertRaises(FileLockError) as cm: + with Mutex("test.lock"): + pass + self.assertIn("cannot acquire lock", str(cm.exception)) + m().close.assert_called_once() + + def test_mutex_release_error_suppressed(self): + """Test that errors during mutex release are suppressed.""" + m = mock_open() + + def flock_side_effect(fd, operation): + if operation == fcntl.LOCK_UN: + raise OSError("mock release error") + return None + + with patch("builtins.open", m), patch( + "fcntl.flock", side_effect=flock_side_effect + ): + # Should complete without raising exception + with Mutex("test.lock"): + pass + m().close.assert_called_once() + + +class TestMutexIntegration(unittest.TestCase): + """Integration tests for the lockedfile/mutex.py functionality.""" + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + self.lock_path = os.path.join(self.temp_dir, "test.lock") + self.temp_path = os.path.join(self.temp_dir, "test.txt") + + def tearDown(self): + try: + os.unlink(self.lock_path) + os.unlink(self.temp_path) + os.rmdir(self.temp_dir) + except FileNotFoundError: + pass + + def test_concurrent_access(self): + """ + Test that mutex provides mutual exclusion with concurrent access. + By writing only from a fixed set of valid strings, any corruption + would result in invalid content that wouldn't match our set. + """ + valid_contents = ["alpha", "beta", "gamma", "delta"] + results_queue = mp.Queue() + should_stop = mp.Event() + + def writer_process(): + try: + while not should_stop.is_set(): + for content in valid_contents: + while not should_stop.is_set(): + try: + with Mutex(self.lock_path): + with open(self.temp_path, "w") as f: + f.write(content) + f.flush() + os.fsync(f.fileno()) + break # Success! Move to next content + except FileLockError: + # Lock busy, try again + time.sleep(0.1) + except Exception as exc: + results_queue.put(exc) + + def reader_process(): + try: + while not should_stop.is_set(): + try: + with Mutex(self.lock_path): + with open(self.temp_path) as f: + content = f.read() + if content: + results_queue.put(content) + break # Success! Continue outer loop + except FileLockError: + # Lock busy, try again + time.sleep(0.1) + except Exception as exc: + results_queue.put(exc) + + # Create and initialize the file + with open(self.temp_path, "w") as f: + f.write(valid_contents[0]) + + # Start processes + writers = [mp.Process(target=writer_process) for _ in range(4)] + readers = [mp.Process(target=reader_process) for _ in range(8)] + + for p in readers + writers: + p.start() + + # Let them run for a bit + time.sleep(1) + should_stop.set() + + # Wait for all processes + for p in readers + writers: + p.join() + + # Check for any errors + observed_contents = set() + while not results_queue.empty(): + item = results_queue.get() + if isinstance(item, Exception): + raise item + observed_contents.add(item) + + # Verify we only saw valid contents + self.assertTrue( + len(observed_contents) > 0, + "No data was read", + ) + self.assertTrue( + observed_contents.issubset(valid_contents), + f"Observed corrupted contents: {observed_contents}", + ) + + +if __name__ == "__main__": + unittest.main() -- GitLab From 04b4392f2acbe88399aba8e91beddc526bc4210d Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 15 Feb 2025 18:22:29 +0100 Subject: [PATCH 44/75] feat(oonireport): improve tests --- .../oonireport/collector.py | 8 +- aggregatetunnelmetrics/oonireport/load.py | 44 ++--- aggregatetunnelmetrics/oonireport/model.py | 23 ++- .../oonireport/__init__.py | 2 +- .../oonireport/test_collector.py | 120 ++++++++++++ .../oonireport/test_load.py | 183 +++++++++++++----- .../oonireport/test_model.py | 80 +++++++- 7 files changed, 370 insertions(+), 90 deletions(-) diff --git a/aggregatetunnelmetrics/oonireport/collector.py b/aggregatetunnelmetrics/oonireport/collector.py index e035925..db519c6 100644 --- a/aggregatetunnelmetrics/oonireport/collector.py +++ b/aggregatetunnelmetrics/oonireport/collector.py @@ -1,7 +1,9 @@ """ -Implements the OONI collector protocol. +Internal implementation the OONI collector protocol. See https://github.com/ooni/spec/blob/master/backends/bk-003-collector.md. + +Please, prefer importing the `oonireport` package directly. """ # SPDX-License-Identifier: GPL-3.0-or-later @@ -127,7 +129,9 @@ class CollectorClient: Raises: model.APIError: in case of failure. """ - measurement.report_id = report_id # Required for Explorer visualization + measurement = measurement.with_report_id( + report_id + ) # Required for Explorer visualization data = json.dumps( { "format": "json", diff --git a/aggregatetunnelmetrics/oonireport/load.py b/aggregatetunnelmetrics/oonireport/load.py index 8ae9eab..28ac8e3 100644 --- a/aggregatetunnelmetrics/oonireport/load.py +++ b/aggregatetunnelmetrics/oonireport/load.py @@ -1,4 +1,8 @@ -"""Functions for loading OONI measurements from a given file.""" +""" +Internal functions for loading OONI measurements from a given file. + +Please, prefer importing the `oonireport` package directly. +""" # SPDX-License-Identifier: GPL-3.0-or-later @@ -20,7 +24,7 @@ class _DictTestKeys: return self._keys -def _load_measurement(raw_data: str) -> Measurement: +def load_single_measurement(raw_data: str) -> Measurement: """ Creates a Measurement from JSON data. @@ -96,54 +100,42 @@ def _load_measurement(raw_data: str) -> Measurement: ) -def load_measurements(path: str) -> List[Measurement]: +def stream_measurements(path: str) -> Iterator[Measurement]: """ - Loads measurements from a JSON file containing one measurement per line. + Streams measurements from a JSON file containing one measurement per line. + + Is more efficient than load_measurements() for large files since it reads + line by line instead of loading the entire file into memory. Args: path: Path to the JSON file Returns: - List of Measurement instances. + Iterator yielding Measurement instances. Raises: ValueError: if the file format is invalid. OSError: if file operations fail. """ with open(path) as filep: - content = filep.read().strip() - if not content: - return [] - - # Try parsing as newline-delimited JSON - measurements = [] - for line in content.split("\n"): + for line in filep: line = line.strip() if line: # Skip empty lines - measurements.append(_load_measurement(line)) + yield load_single_measurement(line) - return measurements - -def stream_measurements(path: str) -> Iterator[Measurement]: +def load_measurements(path: str) -> List[Measurement]: """ - Streams measurements from a JSON file containing one measurement per line. - - Is more efficient than load_measurements() for large files. + Loads measurements from a JSON file containing one measurement per line. Args: path: Path to the JSON file Returns: - Iterator yielding Measurement instances. + List of Measurement instances. Raises: ValueError: if the file format is invalid. OSError: if file operations fail. - - Note: - Not implemented yet. """ - # TODO(bassosimone): implement this function and reimplement - # load_measurements() in terms of it. - raise NotImplementedError() + return list(stream_measurements(path)) diff --git a/aggregatetunnelmetrics/oonireport/model.py b/aggregatetunnelmetrics/oonireport/model.py index 99faec9..3fae6d2 100644 --- a/aggregatetunnelmetrics/oonireport/model.py +++ b/aggregatetunnelmetrics/oonireport/model.py @@ -1,12 +1,14 @@ """ -OONI measurement model. +Internal definition of the OONI measurement model. See https://github.com/ooni/spec/blob/master/data-formats/df-000-base.md. + +Please, prefer importing the `oonireport` package directly. """ # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from datetime import datetime, timezone from typing import Any, Dict, Protocol @@ -16,14 +18,17 @@ class APIError(Exception): class TestKeys(Protocol): - """Models the OONI measurement test keys.""" + """ + Models the OONI measurement test keys. + + Methods: + as_dict: Converts the test keys to a JSON-serializable dict. + """ - def as_dict(self) -> Dict[str, Any]: - """Converts the test keys to a JSON-serializable dict.""" - ... + def as_dict(self) -> Dict[str, Any]: ... -@dataclass +@dataclass(frozen=True) class Measurement: """Models the OONI measurement envelope.""" @@ -97,6 +102,10 @@ class Measurement: return dct + def with_report_id(self, report_id: str) -> "Measurement": + """Creates a new Measurement instance with the given report_id.""" + return replace(self, report_id=report_id) + def datetime_to_ooni_format(dt: datetime) -> str: """Converts a datetime to OONI's datetime format (YYYY-mm-dd HH:MM:SS).""" diff --git a/tests/aggregatetunnelmetrics/oonireport/__init__.py b/tests/aggregatetunnelmetrics/oonireport/__init__.py index cbbb2ce..8b13789 100644 --- a/tests/aggregatetunnelmetrics/oonireport/__init__.py +++ b/tests/aggregatetunnelmetrics/oonireport/__init__.py @@ -1 +1 @@ -"""Tests for the oonireport module.""" + diff --git a/tests/aggregatetunnelmetrics/oonireport/test_collector.py b/tests/aggregatetunnelmetrics/oonireport/test_collector.py index aa48abd..df2ec90 100644 --- a/tests/aggregatetunnelmetrics/oonireport/test_collector.py +++ b/tests/aggregatetunnelmetrics/oonireport/test_collector.py @@ -1,5 +1,7 @@ """Tests for the collector module.""" +# SPDX-License-Identifier: GPL-3.0-or-later + from unittest.mock import patch from datetime import datetime, timezone @@ -148,3 +150,121 @@ class TestCollectorClient(unittest.TestCase): self.client.update_report("test_report_id", measurement) self.assertIn("HTTP error", str(cm.exception)) + + @patch("urllib.request.urlopen") + def test_create_report_from_measurement(self, mock_urlopen): + """Test create_report_from_measurement properly calls create_report.""" + mock_urlopen.return_value = MockResponse(200, {"report_id": "test_report_id"}) + + # Create a measurement instance + measurement = Measurement( + annotations={}, + data_format_version="0.2.0", + input="https://example.com", + measurement_start_time=self.dt, + probe_asn="AS12345", + probe_cc="IT", + software_name="ooniprobe", + software_version="3.0.0", + test_keys=SimpleTestKeys(), + test_name="web_connectivity", + test_runtime=1.0, + test_start_time=self.dt, + test_version="0.0.1", + ) + + report_id = self.client.create_report_from_measurement(measurement) + + self.assertEqual(report_id, "test_report_id") + # Verify the correct URL was called + self.assertEqual(mock_urlopen.call_count, 1) + request = mock_urlopen.call_args[0][0] + self.assertEqual(request.full_url, "https://example.org/report") + + @patch("urllib.request.urlopen") + def test_create_report_api_error(self, mock_urlopen): + """Test create_report handles non-200 status code.""" + mock_urlopen.return_value = MockResponse( + 500, {"error": "Internal Server Error"} + ) + + with self.assertRaises(APIError) as cm: + self.client.create_report( + test_name="web_connectivity", + test_version="0.0.1", + software_name="ooniprobe", + software_version="3.0.0", + probe_asn="AS12345", + probe_cc="IT", + test_start_time=self.dt, + ) + + self.assertIn("unexpected status: 500", str(cm.exception)) + + @patch("urllib.request.urlopen") + def test_update_report_api_error(self, mock_urlopen): + """Test update_report handles non-200 status code.""" + mock_urlopen.return_value = MockResponse(403, {"error": "Forbidden"}) + + measurement = Measurement( + annotations={}, + data_format_version="0.2.0", + input="https://example.com", + measurement_start_time=self.dt, + probe_asn="AS12345", + probe_cc="IT", + software_name="ooniprobe", + software_version="3.0.0", + test_keys=SimpleTestKeys(), + test_name="web_connectivity", + test_runtime=1.0, + test_start_time=self.dt, + test_version="0.0.1", + ) + + with self.assertRaises(APIError) as cm: + self.client.update_report("test_report_id", measurement) + + self.assertIn("unexpected status: 403", str(cm.exception)) + + @patch("urllib.request.urlopen") + def test_update_report_invalid_json(self, mock_urlopen): + """Test update_report handles invalid JSON response.""" + + class InvalidJSONResponse: + status = 200 + + def read(self): + return b"invalid json" + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + mock_urlopen.return_value = InvalidJSONResponse() + + measurement = Measurement( + annotations={}, + data_format_version="0.2.0", + input="https://example.com", + measurement_start_time=self.dt, + probe_asn="AS12345", + probe_cc="IT", + software_name="ooniprobe", + software_version="3.0.0", + test_keys=SimpleTestKeys(), + test_name="web_connectivity", + test_runtime=1.0, + test_start_time=self.dt, + test_version="0.0.1", + ) + + # Should return None when JSON decoding fails but status is 200 + result = self.client.update_report("test_report_id", measurement) + self.assertIsNone(result) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/aggregatetunnelmetrics/oonireport/test_load.py b/tests/aggregatetunnelmetrics/oonireport/test_load.py index 10f795b..e4e2211 100644 --- a/tests/aggregatetunnelmetrics/oonireport/test_load.py +++ b/tests/aggregatetunnelmetrics/oonireport/test_load.py @@ -1,14 +1,19 @@ """Tests for measurement loading functionality.""" +# SPDX-License-Identifier: GPL-3.0-or-later + import json import os import tempfile import unittest -from aggregatetunnelmetrics.oonireport import load_measurements - +from aggregatetunnelmetrics.oonireport.load import ( + load_measurements, + load_single_measurement, +) -SAMPLE_MEASUREMENT = { +# Valid measurement for testing +VALID_MEASUREMENT = { "annotations": {}, "data_format_version": "0.2.0", "input": "https://example.com/", @@ -25,53 +30,125 @@ SAMPLE_MEASUREMENT = { } -class TestMeasurementLoading(unittest.TestCase): - - def test_load_from_file(self): - with tempfile.NamedTemporaryFile(mode="w", delete=False) as filep: - filep.write(json.dumps(SAMPLE_MEASUREMENT) + "\n") - filep.write(json.dumps(SAMPLE_MEASUREMENT) + "\n") - - try: - measurements = load_measurements(filep.name) - self.assertEqual(len(measurements), 2) - for m in measurements: - dm = m.as_dict() - self.assertEqual(dm["data_format_version"], "0.2.0") - self.assertEqual(dm["input"], "https://example.com/") - self.assertEqual(dm["measurement_start_time"], "2023-01-01 12:00:00") - self.assertEqual(dm["probe_asn"], "AS12345") - self.assertEqual(dm["probe_cc"], "IT") - self.assertEqual(dm["probe_ip"], "127.0.0.1") - self.assertEqual(dm["report_id"], "") - self.assertEqual(dm["software_name"], "ooniprobe") - self.assertEqual(dm["software_version"], "3.0.0") - self.assertEqual(dm["test_keys"], {"simple": "test"}) - self.assertEqual(dm["test_name"], "web_connectivity") - self.assertEqual(dm["test_runtime"], 1.0) - self.assertEqual(dm["test_start_time"], "2023-01-01 12:00:00") - self.assertEqual(dm["test_version"], "0.0.1") - self.assertEqual( - set(dm.keys()), - set( - [ - "annotations", - "data_format_version", - "input", - "measurement_start_time", - "probe_asn", - "probe_cc", - "probe_ip", - "report_id", - "software_name", - "software_version", - "test_keys", - "test_name", - "test_runtime", - "test_start_time", - "test_version", - ] - ), - ) - finally: - os.unlink(filep.name) +class TestMeasurementValidation(unittest.TestCase): + """Tests for validating individual measurements.""" + + def test_missing_required_fields(self): + """Test that missing required fields raise ValueError.""" + # Remove a required field + invalid_measurement = VALID_MEASUREMENT.copy() + del invalid_measurement["probe_asn"] + + with self.assertRaises(ValueError) as cm: + load_single_measurement(json.dumps(invalid_measurement)) + + self.assertIn("Missing required fields", str(cm.exception)) + self.assertIn("probe_asn", str(cm.exception)) + + def test_invalid_date_format_measurement_start_time(self): + """Test that invalid measurement_start_time format raises ValueError.""" + invalid_measurement = VALID_MEASUREMENT.copy() + invalid_measurement["measurement_start_time"] = ( + "2023-13-32 25:61:61" # Invalid date + ) + + with self.assertRaises(ValueError) as cm: + load_single_measurement(json.dumps(invalid_measurement)) + + self.assertIn("Invalid datetime format", str(cm.exception)) + + def test_invalid_date_format_test_start_time(self): + """Test that invalid test_start_time format raises ValueError.""" + invalid_measurement = VALID_MEASUREMENT.copy() + invalid_measurement["test_start_time"] = "not-a-date" # Invalid date + + with self.assertRaises(ValueError) as cm: + load_single_measurement(json.dumps(invalid_measurement)) + + self.assertIn("Invalid datetime format", str(cm.exception)) + + +class TestFileLoading(unittest.TestCase): + """Tests for loading measurements from files.""" + + def setUp(self): + """Create a temporary file for testing.""" + self.temp_file = tempfile.NamedTemporaryFile(mode="w", delete=False) + + def tearDown(self): + """Clean up temporary file.""" + os.unlink(self.temp_file.name) + + def test_load_valid_measurements(self): + """Test loading multiple valid measurements.""" + # Write two valid measurements to file + with open(self.temp_file.name, "w") as f: + f.write(json.dumps(VALID_MEASUREMENT) + "\n") + f.write(json.dumps(VALID_MEASUREMENT) + "\n") + + measurements = load_measurements(self.temp_file.name) + + self.assertEqual(len(measurements), 2) + for m in measurements: + # Test all fields + self.assertEqual(m.annotations, {}) + self.assertEqual(m.data_format_version, "0.2.0") + self.assertEqual(m.input, "https://example.com/") + self.assertEqual(m.probe_asn, "AS12345") + self.assertEqual(m.probe_cc, "IT") + self.assertEqual(m.probe_ip, "127.0.0.1") + self.assertEqual(m.report_id, "") + self.assertEqual(m.software_name, "ooniprobe") + self.assertEqual(m.software_version, "3.0.0") + self.assertEqual(m.test_keys.as_dict(), {"simple": "test"}) + self.assertEqual(m.test_name, "web_connectivity") + self.assertEqual(m.test_runtime, 1.0) + self.assertEqual(m.test_version, "0.0.1") + + # Verify we're only loading expected fields + expected_fields = { + "annotations", + "data_format_version", + "input", + "measurement_start_time", + "probe_asn", + "probe_cc", + "probe_ip", + "report_id", + "software_name", + "software_version", + "test_keys", + "test_name", + "test_runtime", + "test_start_time", + "test_version", + } + self.assertEqual(set(m.as_dict().keys()), expected_fields) + + def test_load_empty_file(self): + """Test loading from an empty file.""" + with open(self.temp_file.name, "w") as f: + f.write("") + + measurements = load_measurements(self.temp_file.name) + self.assertEqual(len(measurements), 0) + + def test_load_file_with_blank_lines(self): + """Test loading file with blank lines between measurements.""" + with open(self.temp_file.name, "w") as f: + f.write(json.dumps(VALID_MEASUREMENT) + "\n") + f.write("\n") # blank line + f.write(json.dumps(VALID_MEASUREMENT) + "\n") + f.write("\n") # blank line + + measurements = load_measurements(self.temp_file.name) + self.assertEqual(len(measurements), 2) + + def test_file_not_found(self): + """Test attempting to load from non-existent file.""" + with self.assertRaises(OSError): + load_measurements("nonexistent_file.json") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/aggregatetunnelmetrics/oonireport/test_model.py b/tests/aggregatetunnelmetrics/oonireport/test_model.py index 258f7a0..40a8ea4 100644 --- a/tests/aggregatetunnelmetrics/oonireport/test_model.py +++ b/tests/aggregatetunnelmetrics/oonireport/test_model.py @@ -1,5 +1,7 @@ """Tests for the model module.""" +# SPDX-License-Identifier: GPL-3.0-or-later + from datetime import datetime, timedelta, timezone import unittest @@ -16,7 +18,8 @@ class SimpleTestKeys: class TestModel(unittest.TestCase): - def test_measurement_as_dict(self): + def test_measurement_as_dict_minimal(self): + """Test with just the mandatory fields.""" dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) measurement = Measurement( annotations={"annotation_key": "value"}, @@ -36,6 +39,7 @@ class TestModel(unittest.TestCase): data = measurement.as_dict() + # Check mandatory fields self.assertEqual(data["annotations"], {"annotation_key": "value"}) self.assertEqual(data["data_format_version"], "0.2.0") self.assertEqual(data["input"], "https://example.com") @@ -50,6 +54,80 @@ class TestModel(unittest.TestCase): self.assertEqual(data["test_start_time"], "2023-01-01 12:00:00") self.assertEqual(data["test_version"], "0.0.1") + # Check default values + self.assertEqual(data["probe_ip"], "127.0.0.1") + self.assertEqual(data["report_id"], "") + + # Check that optional fields are not present + self.assertNotIn("options", data) + self.assertNotIn("probe_network_name", data) + self.assertNotIn("resolver_asn", data) + self.assertNotIn("resolver_ip", data) + self.assertNotIn("resolver_network_name", data) + self.assertNotIn("test_helpers", data) + + def test_measurement_as_dict_all_fields(self): + """Test with all fields, including optional ones.""" + dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) + measurement = Measurement( + # Mandatory fields + annotations={"annotation_key": "value"}, + data_format_version="0.2.0", + input="https://example.com", + measurement_start_time=dt, + probe_asn="AS12345", + probe_cc="IT", + software_name="ooniprobe", + software_version="3.0.0", + test_keys=SimpleTestKeys(), + test_name="web_connectivity", + test_runtime=1.0, + test_start_time=dt, + test_version="0.0.1", + # Fields with default values + probe_ip="93.184.216.34", + report_id="20230101_IT_test", + # Optional fields + options=["option1", "option2"], + probe_network_name="Example ISP", + resolver_asn="AS12346", + resolver_ip="8.8.8.8", + resolver_network_name="Example DNS", + test_helpers={"dns": "8.8.8.8", "web": "web-connectivity.example.org"}, + ) + + data = measurement.as_dict() + + # Check mandatory fields + self.assertEqual(data["annotations"], {"annotation_key": "value"}) + self.assertEqual(data["data_format_version"], "0.2.0") + self.assertEqual(data["input"], "https://example.com") + self.assertEqual(data["measurement_start_time"], "2023-01-01 12:00:00") + self.assertEqual(data["probe_asn"], "AS12345") + self.assertEqual(data["probe_cc"], "IT") + self.assertEqual(data["software_name"], "ooniprobe") + self.assertEqual(data["software_version"], "3.0.0") + self.assertEqual(data["test_keys"], {"simple": "test"}) + self.assertEqual(data["test_name"], "web_connectivity") + self.assertEqual(data["test_runtime"], 1.0) + self.assertEqual(data["test_start_time"], "2023-01-01 12:00:00") + self.assertEqual(data["test_version"], "0.0.1") + + # Check fields with default values + self.assertEqual(data["probe_ip"], "93.184.216.34") + self.assertEqual(data["report_id"], "20230101_IT_test") + + # Check optional fields + self.assertEqual(data["options"], ["option1", "option2"]) + self.assertEqual(data["probe_network_name"], "Example ISP") + self.assertEqual(data["resolver_asn"], "AS12346") + self.assertEqual(data["resolver_ip"], "8.8.8.8") + self.assertEqual(data["resolver_network_name"], "Example DNS") + self.assertEqual( + data["test_helpers"], + {"dns": "8.8.8.8", "web": "web-connectivity.example.org"}, + ) + def test_datetime_to_ooni_format_utc(self): dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) formatted = datetime_to_ooni_format(dt) -- GitLab From 7defb0941ac699e04cf3cd3d50055cd16c6e0920 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 15 Feb 2025 19:02:31 +0100 Subject: [PATCH 45/75] feat(oonireport): add coverage for upload scripts --- .coveragerc | 3 + aggregatetunnelmetrics/oonireport/__main__.py | 7 +- .../fieldtestingcsv/test_parser.py | 2 +- .../oonireport/test_main.py | 269 ++++++++++++++++++ 4 files changed, 276 insertions(+), 5 deletions(-) create mode 100644 .coveragerc create mode 100644 tests/aggregatetunnelmetrics/oonireport/test_main.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..cc1e46a --- /dev/null +++ b/.coveragerc @@ -0,0 +1,3 @@ +[run] +source = aggregatetunnelmetrics +omit = tests/* diff --git a/aggregatetunnelmetrics/oonireport/__main__.py b/aggregatetunnelmetrics/oonireport/__main__.py index 8738509..87e8a6f 100644 --- a/aggregatetunnelmetrics/oonireport/__main__.py +++ b/aggregatetunnelmetrics/oonireport/__main__.py @@ -58,9 +58,7 @@ def main(args: Optional[List[str]] = None) -> int: ) opts = parser.parse_args(args) - if opts.command != "upload": - print("oonireport: unknown command", file=sys.stderr) - return 1 + assert opts.command == "upload" # parse_args throws on invalid usage try: measurements = load_measurements(opts.file) @@ -112,9 +110,10 @@ def main(args: Optional[List[str]] = None) -> int: os.unlink(opts.file) except Exception as exc: print(f"oonireport: failed to delete input file: {exc}", file=sys.stderr) + return 1 return 0 if numfailures == 0 else 1 -if __name__ == "__main__": +if __name__ == "__main__": # pragma: no cover sys.exit(main()) diff --git a/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py b/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py index 16574ce..bd002c4 100644 --- a/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py +++ b/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py @@ -8,7 +8,7 @@ import tempfile import os import logging -logging.basicConfig(level=logging.ERROR) # do not spew when running tests +logging.basicConfig(level=logging.ERROR) # do not log when running tests from aggregatetunnelmetrics.fieldtestingcsv.model import Entry from aggregatetunnelmetrics.fieldtestingcsv.parser import ( diff --git a/tests/aggregatetunnelmetrics/oonireport/test_main.py b/tests/aggregatetunnelmetrics/oonireport/test_main.py new file mode 100644 index 0000000..f989385 --- /dev/null +++ b/tests/aggregatetunnelmetrics/oonireport/test_main.py @@ -0,0 +1,269 @@ +"""Tests for the command line interface.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from datetime import datetime, timezone +from unittest.mock import ANY, patch + +import json +import tempfile +import os +import unittest + +from aggregatetunnelmetrics.oonireport.__main__ import main +from aggregatetunnelmetrics.oonireport.model import Measurement + + +class SimpleTestKeys: + """Simple TestKeys implementation for testing.""" + + def as_dict(self): + return {"simple": "test"} + + +class TestMain(unittest.TestCase): + def setUp(self): + self.dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) + self.valid_measurement = Measurement( + annotations={}, + data_format_version="0.2.0", + input="https://example.com", + measurement_start_time=self.dt, + probe_asn="AS12345", + probe_cc="IT", + software_name="ooniprobe", + software_version="3.0.0", + test_keys=SimpleTestKeys(), + test_name="web_connectivity", + test_runtime=1.0, + test_start_time=self.dt, + test_version="0.0.1", + ) + + @patch("sys.stderr") + def test_main_no_args(self, mock_stderr): + """Test main with no arguments.""" + with patch("sys.argv", ["oonireport"]): + with self.assertRaises(SystemExit): + main() + + @patch("sys.stderr") + @patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") + def test_main_successful_upload(self, mock_client_class, mock_stderr): + """Test successful measurement upload.""" + # Create temporary file with measurement + with tempfile.NamedTemporaryFile(mode="w", delete=False) as tf: + json.dump(self.valid_measurement.as_dict(), tf) + tf.write("\n") + tf.flush() + temp_path = tf.name + + try: + # Create mock client instance + mock_client = mock_client_class.return_value + + # Mock successful API responses + mock_client.create_report_from_measurement.return_value = "test_report_id" + mock_client.update_report.return_value = "test_measurement_uid" + + # Run main with arguments + args = ["upload", "-f", temp_path] + exit_code = main(args) + + # Verify success + self.assertEqual(exit_code, 0) + + # Verify API calls + mock_client.create_report_from_measurement.assert_called_once_with( + ANY, + ) + mock_client.update_report.assert_called_once_with( + "test_report_id", + ANY, + ) + + finally: + os.unlink(temp_path) + + @patch("sys.stderr") + @patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") + def test_main_failed_upload(self, mock_client_class, mock_stderr): + """Test failed measurement upload.""" + # Create temporary file with measurement + with tempfile.NamedTemporaryFile(mode="w", delete=False) as tf: + json.dump(self.valid_measurement.as_dict(), tf) + tf.write("\n") + tf.flush() + temp_path = tf.name + + try: + # Create mock client instance + mock_client = mock_client_class.return_value + + # Mock API failure + mock_client.create_report_from_measurement.side_effect = Exception( + "API Error" + ) + + # Run main with arguments + args = ["upload", "-f", temp_path] + exit_code = main(args) + + # Verify failure and that the method was actually called + self.assertEqual(exit_code, 1) + mock_client.create_report_from_measurement.assert_called_once_with( + ANY, + ) + + finally: + os.unlink(temp_path) + + @patch("sys.stderr") + @patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") + def test_main_dump_failed(self, mock_client_class, mock_stderr): + """Test dumping failed measurements.""" + # Create temporary file with measurement + with tempfile.NamedTemporaryFile(mode="w", delete=False) as tf: + json.dump(self.valid_measurement.as_dict(), tf) + tf.write("\n") + tf.flush() + temp_path = tf.name + + try: + # Create mock client instance + mock_client = mock_client_class.return_value + + # Mock API failure + mock_client.create_report_from_measurement.side_effect = Exception( + "API Error" + ) + + # Run main with arguments and capture stdout + with patch("sys.stdout") as mock_stdout: + args = ["upload", "-f", temp_path, "--dump-failed"] + exit_code = main(args) + + # Verify failure + self.assertEqual(exit_code, 1) + mock_client.create_report_from_measurement.assert_called_once_with( + ANY, + ) + + # Verify failed measurement was dumped + mock_stdout.write.assert_called() + + finally: + os.unlink(temp_path) + + @patch("sys.stderr") + def test_main_invalid_file(self, mock_stderr): + """Test handling of invalid measurement file.""" + args = ["upload", "-f", "nonexistent_file.json"] + exit_code = main(args) + self.assertEqual(exit_code, 1) + + @patch("sys.stderr") + @patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") + @patch("os.unlink") + def test_main_delete_input(self, mock_unlink, mock_client_class, mock_stderr): + """Test deleting input file after successful upload.""" + # Create temporary file with measurement + with tempfile.NamedTemporaryFile(mode="w", delete=False) as tf: + json.dump(self.valid_measurement.as_dict(), tf) + tf.write("\n") + tf.flush() + temp_path = tf.name + + try: + # Create mock client instance + mock_client = mock_client_class.return_value + + # Mock successful API responses + mock_client.create_report_from_measurement.return_value = "test_report_id" + mock_client.update_report.return_value = "test_measurement_uid" + + # Run main with delete flag + args = ["upload", "-f", temp_path, "--delete-input-file"] + exit_code = main(args) + + # Verify success and file deletion + self.assertEqual(exit_code, 0) + mock_unlink.assert_called_once_with(temp_path) + + finally: + if os.path.exists(temp_path): + os.unlink(temp_path) + + @patch("sys.stderr") + def test_main_unknown_command(self, mock_stderr): + """Test main with unknown command.""" + args = ["unknown-command"] + with self.assertRaises(SystemExit) as cm: + main(args) + self.assertEqual(cm.exception.code, 2) # argparse uses exit code 2 + + @patch("sys.stderr") + def test_main_no_measurements(self, mock_stderr): + """Test main with empty measurement file.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False) as tf: + temp_path = tf.name + + try: + args = ["upload", "-f", temp_path] + exit_code = main(args) + self.assertEqual(exit_code, 1) + finally: + os.unlink(temp_path) + + @patch("sys.stderr") + @patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") + def test_main_no_measurement_uid(self, mock_client_class, mock_stderr): + """Test handling of missing measurement UID.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False) as tf: + json.dump(self.valid_measurement.as_dict(), tf) + tf.write("\n") + tf.flush() + temp_path = tf.name + + try: + mock_client = mock_client_class.return_value + mock_client.create_report_from_measurement.return_value = "test_report_id" + mock_client.update_report.return_value = None # No measurement UID + + args = ["upload", "-f", temp_path] + exit_code = main(args) + + self.assertEqual(exit_code, 0) + mock_client.update_report.assert_called_once_with("test_report_id", ANY) + finally: + os.unlink(temp_path) + + @patch("sys.stderr") + @patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") + def test_main_unlink_failure(self, mock_client_class, mock_stderr): + """Test handling of file deletion failure.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False) as tf: + json.dump(self.valid_measurement.as_dict(), tf) + tf.write("\n") + tf.flush() + temp_path = tf.name + + try: + mock_client = mock_client_class.return_value + mock_client.create_report_from_measurement.return_value = "test_report_id" + mock_client.update_report.return_value = "test_measurement_uid" + + with patch("os.unlink") as mock_unlink: + mock_unlink.side_effect = OSError("Permission denied") + + args = ["upload", "-f", temp_path, "--delete-input-file"] + exit_code = main(args) + + self.assertEqual(exit_code, 1) # Should fail if can't delete + mock_unlink.assert_called_once_with(temp_path) + finally: + os.unlink(temp_path) + + +if __name__ == "__main__": + unittest.main() -- GitLab From 1476cbd889329115090b75f973946985fbeb179b Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 15 Feb 2025 19:30:51 +0100 Subject: [PATCH 46/75] feat(globalscope): write unit tests --- .../globalscope/__init__.py | 144 +-------- .../globalscope/aggregate.py | 135 ++++++++ .../globalscope/__init__.py | 0 .../globalscope/test_aggregate.py | 296 ++++++++++++++++++ 4 files changed, 437 insertions(+), 138 deletions(-) create mode 100644 aggregatetunnelmetrics/globalscope/aggregate.py create mode 100644 tests/aggregatetunnelmetrics/globalscope/__init__.py create mode 100644 tests/aggregatetunnelmetrics/globalscope/test_aggregate.py diff --git a/aggregatetunnelmetrics/globalscope/__init__.py b/aggregatetunnelmetrics/globalscope/__init__.py index 1467502..2bd9f3c 100644 --- a/aggregatetunnelmetrics/globalscope/__init__.py +++ b/aggregatetunnelmetrics/globalscope/__init__.py @@ -7,142 +7,10 @@ See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-ag # SPDX-License-Identifier: GPL-3.0-or-later -from __future__ import annotations +from .aggregate import ( + AggregateProtocolState, + AggregateState, + AggregatorConfig, +) -from dataclasses import dataclass, field -from datetime import datetime, timezone -from enum import Enum - -from .. import fieldtestingcsv - - -@dataclass -class AggregatorConfig: - """ - Configuration for the measurement aggregator. - """ - - provider: str - upstream_collector: str - probe_asn: str - probe_cc: str - - # threshold below which we emit sample_size - min_sample_size: int = 1000 - - # rounding sample_size to the nearest round_to - round_to: int = 100 - - software_name: str = "solitech-aggregate-tunnel-metrics" - software_version: str = "0.1.0" - - -def datetime_to_compact_utc(dt: datetime) -> str: - """Convert datetime to compact UTC format (YYYYMMDDThhmmssZ)""" - return dt.astimezone(timezone.utc).strftime("%Y%m%dT%H%M%SZ") - - -@dataclass -class AggregateProtocolState: - """Flat representation of the ggregated state at global scope.""" - - # Core identification - provider: str - - # Scope identification - protocol: str - - # Time window - window_start: datetime - window_end: datetime - - # Statistics about the creation phase - creation: dict[str, int] = field(default_factory=dict) - - # Statistics about the tunnel ping phase - tunnel_ping_min: list[float] = field(default_factory=list) - tunnel_ping_avg: list[float] = field(default_factory=list) - tunnel_ping_max: list[float] = field(default_factory=list) - tunnel_ping_loss: list[float] = field(default_factory=list) - - # Statistics about the NDT phase - tunnel_ndt_download_throughput: list[float] = field(default_factory=list) - tunnel_ndt_download_latency: list[float] = field(default_factory=list) - tunnel_ndt_download_rexmit: list[float] = field(default_factory=list) - tunnel_ndt_upload_throughput: list[float] = field(default_factory=list) - tunnel_ndt_upload_latency: list[float] = field(default_factory=list) - tunnel_ndt_upload_rexmit: list[float] = field(default_factory=list) - - def _update_error_counts(self, entry: fieldtestingcsv.Entry) -> None: - """Update error counts based on a new entry""" - error_type = ( - "bootstrap.generic_error" if entry.is_tunnel_error_measurement() else "" - ) - self.creation[error_type] = self.creation.get(error_type, 0) + 1 - - def _update_performance_metrics(self, entry: fieldtestingcsv.Entry) -> None: - """Update performance metrics based on a new entry""" - if not entry.is_tunnel_error_measurement(): # only successful measurements - self._update_ping(entry) - self._update_ndt(entry) - - def _update_ping(self, entry: fieldtestingcsv.Entry) -> None: - """Unconditionally update the ping metrics.""" - self.tunnel_ping_min.append(entry.ping_roundtrip_min) - self.tunnel_ping_avg.append(entry.ping_roundtrip_avg) - self.tunnel_ping_max.append(entry.ping_roundtrip_max) - self.tunnel_ping_loss.append(entry.ping_packets_loss) - - def _update_ndt(self, entry: fieldtestingcsv.Entry) -> None: - """Unconditionally update the NDT metrics.""" - self.tunnel_ndt_upload_throughput.append(entry.throughput_download) - self.tunnel_ndt_download_latency.append(entry.latency_download) - self.tunnel_ndt_download_rexmit.append(entry.retransmission_download) - self.tunnel_ndt_upload_throughput.append(entry.throughput_upload) - self.tunnel_ndt_upload_latency.append(entry.latency_upload) - self.tunnel_ndt_upload_rexmit.append(entry.retransmission_upload) - - def update(self, entry: fieldtestingcsv.Entry) -> None: - """ - Update aggregator state with a new measurement. - """ - self._update_error_counts(entry) - self._update_performance_metrics(entry) - - -@dataclass -class AggregateState: - """Aggregates measurements by protocol at global scope.""" - - config: AggregatorConfig - window_start: datetime - window_end: datetime - protocols: dict[str, AggregateProtocolState] = field(default_factory=dict) - - def _is_in_window(self, entry: fieldtestingcsv.Entry) -> bool: - """Check if entry falls within our time window""" - return self.window_start <= entry.date < self.window_end - - def _is_tunnel_entry(self, entry: fieldtestingcsv.Entry) -> bool: - """Check if entry is a tunnel measurement""" - return entry.is_tunnel_measurement() - - def update(self, entry: fieldtestingcsv.Entry) -> None: - """Update aggregator state with a new measurement.""" - # Ensure we're in window and we're looking at a tunnel entry - if not self._is_in_window(entry): - return - if not self._is_tunnel_entry(entry): - return - - # Get or create state for this protocol - if entry.protocol not in self.protocols: - self.protocols[entry.protocol] = AggregateProtocolState( - provider=self.config.provider, - protocol=entry.protocol, - window_start=self.window_start, - window_end=self.window_end, - ) - - # Update the protocol-specific state - self.protocols[entry.protocol].update(entry) +__all__ = ["AggregateProtocolState", "AggregateState", "AggregatorConfig"] diff --git a/aggregatetunnelmetrics/globalscope/aggregate.py b/aggregatetunnelmetrics/globalscope/aggregate.py new file mode 100644 index 0000000..4666e14 --- /dev/null +++ b/aggregatetunnelmetrics/globalscope/aggregate.py @@ -0,0 +1,135 @@ +""" +Internal implementation of aggregation. + +Please, import `globalscope` instead. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + +from dataclasses import dataclass, field +from datetime import datetime, timezone +from enum import Enum + +from .. import fieldtestingcsv + + +@dataclass(frozen=True) +class AggregatorConfig: + """ + Configuration for the measurement aggregator. + """ + + provider: str + + +def datetime_to_compact_utc(dt: datetime) -> str: + """Convert datetime to compact UTC format (YYYYMMDDThhmmssZ)""" + return dt.astimezone(timezone.utc).strftime("%Y%m%dT%H%M%SZ") + + +@dataclass +class AggregateProtocolState: + """Flat representation of the ggregated state at global scope.""" + + # Core identification + provider: str + + # Scope identification + protocol: str + + # Time window + window_start: datetime + window_end: datetime + + # Statistics about the creation phase + creation: dict[str, int] = field(default_factory=dict) + + # Statistics about the tunnel ping phase + tunnel_ping_min: list[float] = field(default_factory=list) + tunnel_ping_avg: list[float] = field(default_factory=list) + tunnel_ping_max: list[float] = field(default_factory=list) + tunnel_ping_loss: list[float] = field(default_factory=list) + + # Statistics about the NDT phase + tunnel_ndt_download_throughput: list[float] = field(default_factory=list) + tunnel_ndt_download_latency: list[float] = field(default_factory=list) + tunnel_ndt_download_rexmit: list[float] = field(default_factory=list) + tunnel_ndt_upload_throughput: list[float] = field(default_factory=list) + tunnel_ndt_upload_latency: list[float] = field(default_factory=list) + tunnel_ndt_upload_rexmit: list[float] = field(default_factory=list) + + def _update_error_counts(self, entry: fieldtestingcsv.Entry) -> None: + """Update error counts based on a new entry""" + error_type = ( + "bootstrap.generic_error" if entry.is_tunnel_error_measurement() else "" + ) + self.creation[error_type] = self.creation.get(error_type, 0) + 1 + + def _update_performance_metrics(self, entry: fieldtestingcsv.Entry) -> None: + """Update performance metrics based on a new entry""" + if not entry.is_tunnel_error_measurement(): # only successful measurements + self._update_ping(entry) + self._update_ndt(entry) + + def _update_ping(self, entry: fieldtestingcsv.Entry) -> None: + """Unconditionally update the ping metrics.""" + self.tunnel_ping_min.append(entry.ping_roundtrip_min) + self.tunnel_ping_avg.append(entry.ping_roundtrip_avg) + self.tunnel_ping_max.append(entry.ping_roundtrip_max) + self.tunnel_ping_loss.append(entry.ping_packets_loss) + + def _update_ndt(self, entry: fieldtestingcsv.Entry) -> None: + """Unconditionally update the NDT metrics.""" + self.tunnel_ndt_download_throughput.append(entry.throughput_download) + self.tunnel_ndt_download_latency.append(entry.latency_download) + self.tunnel_ndt_download_rexmit.append(entry.retransmission_download) + self.tunnel_ndt_upload_throughput.append(entry.throughput_upload) + self.tunnel_ndt_upload_latency.append(entry.latency_upload) + self.tunnel_ndt_upload_rexmit.append(entry.retransmission_upload) + + def update(self, entry: fieldtestingcsv.Entry) -> None: + """ + Update aggregator state with a new measurement. + """ + self._update_error_counts(entry) + self._update_performance_metrics(entry) + + +@dataclass +class AggregateState: + """Aggregates measurements by protocol at global scope.""" + + config: AggregatorConfig + window_start: datetime + window_end: datetime + protocols: dict[str, AggregateProtocolState] = field(default_factory=dict) + + def _is_in_window(self, entry: fieldtestingcsv.Entry) -> bool: + """Check if entry falls within our time window""" + return self.window_start <= entry.date < self.window_end + + def _is_tunnel_entry(self, entry: fieldtestingcsv.Entry) -> bool: + """Check if entry is a tunnel measurement""" + return entry.is_tunnel_measurement() + + def update(self, entry: fieldtestingcsv.Entry) -> None: + """Update aggregator state with a new measurement.""" + # Ensure we're in window and we're looking at a tunnel entry + if not self._is_in_window(entry): + return + if not self._is_tunnel_entry(entry): + return + + # Get or create state for this protocol + if entry.protocol not in self.protocols: + self.protocols[entry.protocol] = AggregateProtocolState( + provider=self.config.provider, + protocol=entry.protocol, + window_start=self.window_start, + window_end=self.window_end, + ) + + # Update the protocol-specific state + self.protocols[entry.protocol].update(entry) diff --git a/tests/aggregatetunnelmetrics/globalscope/__init__.py b/tests/aggregatetunnelmetrics/globalscope/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/aggregatetunnelmetrics/globalscope/test_aggregate.py b/tests/aggregatetunnelmetrics/globalscope/test_aggregate.py new file mode 100644 index 0000000..020c085 --- /dev/null +++ b/tests/aggregatetunnelmetrics/globalscope/test_aggregate.py @@ -0,0 +1,296 @@ +"""Tests for global scope aggregation functionality.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +import unittest +from datetime import datetime, timedelta, timezone + +from aggregatetunnelmetrics.globalscope.aggregate import ( + AggregatorConfig, + AggregateState, + AggregateProtocolState, + datetime_to_compact_utc, +) +from aggregatetunnelmetrics.fieldtestingcsv import Entry + + +class TestDateTimeFormatting(unittest.TestCase): + """Tests for datetime formatting functions.""" + + def test_datetime_to_compact_utc(self): + """Test conversion of datetime to compact UTC format.""" + dt = datetime(2023, 1, 15, 14, 30, 0, tzinfo=timezone.utc) + self.assertEqual(datetime_to_compact_utc(dt), "20230115T143000Z") + + +class TestAggregateState(unittest.TestCase): + """Tests for AggregateState functionality.""" + + def setUp(self): + """Set up common test fixtures.""" + self.config = AggregatorConfig( + provider="test-provider", + ) + self.now = datetime.now(timezone.utc) + self.window_start = self.now + self.window_end = self.now + timedelta(hours=1) + + self.sample_entry = Entry( + filename="test.csv", + date=self.window_start + timedelta(minutes=30), + asn="AS12345", + isp="Test ISP", + est_city="Test City", + user="testuser", + region="testregion", + server_fqdn="test.server.com", + server_ip="1.1.1.1", + mobile=False, + tunnel="tunnel", + protocol="openvpn", + throughput_download=100.0, + throughput_upload=80.0, + latency_download=50.0, + latency_upload=60.0, + retransmission_download=0.1, + retransmission_upload=0.2, + ping_packets_loss=0.0, + ping_roundtrip_min=10.0, + ping_roundtrip_avg=15.0, + ping_roundtrip_max=20.0, + err_message="", + ) + + def test_aggregate_state_initialization(self): + """Test proper initialization of AggregateState.""" + state = AggregateState( + config=self.config, + window_start=self.window_start, + window_end=self.window_end, + ) + + self.assertEqual(state.config, self.config) + self.assertEqual(state.window_start, self.window_start) + self.assertEqual(state.window_end, self.window_end) + self.assertEqual(state.protocols, {}) + + def test_aggregate_state_update_with_valid_entry(self): + """Test updating state with a valid measurement entry.""" + state = AggregateState( + config=self.config, + window_start=self.window_start, + window_end=self.window_end, + ) + + state.update(self.sample_entry) + + self.assertIn("openvpn", state.protocols) + protocol_state = state.protocols["openvpn"] + self.assertEqual(protocol_state.protocol, "openvpn") + self.assertEqual(protocol_state.provider, self.config.provider) + self.assertEqual(len(protocol_state.tunnel_ping_min), 1) + self.assertEqual(protocol_state.tunnel_ping_min[0], 10.0) + + def test_aggregate_state_ignore_out_of_window_entry(self): + """Test that out-of-window entries are ignored.""" + state = AggregateState( + config=self.config, + window_start=self.window_start, + window_end=self.window_end, + ) + + out_of_window_entry = Entry( + filename="test.csv", + date=self.window_start - timedelta(hours=1), + asn="AS12345", + isp="Test ISP", + est_city="Test City", + user="testuser", + region="testregion", + server_fqdn="test.server.com", + server_ip="1.1.1.1", + mobile=False, + tunnel="tunnel", + protocol="openvpn", + throughput_download=100.0, + throughput_upload=80.0, + latency_download=50.0, + latency_upload=60.0, + retransmission_download=0.1, + retransmission_upload=0.2, + ping_packets_loss=0.0, + ping_roundtrip_min=10.0, + ping_roundtrip_avg=15.0, + ping_roundtrip_max=20.0, + err_message="", + ) + + state.update(out_of_window_entry) + self.assertEqual(len(state.protocols), 0) + + def test_aggregate_state_with_baseline_measurement(self): + """Test that baseline measurements are ignored.""" + state = AggregateState( + config=self.config, + window_start=self.window_start, + window_end=self.window_end, + ) + + baseline_entry = Entry( + filename="test.csv", + date=self.window_start + timedelta(minutes=30), + asn="AS12345", + isp="Test ISP", + est_city="Test City", + user="testuser", + region="testregion", + server_fqdn="test.server.com", + server_ip="1.1.1.1", + mobile=False, + tunnel="baseline", + protocol="openvpn", + throughput_download=100.0, + throughput_upload=80.0, + latency_download=50.0, + latency_upload=60.0, + retransmission_download=0.1, + retransmission_upload=0.2, + ping_packets_loss=0.0, + ping_roundtrip_min=10.0, + ping_roundtrip_avg=15.0, + ping_roundtrip_max=20.0, + err_message="", + ) + + state.update(baseline_entry) + self.assertEqual(len(state.protocols), 0) + + def test_multiple_protocols(self): + """Test handling measurements from multiple protocols.""" + state = AggregateState( + config=self.config, + window_start=self.window_start, + window_end=self.window_end, + ) + + # Create entries for different protocols + protocols = ["openvpn", "obfs4"] + for protocol in protocols: + entry = Entry( + filename="test.csv", + date=self.window_start + timedelta(minutes=30), + asn="AS12345", + isp="Test ISP", + est_city="Test City", + user="testuser", + region="testregion", + server_fqdn="test.server.com", + server_ip="1.1.1.1", + mobile=False, + tunnel="tunnel", + protocol=protocol, + throughput_download=100.0, + throughput_upload=80.0, + latency_download=50.0, + latency_upload=60.0, + retransmission_download=0.1, + retransmission_upload=0.2, + ping_packets_loss=0.0, + ping_roundtrip_min=10.0, + ping_roundtrip_avg=15.0, + ping_roundtrip_max=20.0, + err_message="", + ) + state.update(entry) + + self.assertEqual(len(state.protocols), 2) + for protocol in protocols: + self.assertIn(protocol, state.protocols) + self.assertEqual(len(state.protocols[protocol].tunnel_ping_min), 1) + + +class TestAggregateProtocolState(unittest.TestCase): + """Tests for AggregateProtocolState functionality.""" + + def setUp(self): + """Set up common test fixtures.""" + self.now = datetime.now(timezone.utc) + self.protocol_state = AggregateProtocolState( + provider="test-provider", + protocol="openvpn", + window_start=self.now, + window_end=self.now + timedelta(hours=1), + ) + + self.sample_entry = Entry( + filename="test.csv", + date=self.now + timedelta(minutes=30), + asn="AS12345", + isp="Test ISP", + est_city="Test City", + user="testuser", + region="testregion", + server_fqdn="test.server.com", + server_ip="1.1.1.1", + mobile=False, + tunnel="tunnel", + protocol="openvpn", + throughput_download=100.0, + throughput_upload=80.0, + latency_download=50.0, + latency_upload=60.0, + retransmission_download=0.1, + retransmission_upload=0.2, + ping_packets_loss=0.0, + ping_roundtrip_min=10.0, + ping_roundtrip_avg=15.0, + ping_roundtrip_max=20.0, + err_message="", + ) + + def test_error_counting(self): + """Test counting of error measurements.""" + error_entry = Entry( + filename="test.csv", + date=self.now + timedelta(minutes=30), + asn="AS12345", + isp="Test ISP", + est_city="Test City", + user="testuser", + region="testregion", + server_fqdn="test.server.com", + server_ip="1.1.1.1", + mobile=False, + tunnel="ERROR/tunnel", + protocol="openvpn", + throughput_download=0.0, + throughput_upload=0.0, + latency_download=0.0, + latency_upload=0.0, + retransmission_download=0.0, + retransmission_upload=0.0, + ping_packets_loss=0.0, + ping_roundtrip_min=0.0, + ping_roundtrip_avg=0.0, + ping_roundtrip_max=0.0, + err_message="Connection failed", + ) + + self.protocol_state.update(error_entry) + self.assertEqual(self.protocol_state.creation["bootstrap.generic_error"], 1) + # Verify no performance metrics were recorded + self.assertEqual(len(self.protocol_state.tunnel_ping_min), 0) + + def test_performance_metrics_collection(self): + """Test collection of performance metrics.""" + self.protocol_state.update(self.sample_entry) + + self.assertEqual(len(self.protocol_state.tunnel_ping_min), 1) + self.assertEqual(len(self.protocol_state.tunnel_ping_avg), 1) + self.assertEqual(len(self.protocol_state.tunnel_ndt_download_throughput), 1) + self.assertEqual(self.protocol_state.tunnel_ping_min[0], 10.0) + self.assertEqual(self.protocol_state.tunnel_ndt_download_throughput[0], 100.0) + + +if __name__ == "__main__": + unittest.main() -- GitLab From 13854d9b03ab9c9b31ccbcaecee59d3f7140f248 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 15 Feb 2025 21:33:22 +0100 Subject: [PATCH 47/75] feat: write tests for ooniformatter --- .../globalscope/__init__.py | 8 +- .../ooniformatter/__init__.py | 384 +---------- .../ooniformatter/formatter.py | 415 +++++++++++ .../ooniformatter/__init__.py | 0 .../ooniformatter/test_formatter.py | 645 ++++++++++++++++++ 5 files changed, 1078 insertions(+), 374 deletions(-) create mode 100644 aggregatetunnelmetrics/ooniformatter/formatter.py create mode 100644 tests/aggregatetunnelmetrics/ooniformatter/__init__.py create mode 100644 tests/aggregatetunnelmetrics/ooniformatter/test_formatter.py diff --git a/aggregatetunnelmetrics/globalscope/__init__.py b/aggregatetunnelmetrics/globalscope/__init__.py index 2bd9f3c..3b56c84 100644 --- a/aggregatetunnelmetrics/globalscope/__init__.py +++ b/aggregatetunnelmetrics/globalscope/__init__.py @@ -11,6 +11,12 @@ from .aggregate import ( AggregateProtocolState, AggregateState, AggregatorConfig, + datetime_to_compact_utc, ) -__all__ = ["AggregateProtocolState", "AggregateState", "AggregatorConfig"] +__all__ = [ + "AggregateProtocolState", + "AggregateState", + "AggregatorConfig", + "datetime_to_compact_utc", +] diff --git a/aggregatetunnelmetrics/ooniformatter/__init__.py b/aggregatetunnelmetrics/ooniformatter/__init__.py index 10c0942..fd51bdb 100644 --- a/aggregatetunnelmetrics/ooniformatter/__init__.py +++ b/aggregatetunnelmetrics/ooniformatter/__init__.py @@ -6,378 +6,16 @@ See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-ag # SPDX-License-Identifier: GPL-3.0-or-later -from datetime import datetime -from dataclasses import dataclass -from statistics import quantiles -from typing import Any, Optional -from urllib.parse import urlunparse, urlencode - -import logging - -from ..globalscope import ( - AggregateProtocolState, - AggregatorConfig, - AggregateState, - datetime_to_compact_utc, +from .formatter import ( + AggregationTimeWindow, + TestKeys, + SerializationConfigError, + Serializer, ) -from ..oonireport import Measurement - - -@dataclass -class AggregationTimeWindow: - """Time window for aggregating measurements""" - - from_time: datetime - to_time: datetime - - def as_dict(self) -> dict: - """Convert to JSON-serializable dict""" - return { - "from": datetime_to_compact_utc(self.from_time), - "to": datetime_to_compact_utc(self.to_time), - } - - -@dataclass -class TestKeys: - """ - Models the test_keys portion of an OONI measurement as defined - in the aggregate tunnel metrics specification. - """ - - # Mandatory fields - provider: str - scope: str - protocol: str - time_window: AggregationTimeWindow - - # Optional fields depending on scope - endpoint_hostname: Optional[str] - endpoint_address: Optional[str] - endpoint_port: Optional[int] - asn: Optional[str] # Format: ^AS[0-9]+$ - cc: Optional[str] # Format: ^[A-Z]{2}$ - - # The bodies should always be present - bodies: list[dict[str, Any]] - - def as_dict(self) -> dict: - """Convert to JSON-serializable dict""" - # Start with required fields - d = { - "provider": self.provider, - "scope": self.scope, - "protocol": self.protocol, - "time_window": self.time_window.as_dict(), - "bodies": self.bodies, - } - - # Add optional fields if they exist - for field in ( - "endpoint_hostname", - "endpoint_address", - "endpoint_port", - "asn", - "cc", - ): - value = getattr(self, field) - if value is not None: - d[field] = value - - return d - - -class SerializationConfigError(Exception): - """Raised when serialization configuration does not allow us to proceed.""" - - -class Serializer: - """Converts aggregate endpoint state into OONI measurements""" - - def __init__(self, config: AggregatorConfig): - self.config = config - - @staticmethod - def _compute_percentiles(values: list[float]) -> dict[str, float]: - """Compute the required percentiles for OONI format""" - - if not values: - return {} - - q = quantiles(values, n=100, method="exclusive") - return { - "25p": round(q[24], 1), - "50p": round(q[49], 1), - "75p": round(q[74], 1), - "99p": round(q[98], 1), - } - - def _create_input_url(self, state: AggregateProtocolState) -> str: - """Create the measurement input URL""" - # The query is empty when using the global state - query = {} - - # Build URL using urlunparse for safety - return urlunparse( - ( - state.protocol, # scheme (e.g., "openvpn+obfs4") - state.provider, # netloc (e.g., "riseup.net") - "/", # path - "", # params - urlencode(query), # query (e.g., "address=1.2.3.4&...") - "", # fragment - ) - ) - - def _round_sample_size(self, sample_size: int) -> Optional[int]: - """Round the sample size according to the aggregate tunnel metrics spec.""" - if sample_size < self.config.min_sample_size: - return None - return round(sample_size / self.config.round_to) * self.config.round_to - - @staticmethod - def _maybe_with_sample_size( - obj: dict[str, Any], ss: Optional[int] - ) -> dict[str, Any]: - if ss is not None: - obj["sample_size"] = ss - return obj - - def _create_error_bodies( - self, state: AggregateProtocolState - ) -> list[dict[str, Any]]: - """Create error bodies if there are any errors""" - bodies = [] - total = sum(state.creation.values()) - if total > 0: - for error_type, count in state.creation.items(): - if not error_type: # Skip success counts - continue - bodies.append( - self._maybe_with_sample_size( - { - "phase": "creation", - "type": "network-error", - "failure_ratio": round(count / total, 2), - "error": error_type, - }, - self._round_sample_size(count), - ) - ) - return bodies - - def _validate_ping_measurements( - self, state: AggregateProtocolState, metric_type: str, measurements: list[float] - ) -> None: - """Validate ping measurements""" - if metric_type in ["min", "avg", "max"]: - for min_v, avg_v, max_v in zip( - state.tunnel_ping_min, state.tunnel_ping_avg, state.tunnel_ping_max - ): - if not (0 <= min_v <= avg_v <= max_v): - raise SerializationConfigError("invalid ping latency ordering") - if max_v > 60000: # 60 seconds - raise SerializationConfigError("unreasonable ping latency") - elif metric_type == "loss": - for loss in state.tunnel_ping_loss: - if not (0 <= loss <= 100): - raise SerializationConfigError("ping loss out of range") - - def _create_ping_bodies( - self, state: AggregateProtocolState - ) -> list[dict[str, Any]]: - """Create bodies for ping measurements""" - bodies = [] - items = ( - ("min", "latency_ms", state.tunnel_ping_min), - ("avg", "latency_ms", state.tunnel_ping_avg), - ("max", "latency_ms", state.tunnel_ping_max), - ("loss", "loss_percent", state.tunnel_ping_loss), - ) - for metric_type, key, measurements in items: - if measurements: # Only if we have measurements - try: - self._validate_ping_measurements(state, metric_type, measurements) - except SerializationConfigError as e: - logging.warning(str(e)) - continue - bodies.append( - self._maybe_with_sample_size( - { - "phase": "tunnel_ping", - "type": f"ping_{metric_type}", - "target_address": "", - key: self._compute_percentiles(measurements), - }, - self._round_sample_size(len(measurements)), - ) - ) - return bodies - - def _validate_ndt_measurements( - self, - throughput: list[float], - latency: list[float], - rexmit: list[float], - phase: str, - ) -> None: - """Validate NDT measurements""" - if len({len(throughput), len(latency), len(rexmit)}) > 1: - raise SerializationConfigError( - f"inconsistent NDT {phase} measurement counts" - ) - for t, l, r in zip(throughput, latency, rexmit): - if t < 0 or t > 10000: # 10 Gbps - raise SerializationConfigError(f"unreasonable NDT {phase} throughput") - if l < 0 or l > 60000: # 60 seconds - raise SerializationConfigError(f"unreasonable NDT {phase} latency") - if not (0 <= r <= 100): - raise SerializationConfigError( - f"NDT {phase} retransmission out of range" - ) - - def _create_ndt_bodies(self, state: AggregateProtocolState) -> list[dict[str, Any]]: - """Create bodies for NDT measurements""" - bodies = [] - items = ( - ( - "download", - state.tunnel_ndt_download_throughput, - state.tunnel_ndt_download_latency, - state.tunnel_ndt_download_rexmit, - ), - ( - "upload", - state.tunnel_ndt_upload_throughput, - state.tunnel_ndt_upload_latency, - state.tunnel_ndt_upload_rexmit, - ), - ) - for phase, throughput, latency, rexmit in items: - try: - self._validate_ndt_measurements(throughput, latency, rexmit, phase) - except SerializationConfigError as e: - logging.warning(str(e)) - continue - bodies.append( - self._maybe_with_sample_size( - { - "phase": f"tunnel_ndt_{phase}", - "type": f"ndt_{phase}", - "target_hostname": "", - "target_address": "", - "target_port": 0, - "latency_ms": self._compute_percentiles(latency), - "speed_mbits": self._compute_percentiles(throughput), - "retransmission_percent": self._compute_percentiles(rexmit), - }, - self._round_sample_size(len(throughput)), - ) - ) - return bodies - - def _create_global_bodies( - self, state: AggregateProtocolState - ) -> list[dict[str, Any]]: - """Create the bodies section of test_keys""" - bodies = [] - bodies.extend(self._create_error_bodies(state)) - bodies.extend(self._create_ping_bodies(state)) - bodies.extend(self._create_ndt_bodies(state)) - return bodies - - def _is_valid_state(self, state: AggregateProtocolState) -> bool: - """ - Validates the state before serialization. Returns False if state - should be skipped, True if it's valid to serialize. - - Logs warning messages explaining validation failures. - """ - - # Basic field validations - if not state.provider or not isinstance(state.provider, str): - logging.warning("invalid provider field") - return False - if not state.protocol or not isinstance(state.protocol, str): - logging.warning("invalid protocol field") - return False - if not state.window_start or not state.window_end: - logging.warning("invalid time window") - return False - if state.window_end <= state.window_start: - logging.warning("end time before start time") - return False - - # Creation phase validations - if not state.creation: - logging.warning("no creation phase data") - return False - if any(count < 0 for count in state.creation.values()): - logging.warning("negative creation counts") - return False - if sum(state.creation.values()) == 0: - logging.warning("no measurements") - return False - - # Logical consistency validations - success_count = state.creation.get("", 0) - has_measurements = bool( - state.tunnel_ping_min or state.tunnel_ndt_download_throughput - ) - if success_count > 0 and not has_measurements: - logging.warning("successful creations but no measurements") - return False - if success_count == 0 and has_measurements: - logging.warning("measurements without successful creations") - return False - - return True - - def serialize_global(self, state: AggregateState) -> list[Measurement]: - """ - Convert global state to OONI measurement format. - - Raises: - SerializationError: if the scope is not model.Scope.ENDPOINT. - """ - measurement_time = datetime.utcnow() - measurements = [] - - for proto_name, proto_state in state.protocols.items(): - if not self._is_valid_state(proto_state): - logging.warning(f"skipping invalid state for protocol {proto_name}") - continue - - test_keys = TestKeys( - provider=state.config.provider, - scope="global", - protocol=proto_name, - time_window=AggregationTimeWindow( - from_time=state.window_start, to_time=state.window_end - ), - endpoint_hostname=None, - endpoint_address=None, - endpoint_port=None, - asn=None, - cc=None, - bodies=self._create_global_bodies(proto_state), - ) - - mx = Measurement( - annotations={"upstream_collector": self.config.upstream_collector}, - data_format_version="0.2.0", - input=self._create_input_url(proto_state), - measurement_start_time=measurement_time, - probe_asn=self.config.probe_asn, - probe_cc=self.config.probe_cc, - software_name=self.config.software_name, - software_version=self.config.software_version, - test_keys=test_keys, - test_name="aggregate_tunnel_metrics", - test_runtime=0.0, - test_start_time=measurement_time, - test_version="0.1.0", - ) - measurements.append(mx) - return measurements +__all__ = [ + "AggregationTimeWindow", + "SerializationConfigError", + "Serializer", + "TestKeys", +] diff --git a/aggregatetunnelmetrics/ooniformatter/formatter.py b/aggregatetunnelmetrics/ooniformatter/formatter.py new file mode 100644 index 0000000..146035e --- /dev/null +++ b/aggregatetunnelmetrics/ooniformatter/formatter.py @@ -0,0 +1,415 @@ +""" +Internal OONI formatter implementation. + +Please, prefer to import `ooniformatter` instead. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from datetime import datetime, timezone +from dataclasses import dataclass +from statistics import quantiles +from typing import Any, Optional +from urllib.parse import urlunparse, urlencode + +import logging + +from ..globalscope import ( + AggregateProtocolState, + AggregatorConfig, + AggregateState, + datetime_to_compact_utc, +) +from ..oonireport import Measurement + + +@dataclass(frozen=True) +class FormatterConfig: + """ + Configuration for the OONI measurement formatter. + """ + + upstream_collector: str + probe_asn: str + probe_cc: str + min_sample_size: int = 1000 + round_to: int = 100 + software_name: str = "solitech-aggregate-tunnel-metrics" + software_version: str = "0.1.0" + + +@dataclass(frozen=True) +class AggregationTimeWindow: + """Time window for aggregating measurements""" + + from_time: datetime + to_time: datetime + + def as_dict(self) -> dict: + """Convert to JSON-serializable dict""" + return { + "from": datetime_to_compact_utc(self.from_time), + "to": datetime_to_compact_utc(self.to_time), + } + + +@dataclass +class TestKeys: + """ + Models the test_keys portion of an OONI measurement as defined + in the aggregate tunnel metrics specification. + """ + + # Mandatory fields + provider: str + scope: str + protocol: str + time_window: AggregationTimeWindow + + # Optional fields depending on scope + endpoint_hostname: Optional[str] + endpoint_address: Optional[str] + endpoint_port: Optional[int] + asn: Optional[str] # Format: ^AS[0-9]+$ + cc: Optional[str] # Format: ^[A-Z]{2}$ + + # The bodies should always be present + bodies: list[dict[str, Any]] + + def as_dict(self) -> dict: + """Convert to JSON-serializable dict""" + # Start with required fields + d = { + "provider": self.provider, + "scope": self.scope, + "protocol": self.protocol, + "time_window": self.time_window.as_dict(), + "bodies": self.bodies, + } + + # Add optional fields if they exist + for field in ( + "endpoint_hostname", + "endpoint_address", + "endpoint_port", + "asn", + "cc", + ): + value = getattr(self, field) + if value is not None: + d[field] = value + + return d + + +class SerializationConfigError(Exception): + """Raised when serialization configuration does not allow us to proceed.""" + + +class Serializer: + """Converts aggregate endpoint state into OONI measurements""" + + def __init__(self, ac: AggregatorConfig, fc: FormatterConfig): + self.aggregator_cfg = ac + self.formatter_cfg = fc + + @staticmethod + def _compute_percentiles(values: list[float]) -> dict[str, float]: + """Compute the required percentiles for OONI format""" + + # No values, no percentiles, no future 🎶 + if not values: + return {} + + # If we have just one value, use it for all percentiles + if len(values) == 1: + value = round(values[0], 1) + return { + "25p": value, + "50p": value, + "75p": value, + "99p": value, + } + + q = quantiles(values, n=100, method="exclusive") + return { + "25p": round(q[24], 1), + "50p": round(q[49], 1), + "75p": round(q[74], 1), + "99p": round(q[98], 1), + } + + def _create_input_url(self, state: AggregateProtocolState) -> str: + """Create the measurement input URL""" + # The query is empty when using the global state + query = {} + + # Build URL using urlunparse for safety + return urlunparse( + ( + state.protocol, # scheme (e.g., "openvpn+obfs4") + state.provider, # netloc (e.g., "riseup.net") + "/", # path + "", # params + urlencode(query), # query (e.g., "address=1.2.3.4&...") + "", # fragment + ) + ) + + def _round_sample_size(self, sample_size: int) -> Optional[int]: + """Round the sample size according to the aggregate tunnel metrics spec.""" + if sample_size < self.formatter_cfg.min_sample_size: + return None + return ( + round(sample_size / self.formatter_cfg.round_to) + * self.formatter_cfg.round_to + ) + + @staticmethod + def _maybe_with_sample_size( + obj: dict[str, Any], ss: Optional[int] + ) -> dict[str, Any]: + if ss is not None: + obj["sample_size"] = ss + return obj + + def _create_error_bodies( + self, state: AggregateProtocolState + ) -> list[dict[str, Any]]: + """Create error bodies if there are any errors""" + bodies = [] + total = sum(state.creation.values()) + if total > 0: + for error_type, count in state.creation.items(): + if not error_type: # Skip success counts + continue + bodies.append( + self._maybe_with_sample_size( + { + "phase": "creation", + "type": "network-error", + "failure_ratio": round(count / total, 2), + "error": error_type, + }, + self._round_sample_size(count), + ) + ) + return bodies + + def _validate_ping_measurements( + self, state: AggregateProtocolState, metric_type: str, measurements: list[float] + ) -> None: + """Validate ping measurements""" + if metric_type in ["min", "avg", "max"]: + for min_v, avg_v, max_v in zip( + state.tunnel_ping_min, state.tunnel_ping_avg, state.tunnel_ping_max + ): + if not (0 <= min_v <= avg_v <= max_v): + raise SerializationConfigError("invalid ping latency ordering") + if max_v > 60000: # 60 seconds + raise SerializationConfigError("unreasonable ping latency") + elif metric_type == "loss": + for loss in state.tunnel_ping_loss: + if not (0 <= loss <= 100): + raise SerializationConfigError("ping loss out of range") + + def _create_ping_bodies( + self, state: AggregateProtocolState + ) -> list[dict[str, Any]]: + """Create bodies for ping measurements""" + bodies = [] + items = ( + ("min", "latency_ms", state.tunnel_ping_min), + ("avg", "latency_ms", state.tunnel_ping_avg), + ("max", "latency_ms", state.tunnel_ping_max), + ("loss", "loss_percent", state.tunnel_ping_loss), + ) + for metric_type, key, measurements in items: + if measurements: # Only if we have measurements + try: + self._validate_ping_measurements(state, metric_type, measurements) + except SerializationConfigError as e: + logging.warning(str(e)) + continue + bodies.append( + self._maybe_with_sample_size( + { + "phase": "tunnel_ping", + "type": f"ping_{metric_type}", + "target_address": "", + key: self._compute_percentiles(measurements), + }, + self._round_sample_size(len(measurements)), + ) + ) + return bodies + + def _validate_ndt_measurements( + self, + throughput: list[float], + latency: list[float], + rexmit: list[float], + phase: str, + ) -> None: + """Validate NDT measurements""" + if len({len(throughput), len(latency), len(rexmit)}) > 1: + raise SerializationConfigError( + f"inconsistent NDT {phase} measurement counts" + ) + for t, l, r in zip(throughput, latency, rexmit): + if t < 0 or t > 10000: # 10 Gbps + raise SerializationConfigError(f"unreasonable NDT {phase} throughput") + if l < 0 or l > 60000: # 60 seconds + raise SerializationConfigError(f"unreasonable NDT {phase} latency") + if not (0 <= r <= 100): + raise SerializationConfigError( + f"NDT {phase} retransmission out of range" + ) + + def _create_ndt_bodies(self, state: AggregateProtocolState) -> list[dict[str, Any]]: + """Create bodies for NDT measurements""" + bodies = [] + items = ( + ( + "download", + state.tunnel_ndt_download_throughput, + state.tunnel_ndt_download_latency, + state.tunnel_ndt_download_rexmit, + ), + ( + "upload", + state.tunnel_ndt_upload_throughput, + state.tunnel_ndt_upload_latency, + state.tunnel_ndt_upload_rexmit, + ), + ) + for phase, throughput, latency, rexmit in items: + try: + self._validate_ndt_measurements(throughput, latency, rexmit, phase) + except SerializationConfigError as e: + logging.warning(str(e)) + continue + bodies.append( + self._maybe_with_sample_size( + { + "phase": f"tunnel_ndt_{phase}", + "type": f"ndt_{phase}", + "target_hostname": "", + "target_address": "", + "target_port": 0, + "latency_ms": self._compute_percentiles(latency), + "speed_mbits": self._compute_percentiles(throughput), + "retransmission_percent": self._compute_percentiles(rexmit), + }, + self._round_sample_size(len(throughput)), + ) + ) + return bodies + + def _create_global_bodies( + self, state: AggregateProtocolState + ) -> list[dict[str, Any]]: + """Create the bodies section of test_keys""" + bodies = [] + bodies.extend(self._create_error_bodies(state)) + bodies.extend(self._create_ping_bodies(state)) + bodies.extend(self._create_ndt_bodies(state)) + return bodies + + def _is_valid_state(self, state: AggregateProtocolState) -> bool: + """ + Validates the state before serialization. Returns False if state + should be skipped, True if it's valid to serialize. + + Logs warning messages explaining validation failures. + """ + + # TODO(bassosimone): conventional wisdom says "parse, don't validate" + # but I am not sure how I could apply this principle here + + # Basic field validations + if not state.provider: + logging.warning("invalid provider field") + return False + if not state.protocol: + logging.warning("invalid protocol field") + return False + if state.window_end <= state.window_start: + logging.warning("end time before start time") + return False + + # Creation phase validations + if not state.creation: + logging.warning("no creation phase data") + return False + if any(count < 0 for count in state.creation.values()): + logging.warning("negative creation counts") + return False + if sum(state.creation.values()) == 0: + logging.warning("no measurements") + return False + + # Logical consistency validations + success_count = state.creation.get("", 0) + has_measurements = bool( + state.tunnel_ping_min or state.tunnel_ndt_download_throughput + ) + if success_count > 0 and not has_measurements: + logging.warning("successful creations but no measurements") + return False + if success_count == 0 and has_measurements: + logging.warning("measurements without successful creations") + return False + + return True + + def serialize_global(self, state: AggregateState) -> list[Measurement]: + """ + Convert global state to OONI measurement format. + + Raises: + SerializationError: if the scope is not model.Scope.ENDPOINT. + """ + measurement_time = datetime.now(timezone.utc) + measurements = [] + + for proto_name, proto_state in state.protocols.items(): + if not self._is_valid_state(proto_state): + logging.warning(f"skipping invalid state for protocol {proto_name}") + continue + + test_keys = TestKeys( + provider=state.config.provider, + scope="global", + protocol=proto_name, + time_window=AggregationTimeWindow( + from_time=state.window_start, to_time=state.window_end + ), + endpoint_hostname=None, + endpoint_address=None, + endpoint_port=None, + asn=None, + cc=None, + bodies=self._create_global_bodies(proto_state), + ) + + mx = Measurement( + annotations={ + "upstream_collector": self.formatter_cfg.upstream_collector + }, + data_format_version="0.2.0", + input=self._create_input_url(proto_state), + measurement_start_time=measurement_time, + probe_asn=self.formatter_cfg.probe_asn, + probe_cc=self.formatter_cfg.probe_cc, + software_name=self.formatter_cfg.software_name, + software_version=self.formatter_cfg.software_version, + test_keys=test_keys, + test_name="aggregate_tunnel_metrics", + test_runtime=0.0, + test_start_time=measurement_time, + test_version="0.1.0", + ) + measurements.append(mx) + + return measurements diff --git a/tests/aggregatetunnelmetrics/ooniformatter/__init__.py b/tests/aggregatetunnelmetrics/ooniformatter/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/aggregatetunnelmetrics/ooniformatter/test_formatter.py b/tests/aggregatetunnelmetrics/ooniformatter/test_formatter.py new file mode 100644 index 0000000..eb1cbe5 --- /dev/null +++ b/tests/aggregatetunnelmetrics/ooniformatter/test_formatter.py @@ -0,0 +1,645 @@ +"""Tests for OONI formatter functionality.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from datetime import datetime, timedelta, timezone + +from typing import cast +import unittest + +from aggregatetunnelmetrics.globalscope import ( + AggregatorConfig, + AggregateState, + AggregateProtocolState, +) +from aggregatetunnelmetrics.ooniformatter.formatter import ( + FormatterConfig, + Serializer, + AggregationTimeWindow, + TestKeys, + SerializationConfigError, +) + + +class TestAggregationTimeWindow(unittest.TestCase): + """Tests for AggregationTimeWindow functionality.""" + + def test_as_dict(self): + """Test time window serialization to dict.""" + start = datetime(2023, 1, 15, 14, 30, 0, tzinfo=timezone.utc) + end = start + timedelta(hours=1) + window = AggregationTimeWindow(from_time=start, to_time=end) + + expected = { + "from": "20230115T143000Z", + "to": "20230115T153000Z", + } + self.assertEqual(window.as_dict(), expected) + + +class TestSerializer(unittest.TestCase): + """Tests for measurement serialization.""" + + def setUp(self): + """Set up common test fixtures.""" + self.aggregator_config = AggregatorConfig(provider="test-provider") + self.formatter_config = FormatterConfig( + upstream_collector="test-collector", + probe_asn="AS12345", + probe_cc="XX", + min_sample_size=10, # Small value for testing + round_to=5, # Small value for testing + ) + self.serializer = Serializer(self.aggregator_config, self.formatter_config) + + # Set up a basic state for testing + self.now = datetime.now(timezone.utc) + self.state = AggregateState( + config=self.aggregator_config, + window_start=self.now, + window_end=self.now + timedelta(hours=1), + ) + + def test_compute_percentiles(self): + """Test percentile computation.""" + values = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0] + result = self.serializer._compute_percentiles(values) + + self.assertEqual(len(result), 4) # Should have all 4 percentiles + self.assertIn("25p", result) + self.assertIn("50p", result) + self.assertIn("75p", result) + self.assertIn("99p", result) + + def test_round_sample_size(self): + """Test sample size rounding logic.""" + # Below min_sample_size + self.assertIsNone(self.serializer._round_sample_size(5)) + + # Above min_sample_size + self.assertEqual( + self.serializer._round_sample_size(12), 10 + ) # Rounds to nearest 5 + self.assertEqual( + self.serializer._round_sample_size(13), 15 + ) # Rounds to nearest 5 + + def test_create_input_url(self): + """Test input URL creation.""" + proto_state = AggregateProtocolState( + provider="riseup.net", + protocol="openvpn+obfs4", + window_start=self.now, + window_end=self.now + timedelta(hours=1), + ) + + url = self.serializer._create_input_url(proto_state) + self.assertEqual(url, "openvpn+obfs4://riseup.net/") + + def test_validate_ping_measurements(self): + """Test ping measurement validation.""" + proto_state = AggregateProtocolState( + provider="test", + protocol="test", + window_start=self.now, + window_end=self.now + timedelta(hours=1), + ) + + # Test invalid ordering + proto_state.tunnel_ping_min = [10.0] + proto_state.tunnel_ping_avg = [9.0] # Less than min + proto_state.tunnel_ping_max = [11.0] + + with self.assertRaises(SerializationConfigError): + self.serializer._validate_ping_measurements( + proto_state, "min", proto_state.tunnel_ping_min + ) + + # Test unreasonable values + proto_state.tunnel_ping_max = [70000.0] # > 60 seconds + with self.assertRaises(SerializationConfigError): + self.serializer._validate_ping_measurements( + proto_state, "max", proto_state.tunnel_ping_max + ) + + def test_serialize_global_valid_state(self): + """Test serialization of valid global state.""" + proto_state = AggregateProtocolState( + provider=self.aggregator_config.provider, + protocol="openvpn", + window_start=self.now, + window_end=self.now + timedelta(hours=1), + ) + + # Add some valid measurements + proto_state.creation[""] = 100 # Success count + proto_state.tunnel_ping_min = [10.0, 11.0] + proto_state.tunnel_ping_avg = [12.0, 13.0] + proto_state.tunnel_ping_max = [14.0, 15.0] + proto_state.tunnel_ping_loss = [0.0, 1.0] + + self.state.protocols["openvpn"] = proto_state + + measurements = self.serializer.serialize_global(self.state) + + self.assertEqual(len(measurements), 1) + m = measurements[0] + self.assertEqual(m.probe_asn, "AS12345") + self.assertEqual(m.probe_cc, "XX") + self.assertEqual(m.test_name, "aggregate_tunnel_metrics") + tk = cast(TestKeys, m.test_keys) + self.assertEqual(tk.provider, "test-provider") + self.assertEqual(tk.protocol, "openvpn") + self.assertEqual(tk.scope, "global") + + def test_serialize_invalid_state(self): + """Test serialization with invalid state is skipped.""" + proto_state = AggregateProtocolState( + provider="", # Invalid empty provider + protocol="openvpn", + window_start=self.now, + window_end=self.now + timedelta(hours=1), + ) + + self.state.protocols["openvpn"] = proto_state + + measurements = self.serializer.serialize_global(self.state) + self.assertEqual(len(measurements), 0) + + def test_is_valid_state_invalid_provider(self): + """Test validation of invalid provider field.""" + state = AggregateProtocolState( + provider="", # Empty provider + protocol="openvpn", + window_start=self.now, + window_end=self.now + timedelta(hours=1), + ) + self.assertFalse(self.serializer._is_valid_state(state)) + + def test_is_valid_state_invalid_protocol(self): + """Test validation of invalid protocol field.""" + state = AggregateProtocolState( + provider="test-provider", + protocol="", # Empty protocol + window_start=self.now, + window_end=self.now + timedelta(hours=1), + ) + self.assertFalse(self.serializer._is_valid_state(state)) + + def test_is_valid_state_invalid_time_window(self): + """Test validation of invalid time window.""" + # End time before start time + state = AggregateProtocolState( + provider="test-provider", + protocol="openvpn", + window_start=self.now - timedelta(hours=1), + window_end=self.now, + ) + self.assertFalse(self.serializer._is_valid_state(state)) + + def test_is_valid_state_creation_phase(self): + """Test validation of creation phase data.""" + state = AggregateProtocolState( + provider="test-provider", + protocol="openvpn", + window_start=self.now, + window_end=self.now - timedelta(hours=1), + ) + + # No creation data + self.assertFalse(self.serializer._is_valid_state(state)) + + # Negative creation counts + state.creation = {"": -1} + self.assertFalse(self.serializer._is_valid_state(state)) + + # Zero total measurements + state.creation = {} + self.assertFalse(self.serializer._is_valid_state(state)) + + def test_is_valid_state_logical_consistency(self): + """Test validation of logical consistency between success counts and measurements.""" + state = AggregateProtocolState( + provider="test-provider", + protocol="openvpn", + window_start=self.now, + window_end=self.now + timedelta(hours=1), + ) + + # Negative creation values + state.creation = {"": 10, "error": -1} + self.assertFalse(self.serializer._is_valid_state(state)) + + # Sum of creation values is zero + state.creation = {"": 0, "error": 0} + self.assertFalse(self.serializer._is_valid_state(state)) + + # Successful creations but no measurements + state.creation = {"": 10} # 10 successful creations + self.assertFalse(self.serializer._is_valid_state(state)) + + # Measurements without successful creations + state.creation = {"error": 1} # Only error, no successes + state.tunnel_ping_min = [10.0] # But we have measurements + self.assertFalse(self.serializer._is_valid_state(state)) + + def test_is_valid_state_valid_case(self): + """Test validation with valid state.""" + state = AggregateProtocolState( + provider="test-provider", + protocol="openvpn", + window_start=self.now, + window_end=self.now + timedelta(hours=1), + ) + + # Valid creation counts and measurements + state.creation = {"": 10, "error": 1} + state.tunnel_ping_min = [10.0] + state.tunnel_ping_avg = [12.0] + state.tunnel_ping_max = [15.0] + + self.assertTrue(self.serializer._is_valid_state(state)) + + def test_create_ndt_bodies_validation_failures(self): + """Test NDT bodies creation with invalid measurements.""" + proto_state = AggregateProtocolState( + provider="test-provider", + protocol="openvpn", + window_start=self.now, + window_end=self.now + timedelta(hours=1), + ) + + # Test case 1: Inconsistent measurement counts + proto_state.tunnel_ndt_download_throughput = [1.0, 2.0] + proto_state.tunnel_ndt_download_latency = [10.0] # Different length + proto_state.tunnel_ndt_download_rexmit = [0.5] + proto_state.tunnel_ndt_upload_throughput = [1.0] + proto_state.tunnel_ndt_upload_latency = [10.0] + proto_state.tunnel_ndt_upload_rexmit = [0.5] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ndt_bodies(proto_state) + self.assertEqual(len(bodies), 1) # Upload phase should still succeed + + # Test case 2: Unreasonable throughput + proto_state.tunnel_ndt_download_throughput = [15000.0] # > 10 Gbps + proto_state.tunnel_ndt_download_latency = [10.0] + proto_state.tunnel_ndt_download_rexmit = [0.5] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ndt_bodies(proto_state) + self.assertEqual(len(bodies), 1) # Upload phase should still succeed + + # Test case 3: Test both phases failing + # Invalid download phase + proto_state.tunnel_ndt_download_throughput = [15000.0] # > 10 Gbps + proto_state.tunnel_ndt_download_latency = [10.0] + proto_state.tunnel_ndt_download_rexmit = [0.5] + # Invalid upload phase + proto_state.tunnel_ndt_upload_throughput = [15000.0] # > 10 Gbps + proto_state.tunnel_ndt_upload_latency = [10.0] + proto_state.tunnel_ndt_upload_rexmit = [0.5] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ndt_bodies(proto_state) + self.assertEqual(len(bodies), 0) # Both phases should fail + + # Test case 4: Negative latency + proto_state.tunnel_ndt_download_throughput = [100.0] + proto_state.tunnel_ndt_download_latency = [-1.0] # Invalid negative latency + proto_state.tunnel_ndt_download_rexmit = [1.0] + proto_state.tunnel_ndt_upload_throughput = [100.0] + proto_state.tunnel_ndt_upload_latency = [100.0] + proto_state.tunnel_ndt_upload_rexmit = [1.0] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ndt_bodies(proto_state) + self.assertEqual(len(bodies), 1) # Upload phase should succeed + + # Test case 5: Latency > 60 seconds + proto_state.tunnel_ndt_download_throughput = [100.0] + proto_state.tunnel_ndt_download_latency = [65000.0] # > 60 seconds + proto_state.tunnel_ndt_download_rexmit = [1.0] + proto_state.tunnel_ndt_upload_throughput = [100.0] + proto_state.tunnel_ndt_upload_latency = [100.0] + proto_state.tunnel_ndt_upload_rexmit = [1.0] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ndt_bodies(proto_state) + self.assertEqual(len(bodies), 1) # Upload phase should succeed + + # Test case 6: Negative retransmission + proto_state.tunnel_ndt_download_throughput = [100.0] + proto_state.tunnel_ndt_download_latency = [100.0] + proto_state.tunnel_ndt_download_rexmit = [ + -1.0 + ] # Invalid negative retransmission + proto_state.tunnel_ndt_upload_throughput = [100.0] + proto_state.tunnel_ndt_upload_latency = [100.0] + proto_state.tunnel_ndt_upload_rexmit = [1.0] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ndt_bodies(proto_state) + self.assertEqual(len(bodies), 1) # Upload phase should succeed + + # Test case 7: Retransmission > 100% + proto_state.tunnel_ndt_download_throughput = [100.0] + proto_state.tunnel_ndt_download_latency = [100.0] + proto_state.tunnel_ndt_download_rexmit = [150.0] # > 100% + proto_state.tunnel_ndt_upload_throughput = [100.0] + proto_state.tunnel_ndt_upload_latency = [100.0] + proto_state.tunnel_ndt_upload_rexmit = [1.0] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ndt_bodies(proto_state) + self.assertEqual(len(bodies), 1) # Upload phase should succeed + + # Test case 8: Multiple validation failures in same phase + proto_state.tunnel_ndt_download_throughput = [100.0] + proto_state.tunnel_ndt_download_latency = [-1.0] # Invalid negative latency + proto_state.tunnel_ndt_download_rexmit = [150.0] # > 100% + proto_state.tunnel_ndt_upload_throughput = [100.0] + proto_state.tunnel_ndt_upload_latency = [100.0] + proto_state.tunnel_ndt_upload_rexmit = [1.0] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ndt_bodies(proto_state) + self.assertEqual(len(bodies), 1) # Upload phase should succeed + + # Test case 9: All invalid measurements in both phases + proto_state.tunnel_ndt_download_throughput = [100.0] + proto_state.tunnel_ndt_download_latency = [-1.0] + proto_state.tunnel_ndt_download_rexmit = [150.0] + proto_state.tunnel_ndt_upload_throughput = [100.0] + proto_state.tunnel_ndt_upload_latency = [65000.0] + proto_state.tunnel_ndt_upload_rexmit = [-1.0] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ndt_bodies(proto_state) + self.assertEqual(len(bodies), 0) # Both phases should fail + + def test_create_ping_bodies_validation_failures(self): + """Test ping bodies creation with invalid measurements.""" + proto_state = AggregateProtocolState( + provider="test-provider", + protocol="openvpn", + window_start=self.now, + window_end=self.now + timedelta(hours=1), + ) + + # Test case 1: Invalid ordering (min > avg) + proto_state.tunnel_ping_min = [20.0] + proto_state.tunnel_ping_avg = [10.0] + proto_state.tunnel_ping_max = [30.0] + proto_state.tunnel_ping_loss = [1.0] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ping_bodies(proto_state) + self.assertEqual(len(bodies), 1) # Only loss measurement should succeed + + # Test case 2: Invalid ordering (avg > max) + proto_state.tunnel_ping_min = [10.0] + proto_state.tunnel_ping_avg = [30.0] + proto_state.tunnel_ping_max = [20.0] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ping_bodies(proto_state) + self.assertEqual(len(bodies), 1) # Only loss measurement should succeed + + # Test case 3: Unreasonable latency (> 60 seconds) + proto_state.tunnel_ping_min = [1000.0] + proto_state.tunnel_ping_avg = [61000.0] # > 60 seconds + proto_state.tunnel_ping_max = [65000.0] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ping_bodies(proto_state) + self.assertEqual(len(bodies), 1) # Only loss measurement should succeed + + # Test case 4: Negative latency + proto_state.tunnel_ping_min = [-1.0] + proto_state.tunnel_ping_avg = [10.0] + proto_state.tunnel_ping_max = [20.0] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ping_bodies(proto_state) + self.assertEqual(len(bodies), 1) # Only loss measurement should succeed + + # Test case 5: Invalid loss percentage (negative) + proto_state.tunnel_ping_min = [10.0] + proto_state.tunnel_ping_avg = [20.0] + proto_state.tunnel_ping_max = [30.0] + proto_state.tunnel_ping_loss = [-1.0] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ping_bodies(proto_state) + self.assertEqual(len(bodies), 3) # Only latency measurements should succeed + + # Test case 6: Invalid loss percentage (> 100%) + proto_state.tunnel_ping_loss = [150.0] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ping_bodies(proto_state) + self.assertEqual(len(bodies), 3) # Only latency measurements should succeed + + # Test case 7: Empty measurements + proto_state.tunnel_ping_min = [] + proto_state.tunnel_ping_avg = [] + proto_state.tunnel_ping_max = [] + proto_state.tunnel_ping_loss = [] + + bodies = self.serializer._create_ping_bodies(proto_state) + self.assertEqual(len(bodies), 0) # No measurements should produce no bodies + + # Test case 8: All measurements invalid + proto_state.tunnel_ping_min = [-1.0] + proto_state.tunnel_ping_avg = [61000.0] + proto_state.tunnel_ping_max = [65000.0] + proto_state.tunnel_ping_loss = [150.0] + + with self.assertLogs(level="WARNING"): + bodies = self.serializer._create_ping_bodies(proto_state) + self.assertEqual(len(bodies), 0) # All measurements should fail + + # Test case 9: Valid measurements + proto_state.tunnel_ping_min = [10.0] + proto_state.tunnel_ping_avg = [20.0] + proto_state.tunnel_ping_max = [30.0] + proto_state.tunnel_ping_loss = [50.0] + + bodies = self.serializer._create_ping_bodies(proto_state) + self.assertEqual(len(bodies), 4) # All measurements should succeed + + def test_create_error_bodies_with_sample_sizes(self): + """Test error bodies creation with sample size handling.""" + proto_state = AggregateProtocolState( + provider="test-provider", + protocol="openvpn", + window_start=self.now, + window_end=self.now + timedelta(hours=1), + ) + + # Test case 1: No errors + proto_state.creation = {"": 100} # Only successes + bodies = self.serializer._create_error_bodies(proto_state) + self.assertEqual(len(bodies), 0) # No error bodies should be created + + # Test case 2: Errors below minimum sample size + proto_state.creation = { + "": 100, # Successes + "connection_error": 5, # Below min_sample_size (10) + } + bodies = self.serializer._create_error_bodies(proto_state) + self.assertEqual(len(bodies), 1) + self.assertNotIn("sample_size", bodies[0]) # Should not include sample size + self.assertEqual(bodies[0]["failure_ratio"], 0.05) # 5/100 + self.assertEqual(bodies[0]["error"], "connection_error") + + # Test case 3: Errors above minimum sample size + proto_state.creation = { + "": 1000, # Successes + "timeout_error": 200, # Above min_sample_size (10) + "connection_error": 300, # Above min_sample_size (10) + } + bodies = self.serializer._create_error_bodies(proto_state) + self.assertEqual(len(bodies), 2) + + # Check both error bodies + for body in bodies: + self.assertIn("sample_size", body) # Should include sample size + if body["error"] == "timeout_error": + self.assertEqual(body["sample_size"], 200) + self.assertEqual(body["failure_ratio"], 0.13) # 200/1500 + else: + self.assertEqual(body["sample_size"], 300) + self.assertEqual(body["failure_ratio"], 0.20) # 300/1500 + + # Test case 4: Mixed error counts + proto_state.creation = { + "": 1000, # Successes + "timeout_error": 5, # Below min_sample_size + "connection_error": 200, # Above min_sample_size + } + bodies = self.serializer._create_error_bodies(proto_state) + self.assertEqual(len(bodies), 2) + + # Check both error bodies + for body in bodies: + if body["error"] == "timeout_error": + self.assertNotIn("sample_size", body) + self.assertEqual(body["failure_ratio"], round(5 / 1205, 2)) + else: + self.assertIn("sample_size", body) + self.assertEqual(body["sample_size"], 200) + self.assertEqual(body["failure_ratio"], round(200 / 1205, 2)) + + +class TestTestKeys(unittest.TestCase): + """Tests for TestKeys functionality.""" + + def setUp(self): + """Set up common test fixtures.""" + self.window = AggregationTimeWindow( + from_time=datetime.now(timezone.utc), + to_time=datetime.now(timezone.utc) + timedelta(hours=1), + ) + + def test_as_dict_mandatory_fields(self): + """Test TestKeys serialization with mandatory fields.""" + window = AggregationTimeWindow( + from_time=datetime.now(timezone.utc), + to_time=datetime.now(timezone.utc) + timedelta(hours=1), + ) + + test_keys = TestKeys( + provider="test-provider", + scope="global", + protocol="openvpn", + time_window=window, + endpoint_hostname=None, + endpoint_address=None, + endpoint_port=None, + asn=None, + cc=None, + bodies=[], + ) + + result = test_keys.as_dict() + + self.assertEqual(result["provider"], "test-provider") + self.assertEqual(result["scope"], "global") + self.assertEqual(result["protocol"], "openvpn") + self.assertIn("time_window", result) + self.assertIn("bodies", result) + + # Optional fields should not be present + self.assertNotIn("endpoint_hostname", result) + self.assertNotIn("endpoint_address", result) + self.assertNotIn("endpoint_port", result) + self.assertNotIn("asn", result) + self.assertNotIn("cc", result) + + def test_as_dict_with_optional_fields(self): + """Test TestKeys serialization with optional fields present.""" + test_keys = TestKeys( + provider="test-provider", + scope="endpoint", # Changed to endpoint since we're including endpoint details + protocol="openvpn", + time_window=self.window, + endpoint_hostname="test.example.com", + endpoint_address="1.2.3.4", + endpoint_port=1234, + asn="AS12345", + cc="XX", + bodies=[], + ) + + result = test_keys.as_dict() + + # Check mandatory fields + self.assertEqual(result["provider"], "test-provider") + self.assertEqual(result["scope"], "endpoint") + self.assertEqual(result["protocol"], "openvpn") + self.assertIn("time_window", result) + self.assertIn("bodies", result) + + # Check optional fields are present with correct values + self.assertEqual(result["endpoint_hostname"], "test.example.com") + self.assertEqual(result["endpoint_address"], "1.2.3.4") + self.assertEqual(result["endpoint_port"], 1234) + self.assertEqual(result["asn"], "AS12345") + self.assertEqual(result["cc"], "XX") + + def test_as_dict_with_partial_optional_fields(self): + """Test TestKeys serialization with some optional fields present.""" + test_keys = TestKeys( + provider="test-provider", + scope="endpoint", + protocol="openvpn", + time_window=self.window, + endpoint_hostname="test.example.com", + endpoint_address="1.2.3.4", + endpoint_port=None, # Only some optional fields set + asn="AS12345", + cc=None, + bodies=[], + ) + + result = test_keys.as_dict() + + # Check mandatory fields + self.assertEqual(result["provider"], "test-provider") + self.assertEqual(result["scope"], "endpoint") + self.assertEqual(result["protocol"], "openvpn") + + # Check present optional fields + self.assertEqual(result["endpoint_hostname"], "test.example.com") + self.assertEqual(result["endpoint_address"], "1.2.3.4") + self.assertEqual(result["asn"], "AS12345") + + # Check omitted optional fields + self.assertNotIn("endpoint_port", result) + self.assertNotIn("cc", result) + + +if __name__ == "__main__": + unittest.main() -- GitLab From f48e0ff7bb261a667f93f0904017ff7a4b37b32d Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 15 Feb 2025 22:09:18 +0100 Subject: [PATCH 48/75] feat: write the top-level pipeline package I let the AI (Claude 3.5) rewrite this based on previous implementation and what was obvious given the existing APIs. It remains to write unit and integration testing for this and to update the design document and move it. --- .../ooniformatter/__init__.py | 2 + .../ooniformatter/formatter.py | 42 ++--- aggregatetunnelmetrics/pipeline/__init__.py | 26 +++ aggregatetunnelmetrics/pipeline/config.py | 70 ++++++++ aggregatetunnelmetrics/pipeline/errors.py | 15 ++ aggregatetunnelmetrics/pipeline/processor.py | 159 ++++++++++++++++++ aggregatetunnelmetrics/pipeline/state.py | 59 +++++++ aggregatetunnelmetrics/pipeline/window.py | 67 ++++++++ .../ooniformatter/test_formatter.py | 4 +- 9 files changed, 422 insertions(+), 22 deletions(-) create mode 100644 aggregatetunnelmetrics/pipeline/__init__.py create mode 100644 aggregatetunnelmetrics/pipeline/config.py create mode 100644 aggregatetunnelmetrics/pipeline/errors.py create mode 100644 aggregatetunnelmetrics/pipeline/processor.py create mode 100644 aggregatetunnelmetrics/pipeline/state.py create mode 100644 aggregatetunnelmetrics/pipeline/window.py diff --git a/aggregatetunnelmetrics/ooniformatter/__init__.py b/aggregatetunnelmetrics/ooniformatter/__init__.py index fd51bdb..902e567 100644 --- a/aggregatetunnelmetrics/ooniformatter/__init__.py +++ b/aggregatetunnelmetrics/ooniformatter/__init__.py @@ -8,6 +8,7 @@ See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-ag from .formatter import ( AggregationTimeWindow, + Config, TestKeys, SerializationConfigError, Serializer, @@ -15,6 +16,7 @@ from .formatter import ( __all__ = [ "AggregationTimeWindow", + "Config", "SerializationConfigError", "Serializer", "TestKeys", diff --git a/aggregatetunnelmetrics/ooniformatter/formatter.py b/aggregatetunnelmetrics/ooniformatter/formatter.py index 146035e..05fb180 100644 --- a/aggregatetunnelmetrics/ooniformatter/formatter.py +++ b/aggregatetunnelmetrics/ooniformatter/formatter.py @@ -14,17 +14,12 @@ from urllib.parse import urlunparse, urlencode import logging -from ..globalscope import ( - AggregateProtocolState, - AggregatorConfig, - AggregateState, - datetime_to_compact_utc, -) -from ..oonireport import Measurement +from .. import globalscope +from .. import oonireport @dataclass(frozen=True) -class FormatterConfig: +class Config: """ Configuration for the OONI measurement formatter. """ @@ -48,8 +43,8 @@ class AggregationTimeWindow: def as_dict(self) -> dict: """Convert to JSON-serializable dict""" return { - "from": datetime_to_compact_utc(self.from_time), - "to": datetime_to_compact_utc(self.to_time), + "from": globalscope.datetime_to_compact_utc(self.from_time), + "to": globalscope.datetime_to_compact_utc(self.to_time), } @@ -109,7 +104,7 @@ class SerializationConfigError(Exception): class Serializer: """Converts aggregate endpoint state into OONI measurements""" - def __init__(self, ac: AggregatorConfig, fc: FormatterConfig): + def __init__(self, ac: globalscope.AggregatorConfig, fc: Config): self.aggregator_cfg = ac self.formatter_cfg = fc @@ -139,7 +134,7 @@ class Serializer: "99p": round(q[98], 1), } - def _create_input_url(self, state: AggregateProtocolState) -> str: + def _create_input_url(self, state: globalscope.AggregateProtocolState) -> str: """Create the measurement input URL""" # The query is empty when using the global state query = {} @@ -174,7 +169,7 @@ class Serializer: return obj def _create_error_bodies( - self, state: AggregateProtocolState + self, state: globalscope.AggregateProtocolState ) -> list[dict[str, Any]]: """Create error bodies if there are any errors""" bodies = [] @@ -197,7 +192,10 @@ class Serializer: return bodies def _validate_ping_measurements( - self, state: AggregateProtocolState, metric_type: str, measurements: list[float] + self, + state: globalscope.AggregateProtocolState, + metric_type: str, + measurements: list[float], ) -> None: """Validate ping measurements""" if metric_type in ["min", "avg", "max"]: @@ -214,7 +212,7 @@ class Serializer: raise SerializationConfigError("ping loss out of range") def _create_ping_bodies( - self, state: AggregateProtocolState + self, state: globalscope.AggregateProtocolState ) -> list[dict[str, Any]]: """Create bodies for ping measurements""" bodies = [] @@ -266,7 +264,9 @@ class Serializer: f"NDT {phase} retransmission out of range" ) - def _create_ndt_bodies(self, state: AggregateProtocolState) -> list[dict[str, Any]]: + def _create_ndt_bodies( + self, state: globalscope.AggregateProtocolState + ) -> list[dict[str, Any]]: """Create bodies for NDT measurements""" bodies = [] items = ( @@ -307,7 +307,7 @@ class Serializer: return bodies def _create_global_bodies( - self, state: AggregateProtocolState + self, state: globalscope.AggregateProtocolState ) -> list[dict[str, Any]]: """Create the bodies section of test_keys""" bodies = [] @@ -316,7 +316,7 @@ class Serializer: bodies.extend(self._create_ndt_bodies(state)) return bodies - def _is_valid_state(self, state: AggregateProtocolState) -> bool: + def _is_valid_state(self, state: globalscope.AggregateProtocolState) -> bool: """ Validates the state before serialization. Returns False if state should be skipped, True if it's valid to serialize. @@ -363,7 +363,9 @@ class Serializer: return True - def serialize_global(self, state: AggregateState) -> list[Measurement]: + def serialize_global( + self, state: globalscope.AggregateState + ) -> list[oonireport.Measurement]: """ Convert global state to OONI measurement format. @@ -393,7 +395,7 @@ class Serializer: bodies=self._create_global_bodies(proto_state), ) - mx = Measurement( + mx = oonireport.Measurement( annotations={ "upstream_collector": self.formatter_cfg.upstream_collector }, diff --git a/aggregatetunnelmetrics/pipeline/__init__.py b/aggregatetunnelmetrics/pipeline/__init__.py new file mode 100644 index 0000000..c1e11b6 --- /dev/null +++ b/aggregatetunnelmetrics/pipeline/__init__.py @@ -0,0 +1,26 @@ +""" +High-level pipeline for processing and submitting aggregate tunnel metrics. + +This package provides the main interface for processing field testing CSV files +and submitting the resulting metrics to OONI. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from .config import ProcessConfig, FileIOConfig +from .processor import MetricsProcessor +from .state import ProcessorState +from .window import WindowPolicy, DailyWindowPolicy, WeeklyWindowPolicy +from .errors import PipelineError, StateError + +__all__ = [ + "ProcessConfig", + "FileIOConfig", + "MetricsProcessor", + "ProcessorState", + "WindowPolicy", + "DailyWindowPolicy", + "WeeklyWindowPolicy", + "PipelineError", + "StateError", +] diff --git a/aggregatetunnelmetrics/pipeline/config.py b/aggregatetunnelmetrics/pipeline/config.py new file mode 100644 index 0000000..848bdc1 --- /dev/null +++ b/aggregatetunnelmetrics/pipeline/config.py @@ -0,0 +1,70 @@ +""" +Configuration classes for the metrics processing pipeline. + +This module is internal. Please, import `pipeline` directly instead. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from dataclasses import dataclass +from datetime import datetime + +from .. import lockedfile + + +@dataclass(frozen=True) +class ProcessConfig: + """ + Configuration for the full metrics processing pipeline. + + For safety, you must explicitly provide a collector_base_url. You should configure + it to point either to a testing server or to `https://api.ooni.io/`. + + Fields: + provider: Name of the metrics provider. + upstream_collector: Name of the collector used to collect the CSV files. + probe_asn: ASN of the collector (becomes probe_asn in the OONI measurement). + probe_cc: Country code of the collector (becomes probe_cc in the OONI measurement). + min_sample_size: Minimum number of samples to include statistical information. + collector_base_url: Base URL of the OONI collector to use (mandatory). + timeout: Timeout for HTTP requests. + """ + + # Core identification + provider: str + + # Configuration for filling the measurement + upstream_collector: str + probe_asn: str + probe_cc: str + + # Mandatory collector configuration + collector_base_url: str + + # Optional measurement-filling configuration + min_sample_size: int = 1000 + + # Optional collector configuration + timeout: float = 30.0 + + +@dataclass(frozen=True) +class FileIOConfig: + """ + Configuration for file I/O operations. + + Fields: + state_file: Path to the file where to store state information. + num_retries: Number of retries to perform when acquiring the file lock. + sleep_interval: Time to wait between retries when acquiring the lock. + """ + + state_file: str + num_retries: int = 10 + sleep_interval: float = 0.1 + + def as_lockedfile_fileio_config(self) -> lockedfile.FileIOConfig: + """Convert to a lockedfile.FileIOConfig.""" + return lockedfile.FileIOConfig( + num_retries=self.num_retries, sleep_interval=self.sleep_interval + ) diff --git a/aggregatetunnelmetrics/pipeline/errors.py b/aggregatetunnelmetrics/pipeline/errors.py new file mode 100644 index 0000000..986a44c --- /dev/null +++ b/aggregatetunnelmetrics/pipeline/errors.py @@ -0,0 +1,15 @@ +""" +Exceptions raised by the pipeline package. + +This module is internal. Please, import `pipeline` directly instead. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + + +class PipelineError(Exception): + """Base class for pipeline-related errors.""" + + +class StateError(PipelineError): + """Raised when there are state management errors.""" diff --git a/aggregatetunnelmetrics/pipeline/processor.py b/aggregatetunnelmetrics/pipeline/processor.py new file mode 100644 index 0000000..192a81b --- /dev/null +++ b/aggregatetunnelmetrics/pipeline/processor.py @@ -0,0 +1,159 @@ +""" +Main implementation of the metrics processor. + +This module is internal. Please, import `pipeline` directly instead. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from datetime import datetime, timezone +from typing import List + +import os +import tempfile + +from .config import ProcessConfig, FileIOConfig +from .state import ProcessorState +from .window import WindowPolicy + +from .. import fieldtestingcsv +from .. import globalscope +from .. import lockedfile +from .. import ooniformatter +from .. import oonireport + + +class MetricsProcessor: + """High-level API for processing and submitting tunnel metrics.""" + + def __init__( + self, + process_config: ProcessConfig, + fileio_config: FileIOConfig, + window_policy: WindowPolicy, + ): + self.process_config = process_config + self.fileio_config = fileio_config + self.window_policy = window_policy + + # Initialize configs for sub-components + self.aggregator_config = globalscope.AggregatorConfig( + provider=process_config.provider + ) + + self.collector_client = oonireport.CollectorClient( + oonireport.CollectorConfig( + collector_base_url=process_config.collector_base_url, + timeout=process_config.timeout, + ) + ) + + # Load initial state + self.state = ProcessorState.load( + fileio_config.state_file, + fileio_config.as_lockedfile_fileio_config(), + ) + + # Track current report ID + self._current_report_id: str | None = None + + def _submit_measurement(self, measurement: oonireport.Measurement) -> None: + """Submit a single measurement to OONI.""" + if not self._current_report_id: + self._current_report_id = ( + self.collector_client.create_report_from_measurement(measurement) + ) + + self.collector_client.update_report(self._current_report_id, measurement) + + def process_csv_file(self, csv_path: str) -> None: + """Process CSV file and submit measurements for complete windows. + + Args: + csv_path: Path to field testing CSV file + + Raises: + FileLockError: If cannot acquire locks + StateError: If state file operations fail + ValueError: If CSV parsing fails + SerializationConfigError: If measurement creation fails + """ + # Use mutex to ensure exclusive access + with lockedfile.Mutex(f"{csv_path}.lock"): + # Get a consistent snapshot of CSV + # TODO(bassosimone): consider whether streaming would be possible here + csv_content = lockedfile.read( + csv_path, + self.fileio_config.as_lockedfile_fileio_config(), + ) + + # Process in temp file + with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp: + tmp.write(csv_content) + tmp_path = tmp.name + + try: + # Get time windows to process + project_start = datetime(2024, 1, 1, tzinfo=timezone.utc) + start_time = self.state.next_submission_after or project_start + current_time = datetime.now(timezone.utc) + + windows = self.window_policy.generate_windows(start_time, current_time) + + # Process each window + for window_start, window_end in windows: + # TODO(bassosimone): this design has the issue that we parse the CSV file + # multiple times. Should we instead just parse it once or instead see whether + # we could split the file on creation into well-defined buckets? + self._process_window(tmp_path, window_start, window_end) + + # Update state after successful window processing + # TODO(bassosimone): double check whether using window_end here is proper. It should + # be, considering that we exclude windows containing the current time. + self.state.next_submission_after = window_end + self.state.save( + self.fileio_config.state_file, + self.fileio_config.as_lockedfile_fileio_config(), + ) + + finally: + os.unlink(tmp_path) + # Reset report ID for next processing + self._current_report_id = None + + def _process_window( + self, csv_path: str, window_start: datetime, window_end: datetime + ) -> None: + """Process entries within a specific time window.""" + # Parse and filter entries for window + entries = fieldtestingcsv.parse_file(csv_path) + window_entries = [e for e in entries if window_start <= e.date < window_end] + + if not window_entries: + return + + # Create aggregate state + state = globalscope.AggregateState( + config=self.aggregator_config, + window_start=window_start, + window_end=window_end, + ) + + # Update state with entries + for entry in window_entries: + state.update(entry) + + # Create and submit measurements + serializer = ooniformatter.Serializer( + self.aggregator_config, + ooniformatter.Config( + upstream_collector=self.process_config.upstream_collector, + probe_asn=self.process_config.probe_asn, + probe_cc=self.process_config.probe_cc, + min_sample_size=self.process_config.min_sample_size, + ), + ) + + measurements = serializer.serialize_global(state) + for measurement in measurements: + self._submit_measurement(measurement) diff --git a/aggregatetunnelmetrics/pipeline/state.py b/aggregatetunnelmetrics/pipeline/state.py new file mode 100644 index 0000000..82fee5a --- /dev/null +++ b/aggregatetunnelmetrics/pipeline/state.py @@ -0,0 +1,59 @@ +""" +State management for the metrics processing pipeline. + +This module is internal. Please, import `pipeline` directly instead. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone +import json + +from .errors import StateError + +from .. import lockedfile + + +@dataclass +class ProcessorState: + """Persistent state of the metrics processor.""" + + next_submission_after: datetime | None = None + + @classmethod + def load(cls, path: str, config: lockedfile.FileIOConfig) -> ProcessorState: + """Load state from file with proper locking.""" + try: + content = lockedfile.read(path, config) + data = json.loads(content) + + # Parse next_submission_after if present + if next_after := data.get("next_submission_after"): + next_after_dt = datetime.strptime(next_after, "%Y%m%dT%H%M%SZ").replace( + tzinfo=timezone.utc + ) + return cls(next_submission_after=next_after_dt) + + return cls() + + except FileNotFoundError: + return cls() + except json.JSONDecodeError as e: + raise StateError(f"Corrupt state file: {e}") + except ValueError as e: + raise StateError(f"Invalid datetime in state: {e}") + + def save(self, path: str, config: lockedfile.FileIOConfig) -> None: + """Save state to file with proper locking.""" + data = { + "next_submission_after": ( + self.next_submission_after.strftime("%Y%m%dT%H%M%SZ") + if self.next_submission_after + else None + ) + } + lockedfile.write(path, json.dumps(data), config) diff --git a/aggregatetunnelmetrics/pipeline/window.py b/aggregatetunnelmetrics/pipeline/window.py new file mode 100644 index 0000000..764d98f --- /dev/null +++ b/aggregatetunnelmetrics/pipeline/window.py @@ -0,0 +1,67 @@ +""" +Time window management for the metrics processing pipeline. + +This module is internal. Please, import `pipeline` directly instead. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from abc import ABC, abstractmethod +from datetime import datetime, timedelta + + +class WindowPolicy(ABC): + """Abstract base class for time window policies.""" + + def generate_windows( + self, last_submission: datetime, current_time: datetime + ) -> list[tuple[datetime, datetime]]: + """Generate all complete windows between last_submission and current_time. + + Args: + last_submission: Start point for window generation + current_time: Current time, used to determine incomplete windows + + Returns: + List of (start, end) tuples representing complete windows + """ + windows = [] + window_start = last_submission + + while window_start < current_time: + window_end = self._compute_window_end(window_start) + + # Stop if window would include current time + if current_time <= window_end: + break + + windows.append((window_start, window_end)) + window_start = window_end + + return windows + + @abstractmethod + def _compute_window_end(self, start: datetime) -> datetime: + """Compute the end time for a window starting at start.""" + pass + + +class DailyWindowPolicy(WindowPolicy): + """Organizes measurements in daily windows.""" + + def _compute_window_end(self, start: datetime) -> datetime: + """Return start of next day.""" + return start.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta( + days=1 + ) + + +class WeeklyWindowPolicy(WindowPolicy): + """Organizes measurements in weekly windows.""" + + def _compute_window_end(self, start: datetime) -> datetime: + """Return start of next week (Monday).""" + days_until_monday = (7 - start.isoweekday()) % 7 + return start.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta( + days=days_until_monday + ) diff --git a/tests/aggregatetunnelmetrics/ooniformatter/test_formatter.py b/tests/aggregatetunnelmetrics/ooniformatter/test_formatter.py index eb1cbe5..1a5d730 100644 --- a/tests/aggregatetunnelmetrics/ooniformatter/test_formatter.py +++ b/tests/aggregatetunnelmetrics/ooniformatter/test_formatter.py @@ -13,7 +13,7 @@ from aggregatetunnelmetrics.globalscope import ( AggregateProtocolState, ) from aggregatetunnelmetrics.ooniformatter.formatter import ( - FormatterConfig, + Config, Serializer, AggregationTimeWindow, TestKeys, @@ -43,7 +43,7 @@ class TestSerializer(unittest.TestCase): def setUp(self): """Set up common test fixtures.""" self.aggregator_config = AggregatorConfig(provider="test-provider") - self.formatter_config = FormatterConfig( + self.formatter_config = Config( upstream_collector="test-collector", probe_asn="AS12345", probe_cc="XX", -- GitLab From 7e722e085ddf7958a657cdaa4f45a328ae7cec98 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 22 Feb 2025 16:59:05 +0100 Subject: [PATCH 49/75] fix: refactor tests to work with multiprocess spawn I was previously using Linux where the multiprocess creation mechanism is fork. I am now on macOS, where it is spawn. To make tests work with spawn, I need to refactor them such that the entities executed on background processes are actual toplevel functions rather than closures. --- .../lockedfile/test_fileio.py | 45 ++++++--- .../lockedfile/test_mutex.py | 92 +++++++++++-------- 2 files changed, 86 insertions(+), 51 deletions(-) diff --git a/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py b/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py index 3974bff..b3d42da 100644 --- a/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py +++ b/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Set from unittest.mock import patch, mock_open import fcntl @@ -15,6 +14,21 @@ import unittest from aggregatetunnelmetrics.lockedfile import common, fileio +# Writer process function needs to be at module level for macOS compatibility +def _writer_process(temp_path, valid_contents, should_stop, config): + while not should_stop.is_set(): + for content in valid_contents: + fileio.write(temp_path, content, config=config) + + +# Reader process function needs to be at module level for macOS compatibility +def _reader_process(temp_path, should_stop, results_queue, config): + while not should_stop.is_set(): + content = fileio.read(temp_path, config=config) + if content: + results_queue.put(content) + + class TestFileIOUnit(unittest.TestCase): """Unit tests for the lockedfile/fileio.py functionality.""" @@ -114,20 +128,21 @@ class TestFileIOIntegration(unittest.TestCase): should_stop = mp.Event() config = fileio.FileIOConfig(num_retries=30, sleep_interval=0.1) - def writer_process(): - while not should_stop.is_set(): - for content in valid_contents: - fileio.write(self.temp_path, content, config=config) - - writers = [mp.Process(target=writer_process) for _ in range(4)] - - def reader_process(): - while not should_stop.is_set(): - content = fileio.read(self.temp_path, config=config) - if content: - results_queue.put(content) - - readers = [mp.Process(target=reader_process) for _ in range(8)] + writers = [ + mp.Process( + target=_writer_process, + args=(self.temp_path, valid_contents, should_stop, config), + ) + for _ in range(4) + ] + + readers = [ + mp.Process( + target=_reader_process, + args=(self.temp_path, should_stop, results_queue, config), + ) + for _ in range(8) + ] # Start all processes for p in readers + writers: diff --git a/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py b/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py index d39c362..0c89e5d 100644 --- a/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py +++ b/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py @@ -13,6 +13,42 @@ import unittest from aggregatetunnelmetrics.lockedfile import Mutex, FileLockError +def writer_process(lock_path, temp_path, valid_contents, should_stop, results_queue): + try: + while not should_stop.is_set(): + for content in valid_contents: + while not should_stop.is_set(): + try: + with Mutex(lock_path): + with open(temp_path, "w") as f: + f.write(content) + f.flush() + os.fsync(f.fileno()) + break # Success! Move to next content + except FileLockError: + # Lock busy, try again + time.sleep(0.1) + except Exception as exc: + results_queue.put(exc) + + +def reader_process(lock_path, temp_path, should_stop, results_queue): + try: + while not should_stop.is_set(): + try: + with Mutex(lock_path): + with open(temp_path) as f: + content = f.read() + if content: + results_queue.put(content) + break # Success! Continue outer loop + except FileLockError: + # Lock busy, try again + time.sleep(0.1) + except Exception as exc: + results_queue.put(exc) + + class TestMutexUnit(unittest.TestCase): """Unit tests for the lockedfile/mutex.py functionality.""" @@ -111,47 +147,31 @@ class TestMutexIntegration(unittest.TestCase): results_queue = mp.Queue() should_stop = mp.Event() - def writer_process(): - try: - while not should_stop.is_set(): - for content in valid_contents: - while not should_stop.is_set(): - try: - with Mutex(self.lock_path): - with open(self.temp_path, "w") as f: - f.write(content) - f.flush() - os.fsync(f.fileno()) - break # Success! Move to next content - except FileLockError: - # Lock busy, try again - time.sleep(0.1) - except Exception as exc: - results_queue.put(exc) - - def reader_process(): - try: - while not should_stop.is_set(): - try: - with Mutex(self.lock_path): - with open(self.temp_path) as f: - content = f.read() - if content: - results_queue.put(content) - break # Success! Continue outer loop - except FileLockError: - # Lock busy, try again - time.sleep(0.1) - except Exception as exc: - results_queue.put(exc) - # Create and initialize the file with open(self.temp_path, "w") as f: f.write(valid_contents[0]) # Start processes - writers = [mp.Process(target=writer_process) for _ in range(4)] - readers = [mp.Process(target=reader_process) for _ in range(8)] + writers = [ + mp.Process( + target=writer_process, + args=( + self.lock_path, + self.temp_path, + valid_contents, + should_stop, + results_queue, + ), + ) + for _ in range(4) + ] + readers = [ + mp.Process( + target=reader_process, + args=(self.lock_path, self.temp_path, should_stop, results_queue), + ) + for _ in range(8) + ] for p in readers + writers: p.start() -- GitLab From e900afb42d8ee4830b93dd888cb9a4055431285e Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sat, 22 Feb 2025 19:52:13 +0100 Subject: [PATCH 50/75] fix(pipeline): rewrite window aggregation I used AI to generate ~placholder code but it was a bit hard to test and generally not clear, so rewrite it and ensure we have some unit tests for the code as well. --- aggregatetunnelmetrics/pipeline/__init__.py | 12 +- aggregatetunnelmetrics/pipeline/processor.py | 34 +- aggregatetunnelmetrics/pipeline/window.py | 67 ---- .../pipeline/windowpolicy.py | 187 ++++++++++ .../pipeline/__init__.py | 0 .../pipeline/test_windowpolicy.py | 334 ++++++++++++++++++ 6 files changed, 539 insertions(+), 95 deletions(-) delete mode 100644 aggregatetunnelmetrics/pipeline/window.py create mode 100644 aggregatetunnelmetrics/pipeline/windowpolicy.py create mode 100644 tests/aggregatetunnelmetrics/pipeline/__init__.py create mode 100644 tests/aggregatetunnelmetrics/pipeline/test_windowpolicy.py diff --git a/aggregatetunnelmetrics/pipeline/__init__.py b/aggregatetunnelmetrics/pipeline/__init__.py index c1e11b6..db9fad4 100644 --- a/aggregatetunnelmetrics/pipeline/__init__.py +++ b/aggregatetunnelmetrics/pipeline/__init__.py @@ -8,19 +8,17 @@ and submitting the resulting metrics to OONI. # SPDX-License-Identifier: GPL-3.0-or-later from .config import ProcessConfig, FileIOConfig +from .errors import PipelineError, StateError from .processor import MetricsProcessor from .state import ProcessorState -from .window import WindowPolicy, DailyWindowPolicy, WeeklyWindowPolicy -from .errors import PipelineError, StateError +from .windowpolicy import Policy __all__ = [ - "ProcessConfig", "FileIOConfig", "MetricsProcessor", - "ProcessorState", - "WindowPolicy", - "DailyWindowPolicy", - "WeeklyWindowPolicy", "PipelineError", + "Policy", + "ProcessorState", + "ProcessConfig", "StateError", ] diff --git a/aggregatetunnelmetrics/pipeline/processor.py b/aggregatetunnelmetrics/pipeline/processor.py index 192a81b..c166562 100644 --- a/aggregatetunnelmetrics/pipeline/processor.py +++ b/aggregatetunnelmetrics/pipeline/processor.py @@ -6,15 +6,12 @@ This module is internal. Please, import `pipeline` directly instead. # SPDX-License-Identifier: GPL-3.0-or-later -from datetime import datetime, timezone -from typing import List - import os import tempfile from .config import ProcessConfig, FileIOConfig from .state import ProcessorState -from .window import WindowPolicy +from .windowpolicy import Policy, Window, generate_windows from .. import fieldtestingcsv from .. import globalscope @@ -30,7 +27,7 @@ class MetricsProcessor: self, process_config: ProcessConfig, fileio_config: FileIOConfig, - window_policy: WindowPolicy, + window_policy: Policy, ): self.process_config = process_config self.fileio_config = fileio_config @@ -94,23 +91,20 @@ class MetricsProcessor: try: # Get time windows to process - project_start = datetime(2024, 1, 1, tzinfo=timezone.utc) - start_time = self.state.next_submission_after or project_start - current_time = datetime.now(timezone.utc) - - windows = self.window_policy.generate_windows(start_time, current_time) + windows = generate_windows( + policy=self.window_policy, + reference=self.state.next_submission_after, + ) # Process each window - for window_start, window_end in windows: + for window in windows: # TODO(bassosimone): this design has the issue that we parse the CSV file # multiple times. Should we instead just parse it once or instead see whether # we could split the file on creation into well-defined buckets? - self._process_window(tmp_path, window_start, window_end) + self._process_window(tmp_path, window) # Update state after successful window processing - # TODO(bassosimone): double check whether using window_end here is proper. It should - # be, considering that we exclude windows containing the current time. - self.state.next_submission_after = window_end + self.state.next_submission_after = window.end self.state.save( self.fileio_config.state_file, self.fileio_config.as_lockedfile_fileio_config(), @@ -121,13 +115,11 @@ class MetricsProcessor: # Reset report ID for next processing self._current_report_id = None - def _process_window( - self, csv_path: str, window_start: datetime, window_end: datetime - ) -> None: + def _process_window(self, csv_path: str, window: Window) -> None: """Process entries within a specific time window.""" # Parse and filter entries for window entries = fieldtestingcsv.parse_file(csv_path) - window_entries = [e for e in entries if window_start <= e.date < window_end] + window_entries = [e for e in entries if window.includes_datetime(e.date)] if not window_entries: return @@ -135,8 +127,8 @@ class MetricsProcessor: # Create aggregate state state = globalscope.AggregateState( config=self.aggregator_config, - window_start=window_start, - window_end=window_end, + window_start=window.start, + window_end=window.end, ) # Update state with entries diff --git a/aggregatetunnelmetrics/pipeline/window.py b/aggregatetunnelmetrics/pipeline/window.py deleted file mode 100644 index 764d98f..0000000 --- a/aggregatetunnelmetrics/pipeline/window.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -Time window management for the metrics processing pipeline. - -This module is internal. Please, import `pipeline` directly instead. -""" - -# SPDX-License-Identifier: GPL-3.0-or-later - -from abc import ABC, abstractmethod -from datetime import datetime, timedelta - - -class WindowPolicy(ABC): - """Abstract base class for time window policies.""" - - def generate_windows( - self, last_submission: datetime, current_time: datetime - ) -> list[tuple[datetime, datetime]]: - """Generate all complete windows between last_submission and current_time. - - Args: - last_submission: Start point for window generation - current_time: Current time, used to determine incomplete windows - - Returns: - List of (start, end) tuples representing complete windows - """ - windows = [] - window_start = last_submission - - while window_start < current_time: - window_end = self._compute_window_end(window_start) - - # Stop if window would include current time - if current_time <= window_end: - break - - windows.append((window_start, window_end)) - window_start = window_end - - return windows - - @abstractmethod - def _compute_window_end(self, start: datetime) -> datetime: - """Compute the end time for a window starting at start.""" - pass - - -class DailyWindowPolicy(WindowPolicy): - """Organizes measurements in daily windows.""" - - def _compute_window_end(self, start: datetime) -> datetime: - """Return start of next day.""" - return start.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta( - days=1 - ) - - -class WeeklyWindowPolicy(WindowPolicy): - """Organizes measurements in weekly windows.""" - - def _compute_window_end(self, start: datetime) -> datetime: - """Return start of next week (Monday).""" - days_until_monday = (7 - start.isoweekday()) % 7 - return start.replace(hour=0, minute=0, second=0, microsecond=0) + timedelta( - days=days_until_monday - ) diff --git a/aggregatetunnelmetrics/pipeline/windowpolicy.py b/aggregatetunnelmetrics/pipeline/windowpolicy.py new file mode 100644 index 0000000..7f4897c --- /dev/null +++ b/aggregatetunnelmetrics/pipeline/windowpolicy.py @@ -0,0 +1,187 @@ +""" +Time window management for the metrics processing pipeline. + +This module is internal. Please, import `pipeline` directly instead. + +Note: + All datetime objects provided to this module must be in UTC timezone. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import Protocol, runtime_checkable + + +@dataclass(frozen=True) +class Window: + """A time window with start and end times.""" + + start: datetime + end: datetime + delta: timedelta + + def __str__(self) -> str: + """Return a string representation of the window.""" + return f"Window({self.start.isoformat()} -> {self.end.isoformat()})" + + def __post_init__(self): + _validate_utc(self.start, "window start") + _validate_utc(self.end, "window end") + if self.start >= self.end: + raise ValueError("window start must be before end") + if self.delta <= timedelta(0): + raise ValueError("window delta must be positive") + if self.end - self.start != self.delta: + raise ValueError("window delta must match end - start") + + def before_datetime(self, t: datetime) -> bool: + """Check if the window ends before the given time.""" + return self.start <= t + + def includes_datetime(self, t: datetime) -> bool: + """Check if the window includes the given time.""" + return self.start <= t < self.end + + def next_window(self) -> Window: + """Return the next window.""" + return Window(self.end, self.end + self.delta, self.delta) + + +@runtime_checkable +class Policy(Protocol): + """Models a policy for generating time windows. + + Methods: + start_window: Generates the first window given a reference time. + """ + + def start_window(self, reference: datetime) -> Window: ... + + +class DailyPolicy: + """A policy that generates daily windows. + + Windows start at midnight UTC and run for 24 hours. + Given a reference time, returns a window starting from midnight + of the day containing the reference time. + """ + + def start_window(self, reference: datetime) -> Window: + """Implements Policy.start. + + Args: + reference: The reference time. + + Returns: + A Window starting from 00:00 UTC of the current day, + ending the following day. + """ + _validate_utc(reference, "reference") + today_at_midnight = reference.replace(hour=0, minute=0, second=0, microsecond=0) + delta = timedelta(days=1) + tomorrow_at_midnight = today_at_midnight + delta + return Window(today_at_midnight, tomorrow_at_midnight, delta) + + +class WeeklyPolicy: + """A policy that generates weekly windows. + + Windows start on Monday at midnight UTC and run for 7 days. + Given a reference time, returns a window starting from the Monday + of the week containing the reference time. + """ + + def start_window(self, reference: datetime) -> Window: + """Implements Policy.start. + + Args: + reference: The reference time. + + Returns: + A Window starting from Monday 00:00 UTC of the current week, + ending the following Monday. + """ + _validate_utc(reference, "reference") + + # Get midnight of the reference day + today_at_midnight = reference.replace(hour=0, minute=0, second=0, microsecond=0) + + # Calculate days since last Monday (isoweekday: Monday=1, Sunday=7) + days_since_monday = today_at_midnight.isoweekday() - 1 + + # Get this week's Monday at midnight + this_monday_at_midnight = today_at_midnight - timedelta(days=days_since_monday) + + delta = timedelta(days=7) + next_monday_at_midnight = this_monday_at_midnight + delta + + return Window(this_monday_at_midnight, next_monday_at_midnight, delta) + + +project_start_time = datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc) +"""The start time of the project, used as the default reference time for window generation.""" + + +def datetime_utcnow() -> datetime: + """Convenience function to return the current UTC time.""" + return datetime.now(timezone.utc) + + +def _validate_utc(dt: datetime, param_name: str) -> None: + """Validate that a datetime is UTC. + + Args: + dt: The datetime to validate. + param_name: Name of the parameter for the error message. + + Raises: + ValueError: If the datetime is not in UTC timezone. + """ + if dt.tzinfo != timezone.utc: + raise ValueError(f"{param_name} must be in UTC timezone") + + +def generate_windows( + policy: Policy, + reference: datetime | None = None, + now: datetime | None = None, +) -> list[Window]: + """Generates all the windows between a reference time and now using a policy. + + Args: + policy: The policy implementing the window generation strategy. + reference: Optional reference time for starting window generation + (defaults to project_start_time). + now: Optional current time (defaults to current UTC time). + + Returns: + WindowList containing generated windows and the next start time. + + Raises: + ValueError: If the policy generates invalid windows (through Window validation). + ValueError: If the provided times are not in UTC. + """ + + # Ensure reference is a valid datetime + if reference is None: + reference = project_start_time + _validate_utc(reference, "reference") + + # Ensure now is a valid datetime + if now is None: + now = datetime_utcnow() + _validate_utc(now, "now") + + # Initialize by creating the initial window + windows: list[Window] = [] + window = policy.start_window(reference) + + # Generate windows until the current window contains now + while window.before_datetime(now) and not window.includes_datetime(now): + windows.append(window) + window = window.next_window() + + # Return the generated windows + return windows diff --git a/tests/aggregatetunnelmetrics/pipeline/__init__.py b/tests/aggregatetunnelmetrics/pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/aggregatetunnelmetrics/pipeline/test_windowpolicy.py b/tests/aggregatetunnelmetrics/pipeline/test_windowpolicy.py new file mode 100644 index 0000000..f5ed3aa --- /dev/null +++ b/tests/aggregatetunnelmetrics/pipeline/test_windowpolicy.py @@ -0,0 +1,334 @@ +"""Unit tests for the windowpolicy module.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +import unittest +from datetime import datetime, timedelta, timezone + +from aggregatetunnelmetrics.pipeline.windowpolicy import ( + DailyPolicy, + WeeklyPolicy, + Window, + _validate_utc, + datetime_utcnow, + generate_windows, +) + + +class TestWindow(unittest.TestCase): + """Test the Window class.""" + + def test_valid_window_creation(self): + """Test creating a valid window.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + delta = timedelta(days=1) + window = Window(start, end, delta) + self.assertEqual(window.start, start) + self.assertEqual(window.end, end) + self.assertEqual(window.delta, delta) + + def test_window_validation_non_utc(self): + """Test window creation with non-UTC times.""" + start = datetime(2024, 1, 1) # naive datetime + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + delta = timedelta(days=1) + with self.assertRaises(ValueError): + Window(start, end, delta) + + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2) # naive datetime + with self.assertRaises(ValueError): + Window(start, end, delta) + + def test_window_validation_invalid_times(self): + """Test window creation with invalid time combinations.""" + start = datetime(2024, 1, 2, tzinfo=timezone.utc) + end = datetime(2024, 1, 1, tzinfo=timezone.utc) + delta = timedelta(days=1) + with self.assertRaises(ValueError): + Window(start, end, delta) + + def test_window_validation_invalid_delta(self): + """Test window creation with invalid delta.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + delta = timedelta(days=-1) + with self.assertRaises(ValueError): + Window(start, end, delta) + + delta = timedelta(0) + with self.assertRaises(ValueError): + Window(start, end, delta) + + def test_window_validation_mismatched_delta(self): + """Test window creation with delta not matching start-end difference.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + delta = timedelta(days=2) + with self.assertRaises(ValueError): + Window(start, end, delta) + + def test_window_before_datetime(self): + """Test the before_datetime method.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + window = Window(start, end, timedelta(days=1)) + + # Test various cases + self.assertTrue( + window.before_datetime(datetime(2024, 1, 1, tzinfo=timezone.utc)) + ) + self.assertTrue( + window.before_datetime(datetime(2024, 1, 2, tzinfo=timezone.utc)) + ) + self.assertTrue( + window.before_datetime(datetime(2024, 1, 3, tzinfo=timezone.utc)) + ) + self.assertFalse( + window.before_datetime(datetime(2023, 12, 31, tzinfo=timezone.utc)) + ) + + def test_window_includes_datetime(self): + """Test the includes_datetime method.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + window = Window(start, end, timedelta(days=1)) + + # Test various cases + self.assertTrue( + window.includes_datetime(datetime(2024, 1, 1, tzinfo=timezone.utc)) + ) + self.assertTrue( + window.includes_datetime(datetime(2024, 1, 1, 12, tzinfo=timezone.utc)) + ) + self.assertFalse( + window.includes_datetime(datetime(2024, 1, 2, tzinfo=timezone.utc)) + ) + self.assertFalse( + window.includes_datetime(datetime(2023, 12, 31, tzinfo=timezone.utc)) + ) + + def test_window_next_window(self): + """Test the next_window method.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + window = Window(start, end, timedelta(days=1)) + + next_window = window.next_window() + self.assertEqual(next_window.start, window.end) + self.assertEqual(next_window.end, datetime(2024, 1, 3, tzinfo=timezone.utc)) + self.assertEqual(next_window.delta, window.delta) + + def test_window_string_representation(self): + """Test the string representation of Window.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + window = Window(start, end, timedelta(days=1)) + expected = f"Window({start.isoformat()} -> {end.isoformat()})" + self.assertEqual(str(window), expected) + + +class TestDailyPolicy(unittest.TestCase): + """Test the DailyPolicy class.""" + + def test_daily_policy_start_window_different_times(self): + """Test DailyPolicy.start_window with reference at different times of day.""" + policy = DailyPolicy() + + # Test different times during January 1, 2024 + test_cases = [ + # (reference_time, expected_start) + ( + datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), # midnight + datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 1, 12, 30, tzinfo=timezone.utc), # noon + datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 1, 23, 59, tzinfo=timezone.utc), # end of day + datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + ), + ] + + for reference, expected_start in test_cases: + with self.subTest(reference=reference): + window = policy.start_window(reference) + self.assertEqual(window.start, expected_start) + self.assertEqual(window.end, expected_start + timedelta(days=1)) + self.assertEqual(window.delta, timedelta(days=1)) + + def test_daily_policy_across_days(self): + """Test DailyPolicy.start_window across different days.""" + policy = DailyPolicy() + + test_cases = [ + # (reference_time, expected_start) + ( + datetime(2024, 1, 1, 12, tzinfo=timezone.utc), + datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 2, 12, tzinfo=timezone.utc), + datetime(2024, 1, 2, 0, 0, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 3, 12, tzinfo=timezone.utc), + datetime(2024, 1, 3, 0, 0, tzinfo=timezone.utc), + ), + ] + + for reference, expected_start in test_cases: + with self.subTest(reference=reference): + window = policy.start_window(reference) + self.assertEqual(window.start, expected_start) + self.assertEqual(window.end, expected_start + timedelta(days=1)) + self.assertEqual(window.delta, timedelta(days=1)) + + def test_daily_policy_non_utc(self): + """Test DailyPolicy with non-UTC reference.""" + policy = DailyPolicy() + reference = datetime(2024, 1, 1, 15, 30) # naive datetime + with self.assertRaises(ValueError): + policy.start_window(reference) + + +class TestWeeklyPolicy(unittest.TestCase): + """Test the WeeklyPolicy class.""" + + def test_weekly_policy_start_window(self): + """Test WeeklyPolicy.start_window for different days of the week.""" + policy = WeeklyPolicy() + + # Test for each day of the week + test_cases = [ + # Monday + ( + datetime(2024, 1, 1, tzinfo=timezone.utc), + datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + # Tuesday + ( + datetime(2024, 1, 2, tzinfo=timezone.utc), + datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + # Wednesday + ( + datetime(2024, 1, 3, tzinfo=timezone.utc), + datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + # Thursday + ( + datetime(2024, 1, 4, tzinfo=timezone.utc), + datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + # Friday + ( + datetime(2024, 1, 5, tzinfo=timezone.utc), + datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + # Saturday + ( + datetime(2024, 1, 6, tzinfo=timezone.utc), + datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + # Sunday + ( + datetime(2024, 1, 7, tzinfo=timezone.utc), + datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + ] + + for reference, expected_start in test_cases: + with self.subTest(reference=reference): + window = policy.start_window(reference) + self.assertEqual(window.start, expected_start) + self.assertEqual(window.end, expected_start + timedelta(days=7)) + self.assertEqual(window.delta, timedelta(days=7)) + + def test_weekly_policy_non_utc(self): + """Test WeeklyPolicy with non-UTC reference.""" + policy = WeeklyPolicy() + reference = datetime(2024, 1, 1, 15, 30) # naive datetime + with self.assertRaises(ValueError): + policy.start_window(reference) + + +class TestGenerateWindows(unittest.TestCase): + """Test the generate_windows function.""" + + def test_generate_windows_default_params(self): + """Test generate_windows with default parameters.""" + policy = DailyPolicy() + windows = generate_windows(policy) + self.assertIsInstance(windows, list) + self.assertGreater(len(windows), 0) + + def test_generate_windows_custom_reference(self): + """Test generate_windows with custom reference time.""" + policy = DailyPolicy() + reference = datetime(2024, 1, 1, tzinfo=timezone.utc) + now = datetime(2024, 1, 3, tzinfo=timezone.utc) + windows = generate_windows(policy, reference=reference, now=now) + + self.assertEqual(len(windows), 2) # Should have Jan 1-2 and Jan 2-3 + self.assertEqual(windows[0].start, reference) + + def test_generate_windows_non_utc(self): + """Test generate_windows with non-UTC times.""" + policy = DailyPolicy() + reference = datetime(2024, 1, 1) # naive datetime + now = datetime(2024, 1, 2, tzinfo=timezone.utc) + + with self.assertRaises(ValueError): + generate_windows(policy, reference=reference, now=now) + + reference = datetime(2024, 1, 1, tzinfo=timezone.utc) + now = datetime(2024, 1, 2) # naive datetime + + with self.assertRaises(ValueError): + generate_windows(policy, reference=reference, now=now) + + def test_generate_windows_reference_after_now(self): + """Test generate_windows with reference time after now.""" + policy = DailyPolicy() + reference = datetime(2024, 1, 2, tzinfo=timezone.utc) + now = datetime(2024, 1, 1, tzinfo=timezone.utc) + + windows = generate_windows(policy, reference=reference, now=now) + self.assertEqual(len(windows), 0) + + def test_generate_windows_same_day(self): + """Test generate_windows with reference and now on the same day.""" + policy = DailyPolicy() + reference = datetime(2024, 1, 1, tzinfo=timezone.utc) + now = datetime(2024, 1, 1, 12, tzinfo=timezone.utc) + + windows = generate_windows(policy, reference=reference, now=now) + self.assertEqual(len(windows), 0) + + +class TestUtilities(unittest.TestCase): + """Test utility functions.""" + + def test_datetime_utcnow(self): + """Test datetime_utcnow function.""" + now = datetime_utcnow() + self.assertEqual(now.tzinfo, timezone.utc) + + def test_validate_utc(self): + """Test _validate_utc function.""" + # Valid UTC datetime + dt = datetime(2024, 1, 1, tzinfo=timezone.utc) + _validate_utc(dt, "test") # Should not raise + + # Naive datetime + dt = datetime(2024, 1, 1) + with self.assertRaises(ValueError): + _validate_utc(dt, "test") + + +if __name__ == "__main__": + unittest.main() -- GitLab From afc60507e39bb753724e4108c878cd9b49b3471e Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 23 Feb 2025 10:24:23 +0100 Subject: [PATCH 51/75] fix: remove unneeded import --- aggregatetunnelmetrics/pipeline/config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aggregatetunnelmetrics/pipeline/config.py b/aggregatetunnelmetrics/pipeline/config.py index 848bdc1..009aba2 100644 --- a/aggregatetunnelmetrics/pipeline/config.py +++ b/aggregatetunnelmetrics/pipeline/config.py @@ -7,7 +7,6 @@ This module is internal. Please, import `pipeline` directly instead. # SPDX-License-Identifier: GPL-3.0-or-later from dataclasses import dataclass -from datetime import datetime from .. import lockedfile -- GitLab From 19adc411e80f25fc3763e6360e5527809790043a Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 23 Feb 2025 10:37:02 +0100 Subject: [PATCH 52/75] feat(pipeline): write tests for config.py and state.py --- .../pipeline/test_config.py | 91 ++++++++++++ .../pipeline/test_state.py | 131 ++++++++++++++++++ 2 files changed, 222 insertions(+) create mode 100644 tests/aggregatetunnelmetrics/pipeline/test_config.py create mode 100644 tests/aggregatetunnelmetrics/pipeline/test_state.py diff --git a/tests/aggregatetunnelmetrics/pipeline/test_config.py b/tests/aggregatetunnelmetrics/pipeline/test_config.py new file mode 100644 index 0000000..2caf3e8 --- /dev/null +++ b/tests/aggregatetunnelmetrics/pipeline/test_config.py @@ -0,0 +1,91 @@ +"""Tests for the pipeline configuration module.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +import unittest + +from aggregatetunnelmetrics.pipeline.config import ProcessConfig, FileIOConfig +from aggregatetunnelmetrics.lockedfile import FileIOConfig as LockedFileIOConfig + + +class TestProcessConfig(unittest.TestCase): + """Test ProcessConfig functionality.""" + + def test_valid_minimal_config(self): + """Test creating config with just mandatory fields.""" + config = ProcessConfig( + provider="test-provider", + upstream_collector="test-collector", + probe_asn="AS12345", + probe_cc="XX", + collector_base_url="https://api.ooni.io/", + ) + + self.assertEqual(config.provider, "test-provider") + self.assertEqual(config.upstream_collector, "test-collector") + self.assertEqual(config.probe_asn, "AS12345") + self.assertEqual(config.probe_cc, "XX") + self.assertEqual(config.collector_base_url, "https://api.ooni.io/") + self.assertEqual(config.min_sample_size, 1000) # default value + self.assertEqual(config.timeout, 30.0) # default value + + def test_valid_full_config(self): + """Test creating config with all fields specified.""" + config = ProcessConfig( + provider="test-provider", + upstream_collector="test-collector", + probe_asn="AS12345", + probe_cc="XX", + collector_base_url="https://api.ooni.io/", + min_sample_size=500, + timeout=60.0, + ) + + self.assertEqual(config.provider, "test-provider") + self.assertEqual(config.upstream_collector, "test-collector") + self.assertEqual(config.probe_asn, "AS12345") + self.assertEqual(config.probe_cc, "XX") + self.assertEqual(config.collector_base_url, "https://api.ooni.io/") + self.assertEqual(config.min_sample_size, 500) + self.assertEqual(config.timeout, 60.0) + + +class TestFileIOConfig(unittest.TestCase): + """Test FileIOConfig functionality.""" + + def test_valid_minimal_config(self): + """Test creating config with just mandatory fields.""" + config = FileIOConfig(state_file="/path/to/state.json") + + self.assertEqual(config.state_file, "/path/to/state.json") + self.assertEqual(config.num_retries, 10) # default value + self.assertEqual(config.sleep_interval, 0.1) # default value + + def test_valid_full_config(self): + """Test creating config with all fields specified.""" + config = FileIOConfig( + state_file="/path/to/state.json", + num_retries=5, + sleep_interval=0.2, + ) + + self.assertEqual(config.state_file, "/path/to/state.json") + self.assertEqual(config.num_retries, 5) + self.assertEqual(config.sleep_interval, 0.2) + + def test_conversion_to_lockedfile_config(self): + """Test conversion to lockedfile.FileIOConfig.""" + config = FileIOConfig( + state_file="/path/to/state.json", + num_retries=5, + sleep_interval=0.2, + ) + + locked_config = config.as_lockedfile_fileio_config() + self.assertIsInstance(locked_config, LockedFileIOConfig) + self.assertEqual(locked_config.num_retries, 5) + self.assertEqual(locked_config.sleep_interval, 0.2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/aggregatetunnelmetrics/pipeline/test_state.py b/tests/aggregatetunnelmetrics/pipeline/test_state.py new file mode 100644 index 0000000..5541eb8 --- /dev/null +++ b/tests/aggregatetunnelmetrics/pipeline/test_state.py @@ -0,0 +1,131 @@ +"""Tests for the pipeline state management.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from datetime import datetime, timezone +import json +import os +import tempfile +import unittest +from unittest.mock import patch + +from aggregatetunnelmetrics.pipeline.state import ProcessorState +from aggregatetunnelmetrics.pipeline.errors import StateError +from aggregatetunnelmetrics.lockedfile import FileIOConfig + + +class TestProcessorState(unittest.TestCase): + """Test ProcessorState functionality.""" + + def setUp(self): + """Create a temporary file for testing.""" + self.temp_dir = tempfile.mkdtemp() + self.state_file = os.path.join(self.temp_dir, "state.json") + self.config = FileIOConfig(num_retries=1, sleep_interval=0.1) + + def tearDown(self): + """Clean up temporary files.""" + try: + os.unlink(self.state_file) + os.rmdir(self.temp_dir) + except FileNotFoundError: + pass + + def test_initial_state(self): + """Test initial state creation.""" + state = ProcessorState() + self.assertIsNone(state.next_submission_after) + + def test_load_nonexistent_file(self): + """Test loading from a nonexistent file returns default state.""" + state = ProcessorState.load(self.state_file, self.config) + self.assertIsNone(state.next_submission_after) + + def test_load_valid_state(self): + """Test loading valid state from file.""" + # Write valid state file + test_time = "20240101T120000Z" + with open(self.state_file, "w") as f: + json.dump({"next_submission_after": test_time}, f) + + state = ProcessorState.load(self.state_file, self.config) + expected_dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + self.assertEqual(state.next_submission_after, expected_dt) + + def test_load_corrupt_json(self): + """Test loading corrupted JSON state file.""" + # Write invalid JSON + with open(self.state_file, "w") as f: + f.write("not valid json{") + + with self.assertRaises(StateError) as cm: + ProcessorState.load(self.state_file, self.config) + self.assertIn("Corrupt state file", str(cm.exception)) + + def test_load_invalid_datetime(self): + """Test loading state with invalid datetime format.""" + # Write state with invalid datetime + with open(self.state_file, "w") as f: + json.dump({"next_submission_after": "invalid-date"}, f) + + with self.assertRaises(StateError) as cm: + ProcessorState.load(self.state_file, self.config) + self.assertIn("Invalid datetime in state", str(cm.exception)) + + def test_load_missing_next_submission_after(self): + """Test loading state file with missing next_submission_after field returns default state.""" + # Write state file without next_submission_after field + with open(self.state_file, "w") as f: + json.dump({}, f) + + state = ProcessorState.load(self.state_file, self.config) + self.assertIsNone(state.next_submission_after) + + def test_save_state(self): + """Test saving state to file.""" + dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + state = ProcessorState(next_submission_after=dt) + state.save(self.state_file, self.config) + + # Verify saved content + with open(self.state_file) as f: + saved_state = json.load(f) + self.assertEqual(saved_state["next_submission_after"], "20240101T120000Z") + + def test_save_none_state(self): + """Test saving state with None datetime.""" + state = ProcessorState(next_submission_after=None) + state.save(self.state_file, self.config) + + # Verify saved content + with open(self.state_file) as f: + saved_state = json.load(f) + self.assertIsNone(saved_state["next_submission_after"]) + + def test_save_with_file_error(self): + """Test saving state with file write error.""" + state = ProcessorState() + + # Mock file write to fail + with patch("aggregatetunnelmetrics.lockedfile.write") as mock_write: + mock_write.side_effect = IOError("Mock write error") + + with self.assertRaises(IOError): + state.save(self.state_file, self.config) + + def test_load_save_roundtrip(self): + """Test that saved state can be loaded correctly.""" + # Create and save initial state + dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + original_state = ProcessorState(next_submission_after=dt) + original_state.save(self.state_file, self.config) + + # Load state back + loaded_state = ProcessorState.load(self.state_file, self.config) + + # Verify loaded state matches original + self.assertEqual(loaded_state.next_submission_after, dt) + + +if __name__ == "__main__": + unittest.main() -- GitLab From 006f98f59e4ecfe37fe83c12ca0b3a69299df0b6 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 23 Feb 2025 15:41:49 +0100 Subject: [PATCH 53/75] feat: start adding common interface and using pytest --- .../aggregators/__init__.py | 0 aggregatetunnelmetrics/aggregators/common.py | 274 ++++++++++++++ .../globalscope/aggregate.py | 1 - aggregatetunnelmetrics/spec/__init__.py | 45 +++ aggregatetunnelmetrics/spec/aggregator.py | 32 ++ aggregatetunnelmetrics/spec/fieldtesting.py | 80 ++++ aggregatetunnelmetrics/spec/filelocking.py | 63 ++++ aggregatetunnelmetrics/spec/metrics.py | 336 +++++++++++++++++ aggregatetunnelmetrics/spec/oonicollector.py | 222 +++++++++++ pyproject.toml | 14 + .../lockedfile/test_fileio.py | 29 +- .../lockedfile/test_mutex.py | 15 +- .../pipeline/test_processor.py | 354 ++++++++++++++++++ tests/aggregatetunnelmetrics/spec/__init__.py | 3 + .../spec/test_aggregator.py | 23 ++ .../spec/test_fieldtesting.py | 42 +++ .../spec/test_filelocking.py | 42 +++ .../spec/test_metrics.py | 175 +++++++++ .../spec/test_oonicollector.py | 153 ++++++++ uv.lock | 190 ++++++++++ 20 files changed, 2075 insertions(+), 18 deletions(-) create mode 100644 aggregatetunnelmetrics/aggregators/__init__.py create mode 100644 aggregatetunnelmetrics/aggregators/common.py create mode 100644 aggregatetunnelmetrics/spec/__init__.py create mode 100644 aggregatetunnelmetrics/spec/aggregator.py create mode 100644 aggregatetunnelmetrics/spec/fieldtesting.py create mode 100644 aggregatetunnelmetrics/spec/filelocking.py create mode 100644 aggregatetunnelmetrics/spec/metrics.py create mode 100644 aggregatetunnelmetrics/spec/oonicollector.py create mode 100644 pyproject.toml create mode 100644 tests/aggregatetunnelmetrics/pipeline/test_processor.py create mode 100644 tests/aggregatetunnelmetrics/spec/__init__.py create mode 100644 tests/aggregatetunnelmetrics/spec/test_aggregator.py create mode 100644 tests/aggregatetunnelmetrics/spec/test_fieldtesting.py create mode 100644 tests/aggregatetunnelmetrics/spec/test_filelocking.py create mode 100644 tests/aggregatetunnelmetrics/spec/test_metrics.py create mode 100644 tests/aggregatetunnelmetrics/spec/test_oonicollector.py create mode 100644 uv.lock diff --git a/aggregatetunnelmetrics/aggregators/__init__.py b/aggregatetunnelmetrics/aggregators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aggregatetunnelmetrics/aggregators/common.py b/aggregatetunnelmetrics/aggregators/common.py new file mode 100644 index 0000000..5becd17 --- /dev/null +++ b/aggregatetunnelmetrics/aggregators/common.py @@ -0,0 +1,274 @@ +""" +Common Aggregation Code +======================= + +TODO... +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + + +from dataclasses import dataclass, field +from statistics import quantiles +from typing import Iterator, Protocol, runtime_checkable + +from . import fieldtesting +from . import metrics +from . import oonicollector + + +def make_distribution(values: list[float]) -> metrics.Distribution | None: + """Generates an empirical distribution from a list of values.""" + # TODO(bassosimone): this code is most likely wrong + if not values: + return None + q = quantiles(values, n=100, method="exclusive") + return metrics.Distribution( + p25=q[24], + p50=q[49], + p75=q[74], + p99=q[98], + ) + + +@dataclass +class CreationMetrics: + """ + Allows tracking tunnel creation metrics. + + Fields: + errors: A dictionary with error type as key and error count as value. + num_samples: The total number of samples. + + Methods: + update: Updates the metrics with a new entry. + """ + + errors: dict[str, int] = field(default_factory=dict) + num_samples: int = 0 + + def update(self, entry: fieldtesting.Entry) -> None: + """Updates the metrics with a new entry.""" + if entry.is_tunnel_measurement(): + if entry.is_tunnel_error_measurement(): + error_type = "bootstrap.generic_error" + self.errors[error_type] = self.errors.get(error_type, 0) + 1 + self.num_samples += 1 + + def statements(self) -> list[metrics.NetworkErrorStatement]: + """Generates statements for the creation metrics.""" + result = [] + for error, count in self.errors.items(): + result.append( + metrics.NetworkErrorStatement( + sample_size=self.num_samples, + failure_ratio=count / self.num_samples, + error=error, + ) + ) + return result + + +@dataclass +class PingMetricsPerTarget: + """ + Allows tracking tunnel ping metrics for a specific target. + + Fields: + target_address: The optional ping target address. + min: The minimum roundtrip time. + avg: The average roundtrip time. + max: The maximum roundtrip time. + loss: The packet loss. + + Methods: + update: Updates the metrics with a new entry. + """ + + target_address: str | None + min: list[float] = field(default_factory=list) + avg: list[float] = field(default_factory=list) + max: list[float] = field(default_factory=list) + loss: list[float] = field(default_factory=list) + num_samples: int = 0 + + def update(self, entry: fieldtesting.Entry) -> None: + """Updates the metrics with a new entry.""" + if entry.is_tunnel_measurement() and not entry.is_tunnel_error_measurement(): + self.min.append(entry.ping_roundtrip_min) + self.avg.append(entry.ping_roundtrip_avg) + self.max.append(entry.ping_roundtrip_max) + self.loss.append(entry.ping_packets_loss) + self.num_samples += 1 + + def statements(self) -> list[metrics.TunnelPingStatement]: + """Generates statements for the ping metrics.""" + return [ + metrics.TunnelPingStatement( + target_address=self.target_address, + sample_size=self.num_samples, + latency_avg=make_distribution(self.avg), + ) + ] + + +@dataclass +class PingMetricsOverall: + """ + Allows tracking tunnel ping metrics for all targets. + + Fields: + targets: A dictionary mapping targets to their metrics. + + Methods: + update: Updates the metrics with a new entry. + """ + + targets: dict[str, PingMetricsPerTarget] = field(default_factory=dict) + + def update(self, entry: fieldtesting.Entry) -> None: + """Updates the metrics with a new entry.""" + if entry.is_tunnel_measurement() and not entry.is_tunnel_error_measurement(): + target = "f{entry.ping_target_address}" + if target not in self.targets: + self.targets[target] = PingMetricsPerTarget(target_address=target) + self.targets[target].update(entry) + + def statements(self) -> list[metrics.TunnelPingStatement]: + """Generates statements for ping metrics.""" + results = [] + for target in self.targets.values(): + results.extend(target.statements()) + return results + + +@dataclass +class NDTMetricsPerTarget: + """ + Allows tracking tunnel NDT metrics for a specific target. + + Fields: + target_hostname: The optional NDT target hostname. + target_address: The optional NDT target address. + target_port: The optional NDT target port. + download_throughput: The download throughput. + download_latency: The download latency. + download_rexmit: The download retransmission. + upload_throughput: The upload throughput. + upload_latency: The upload latency. + upload_rexmit: The upload retransmission. + + Methods: + update: Updates the metrics with a new entry. + """ + + target_hostname: str | None + target_address: str | None + target_port: int | None + download_throughput: list[float] = field(default_factory=list) + download_latency: list[float] = field(default_factory=list) + download_rexmit: list[float] = field(default_factory=list) + upload_throughput: list[float] = field(default_factory=list) + upload_latency: list[float] = field(default_factory=list) + upload_rexmit: list[float] = field(default_factory=list) + num_samples: int = 0 + + def update(self, entry: fieldtesting.Entry) -> None: + """Updates the metrics with a new entry.""" + if entry.is_tunnel_measurement() and not entry.is_tunnel_error_measurement(): + self.download_throughput.append(entry.throughput_download) + self.download_latency.append(entry.latency_download) + self.download_rexmit.append(entry.retransmission_download) + self.upload_throughput.append(entry.throughput_upload) + self.upload_latency.append(entry.latency_upload) + self.upload_rexmit.append(entry.retransmission_upload) + self.num_samples += 1 + + def statements(self) -> list[metrics.TunnelNDTStatement]: + """Generates statements for NDT metrics.""" + return [ + metrics.TunnelNDTStatement( + direction="download", + target_hostname=self.target_hostname, + target_address=self.target_address, + target_port=self.target_port, + sample_size=self.num_samples, + latency=make_distribution(self.download_latency), + speed=make_distribution(self.download_throughput), + ), + metrics.TunnelNDTStatement( + direction="upload", + target_hostname=self.target_hostname, + target_address=self.target_address, + target_port=self.target_port, + sample_size=self.num_samples, + latency=make_distribution(self.upload_latency), + speed=make_distribution(self.upload_throughput), + ), + ] + + +@dataclass +class NDTMetricsOverall: + """ + Allows tracking tunnel NDT metrics for all targets. + + Fields: + known_targets: A dictionary mapping targets to their metrics. + + Methods: + update: Updates the metrics with a new entry. + """ + + targets: dict[str, NDTMetricsPerTarget] = field(default_factory=dict) + + def update(self, entry: fieldtesting.Entry) -> None: + """Updates the metrics with a new entry.""" + if entry.is_tunnel_measurement() and not entry.is_tunnel_error_measurement(): + key = f"{entry.ndt_target_hostname} {entry.ndt_target_address} {entry.ndt_target_port}" + if key not in self.targets: + self.targets[key] = NDTMetricsPerTarget( + target_hostname=entry.ndt_target_hostname, + target_address=entry.ndt_target_address, + target_port=entry.ndt_target_port, + ) + self.targets[key].update(entry) + + def statements(self) -> list[metrics.TunnelNDTStatement]: + """Generates statements for NDT metrics.""" + results = [] + for target in self.targets.values(): + results.extend(target.statements()) + return results + + +@dataclass +class AggregationUnitMetrics: + """ + Allows tracking an aggregation unit metrics. + + The aggregation unit depends on the aggregation policy. + + Fields: + creation: The creation metrics. + tunnel_ping: The tunnel ping metrics. + tunnel_ndt: The tunnel NDT metrics. + """ + + creation: CreationMetrics = CreationMetrics() + tunnel_ping: PingMetricsOverall = PingMetricsOverall() + tunnel_ndt: NDTMetricsOverall = NDTMetricsOverall() + + def update(self, entry: fieldtesting.Entry) -> None: + """Updates the metrics with a new entry.""" + self.creation.update(entry) + self.tunnel_ping.update(entry) + self.tunnel_ndt.update(entry) + + def statements(self) -> list[metrics.Statement]: + """Generates a statement for the aggregation unit metrics.""" + result = [] + result.extend(self.creation.statements()) + result.extend(self.tunnel_ping.statements()) + result.extend(self.tunnel_ndt.statements()) + return result diff --git a/aggregatetunnelmetrics/globalscope/aggregate.py b/aggregatetunnelmetrics/globalscope/aggregate.py index 4666e14..a8fd24d 100644 --- a/aggregatetunnelmetrics/globalscope/aggregate.py +++ b/aggregatetunnelmetrics/globalscope/aggregate.py @@ -10,7 +10,6 @@ from __future__ import annotations from dataclasses import dataclass, field from datetime import datetime, timezone -from enum import Enum from .. import fieldtestingcsv diff --git a/aggregatetunnelmetrics/spec/__init__.py b/aggregatetunnelmetrics/spec/__init__.py new file mode 100644 index 0000000..3590487 --- /dev/null +++ b/aggregatetunnelmetrics/spec/__init__.py @@ -0,0 +1,45 @@ +""" +Pipeline Specification +====================== + +This package specifies the field-testing data aggregation +pipeline, defining core data structures and behaviour. + +The code in this package is close to literate programming; reading +it should provide an overall pipeline understanding. + +High-Level Overview +------------------- + +The pipeline: + +- reads field testing data organized as CSV files; +- manages file locking to ensure consistency; +- aggregates the CSV files content into aggregate metrics; +- transforms the aggregates into OONI measurements; +- submits measurements to the OONI collector; +- remembers when it stopped processing to avoid reprocessing. + +Data Flow +--------- + +CSV Files --> [File Locking] --> [Aggregation] --> [OONI Format] --> OONI Collector + ^ | + | v + +--- State File <-- [Track Progress] + +Aggregation applies time windows and scope policies (global/pool/endpoint). + +Modules +------- +- fieldtesting: Models raw CSV data format and streaming. +- filelocking: File locking API for safe concurrent access. +- metrics: Core data structures for aggregated metrics. +- aggregator: Logic for aggregating raw data into metrics. +- oonicollector: OONI submission protocol and client. + +Each module is designed to handle one specific aspect of the pipeline, with clear +interfaces between components to maintain separation of concerns. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/aggregatetunnelmetrics/spec/aggregator.py b/aggregatetunnelmetrics/spec/aggregator.py new file mode 100644 index 0000000..c852ae3 --- /dev/null +++ b/aggregatetunnelmetrics/spec/aggregator.py @@ -0,0 +1,32 @@ +""" +Aggregator Model +================ + +This module contains the generic aggregator model. + +Classes: + Logic: Models the generic aggregator logic. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + + +from typing import Iterator, Protocol, runtime_checkable + +from . import fieldtesting +from . import oonicollector + + +@runtime_checkable +class Logic(Protocol): + """ + Models the generic aggregator logic. + + Methods: + aggregate: Aggregates Entry into OONI Measurement. + """ + + def aggregate( + self, + entries: Iterator[fieldtesting.Entry], + ) -> Iterator[oonicollector.Measurement]: ... diff --git a/aggregatetunnelmetrics/spec/fieldtesting.py b/aggregatetunnelmetrics/spec/fieldtesting.py new file mode 100644 index 0000000..08dfa91 --- /dev/null +++ b/aggregatetunnelmetrics/spec/fieldtesting.py @@ -0,0 +1,80 @@ +""" +Field-Testing Model +=================== + +This module contains the field-testing CSV data model. + +Classes: + Entry: Models a single field-testing CSV entry. + Streamer: Allows streaming Entry objects from a given file. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from dataclasses import dataclass +from datetime import datetime +from typing import Iterator, Protocol, runtime_checkable + + +@dataclass(frozen=True) +class Entry: + """ + Models a single field-testing CSV entry. + + The order of the fields in this dataclass it the same + of the fields within the CSV file. + """ + + # Fields present in the CSV files as of 2024-12-06 + filename: str + date: datetime + asn: str + isp: str + est_city: str + user: str + region: str + server_fqdn: str + server_ip: str + mobile: bool + tunnel: str # 'baseline', 'tunnel', 'ERROR/baseline', 'ERROR/tunnel' + throughput_download: float + throughput_upload: float + latency_download: float + latency_upload: float + retransmission_download: float + retransmission_upload: float + ping_packets_loss: float + ping_roundtrip_min: float + ping_roundtrip_avg: float + ping_roundtrip_max: float + err_message: str + protocol: str + + # Additional fields not present in the CSV files + ping_target_address: str | None + ndt_target_hostname: str | None + ndt_target_address: str | None + ndt_target_port: int | None + + def is_tunnel_measurement(self) -> bool: + """ + Return whether this is a tunnel measurement, which includes both + successful and failed tunnel measurements. + """ + return self.tunnel in ("tunnel", "ERROR/tunnel") + + def is_tunnel_error_measurement(self) -> bool: + """Return whether this is a failed tunnel measurement""" + return self.tunnel == "ERROR/tunnel" + + +@runtime_checkable +class Streamer(Protocol): + """ + Allows streaming Entry objects from a given file. + + Methods: + stream: stream Entry objects from the given file. + """ + + def stream(self, filepath: str) -> Iterator[Entry]: ... diff --git a/aggregatetunnelmetrics/spec/filelocking.py b/aggregatetunnelmetrics/spec/filelocking.py new file mode 100644 index 0000000..e7ee73c --- /dev/null +++ b/aggregatetunnelmetrics/spec/filelocking.py @@ -0,0 +1,63 @@ +""" +File Locking API +================ + +This module defines the file-locking API we use. + +Classes: + ReadWriteConfig: configuration for locked file I/O operations. + Mutex: context manager using a lockfile for mutex. + API: abstract file-locking API used by the pipeline. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Protocol, runtime_checkable + + +@dataclass(frozen=True) +class ReadWriteConfig: + """ + Configuration for Protocol.readfile and Protocol.writefile. + + Fields: + num_retries: number of retries before bailing. + sleep_interval: seconds to sleep between each retry. + """ + + num_retries: int = 10 + sleep_interval: float = 0.1 + + +@runtime_checkable +class Mutex(Protocol): + """ + Context manager providing mutual exclusion via file locking. + + Regardless of exceptions, __exit__ unlocks the mutex. + """ + + def __enter__(self) -> Mutex: ... + + def __exit__(self, *args) -> None: ... + + +@runtime_checkable +class API(Protocol): + """ + Abstract file-locking API used by the pipeline. + + Methods: + readfile: lock file and read its contents. + writefile: lock file and write it. + mutex: create a mutex using a lockfile. + """ + + def readfile(self, config: ReadWriteConfig | None) -> str: ... + + def writefile(self, data: str, config: ReadWriteConfig | None) -> None: ... + + def mutex(self, filepath: str) -> Mutex: ... diff --git a/aggregatetunnelmetrics/spec/metrics.py b/aggregatetunnelmetrics/spec/metrics.py new file mode 100644 index 0000000..a36344c --- /dev/null +++ b/aggregatetunnelmetrics/spec/metrics.py @@ -0,0 +1,336 @@ +""" +Metrics Definitions +=================== + +Defines the metrics submitted by the aggregate_tunnel_metrics experiment. + +See https://0xacab.org/leap/dev-documentation/-/blob/no-masters/proposals/005-aggregate-tunnel-metrics.md. + +Classes: + MeasurementTestKeys: models the measurement test keys. + Scope: models the aggregation scope. + GlobalScope: models aggregation at the global scope. + EndpointPoolScope: models aggregation at the endpoint_pool scope. + EndpointScope: models aggregation at the endpoint scope. + TimeWindow: models the time window for aggregation. + Statement: models a statement included inside the MeasurementTestKeys. + NetworkErrorStatement: models a statement about network errors. + Distribution: models a distribution of values. + TunnelPingStatement: models a tunnel_ping statement. + TunnelNDTStatement: models a tunnel_ndt_{download,upload} statement. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + + +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class Scope(Protocol): + """ + Models a generic aggregation scope. + + Methods: + update_dict: Updates the raw dictionary with the scope keys. + """ + + def update_dict(self, dct: dict) -> None: ... + + +@dataclass(frozen=True) +class GlobalScope: + """ + Models aggregation at the global scope. + + Attributes: + protocol: the protocol being measured. + + Methods: + update_dict: Implements the Scope protocol. + """ + + protocol: str + + def update_dict(self, dct: dict) -> None: + """Implements the Scope protocol.""" + dct["scope"] = "global" + dct["protocol"] = self.protocol + + +@dataclass(frozen=True) +class EndpointPoolScope: + """ + Models aggregation at the endpoint pool scope. + + Attributes: + protocol: the protocol being measured. + cc: the country code of the measured endpoints. + + Methods: + update_dict: Implements the Scope protocol. + """ + + protocol: str + cc: str + + def update_dict(self, dct: dict) -> None: + """Implements the Scope protocol.""" + dct["scope"] = "endpoint_pool" + dct["protocol"] = self.protocol + dct["cc"] = self.cc + + +@dataclass(frozen=True) +class EndpointScope: + """ + Models aggregation at the endpoint scope. + + Attributes: + protocol: the protocol being measured. + cc: the country code of the measured endpoints. + asn: the ASN of the measured endpoints. + endpoint_hostname: the hostname of the measured endpoints. + endpoint_address: the IP address of the measured endpoints. + endpoint_port: the port of the measured endpoints. + + Methods: + update_dict: Implements the Scope protocol. + """ + + protocol: str + cc: str + asn: str + endpoint_hostname: str + endpoint_address: str + endpoint_port: int + + def update_dict(self, dct: dict) -> None: + """Implements the Scope protocol.""" + dct["scope"] = "endpoint" + dct["protocol"] = self.protocol + dct["cc"] = self.cc + dct["asn"] = self.asn + dct["endpoint_hostname"] = self.endpoint_hostname + dct["endpoint_address"] = self.endpoint_address + dct["endpoint_port"] = self.endpoint_port + + +@dataclass(frozen=True) +class TimeWindow: + """ + Represents a time window for aggregation. + + Fields: + start: start of the time window. + end: end of the time window. + + Methods: + as_dict: Converts the time window to a JSON-serializable dict. + """ + + start: datetime + end: datetime + + def as_dict(self) -> dict: + """Converts the time window to a JSON-serializable dict.""" + return { + "start": format_datetime(self.start), + "end": format_datetime(self.end), + } + + +@runtime_checkable +class Statement(Protocol): + """ + A statement included inside the TestKeys. + + Methods: + as_dict: Converts the statement to a JSON-serializable dict. + """ + + def as_dict(self) -> dict: ... + + +@dataclass(frozen=True) +class NetworkErrorStatement: + """ + Statement about the network errors that occurred. + + Fields: + sample_size: the number of samples. + failure_ratio: the ratio of failures. + error: the error that occurred. + + Methods: + as_dict: Implements the Statement protocol. + """ + + sample_size: int | None + failure_ratio: float + error: str + + def as_dict(self) -> dict: + """Implements the Statement protocol.""" + return { + "phase": "creation", + "sample_size": self.sample_size, + "type": "network-error", + "failure_ratio": self.failure_ratio, + "error": self.error, + } + + +@dataclass(frozen=True) +class Distribution: + """ + Represents a distribution of values. + + Fields: + p25: the 25th percentile. + p50: the 50th percentile. + p75: the 75th percentile. + p99: the 99th percentile. + + Methods: + as_dict: Converts the distribution to a JSON-serializable dict. + """ + + p25: float + p50: float + p75: float + p99: float + + def as_dict(self) -> dict: + """Converts the distribution to a JSON-serializable dict.""" + return { + "25p": self.p25, + "50p": self.p50, + "75p": self.p75, + "99p": self.p99, + } + + +@dataclass(frozen=True) +class TunnelPingStatement: + """ + Models a tunnel_ping statement. + + Fields: + target_address: the IP address being pinged. + sample_size: the number of samples. + latency_min: the minimum latency distribution. + latency_avg: the average latency distribution. + latency_max: the maximum latency distribution. + loss: the packet loss distribution. + + Methods: + as_dict: Implements the Statement protocol. + """ + + target_address: str | None + sample_size: int + latency_min: Distribution | None + latency_avg: Distribution | None + latency_max: Distribution | None + loss: Distribution | None + + def as_dict(self) -> dict: + """Implements the Statement protocol.""" + return { + "phase": "tunnel_ping", + "target_address": self.target_address, + "sample_size": self.sample_size, + "type": "ping", + "latency_min_ms": self.latency_min.as_dict() if self.latency_min else None, + "latency_avg_ms": self.latency_avg.as_dict() if self.latency_avg else None, + "latency_max_ms": self.latency_max.as_dict() if self.latency_max else None, + "loss": self.loss.as_dict() if self.loss else None, + } + + +@dataclass(frozen=True) +class TunnelNDTStatement: + """ + Models a tunnel_ndt_{download,upload} statement. + + Fields: + direction: "download" or "upload". + target_hostname: the hostname being measured. + target_address: the IP address being measured. + target_port: the port being measured. + sample_size: the number of samples. + latency: the latency distribution. + speed: the download speed distribution. + rexmit: the retransmission distribution. + + Methods: + as_dict: Implements the Statement protocol. + """ + + direction: str # "download" | "upload" + target_hostname: str | None + target_address: str | None + target_port: int | None + sample_size: int | None + latency: Distribution | None + speed: Distribution | None + rexmit: Distribution | None + + def as_dict(self) -> dict: + """Implements the Statement protocol.""" + return { + "phase": f"tunnel_ndt_{self.direction}", + "target_hostname": self.target_hostname, + "target_address": self.target_address, + "target_port": self.target_port, + "sample_size": self.sample_size, + "type": "ndt_download", + "latency_ms": self.latency.as_dict() if self.latency else None, + "speed_mbits": self.speed.as_dict() if self.speed else None, + "rexmit": self.rexmit.as_dict() if self.rexmit else None, + } + + +@dataclass(frozen=True) +class MeasurementTestKeys: + """ + Models the aggregate_tunnel_metrics measurement test keys. + + Fields: + provider: the provider being measured. + scope: the scope of the measurement. + time_window: the time window of the measurement. + bodies: the statements made during the measurement. + + Methods: + as_dict: Converts the test keys to a JSON-serializable dict. + """ + + # for this provider + provider: str + + # with this scope + scope: Scope + + # in this time window + time_window: TimeWindow + + # we make the following statements + bodies: list[Statement] + + def as_dict(self) -> dict: + """Converts the test keys to a JSON-serializable dict.""" + dct = {} + dct["provider"] = self.provider + self.scope.update_dict(dct) + dct["time_window"] = self.time_window.as_dict() + dct["bodies"] = [body.as_dict() for body in self.bodies] + return dct + + +def format_datetime(dt: datetime) -> str: + """Convert datetime to compact UTC format (YYYYMMDDThhmmssZ)""" + return dt.astimezone(timezone.utc).strftime("%Y%m%dT%H%M%SZ") diff --git a/aggregatetunnelmetrics/spec/oonicollector.py b/aggregatetunnelmetrics/spec/oonicollector.py new file mode 100644 index 0000000..5784d71 --- /dev/null +++ b/aggregatetunnelmetrics/spec/oonicollector.py @@ -0,0 +1,222 @@ +""" +OONI Collector Model +==================== + +This module contains the OONI collector data model. + +Classes: + TestKeys: models the OONI measurement experiment-specific test keys. + Measurement: models the OONI measurement envelope. + Config: configures the OONI collector client. + OpenReportRequest: contains data required to open an OONI report. + ReportID: type alias describing an open OONI report ID. + MaybeMeasurementID: type alias describing an optional OONI measurement ID. + Client: allows submitting measurements to the OONI collector. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + +from dataclasses import dataclass, field, replace +from datetime import datetime, timezone +from typing import Any, Protocol, runtime_checkable + + +@runtime_checkable +class TestKeys(Protocol): + """ + Models the OONI measurement test keys. + + Methods: + as_dict: Converts the test keys to a JSON-serializable dict. + """ + + def as_dict(self) -> dict[str, Any]: ... + + +@dataclass(frozen=True) +class Measurement: + """ + Models the OONI measurement envelope. + + Methods: + as_open_report_request: Converts the measurement to an OpenReportRequest. + as_dict: Converts the Measurement to a JSON-serializable dict. + with_report_id: Creates a new Measurement instance with the given report_id. + """ + + # mandatory fields + annotations: dict[str, str] + data_format_version: str + input: str # e.g., {protocol}://{provider}/?{query_string} + measurement_start_time: datetime + probe_asn: str # Format: ^AS[0-9]+$ + probe_cc: str # Format: ^[A-Z]{2}$ + software_name: str + software_version: str + test_keys: TestKeys + test_name: str + test_runtime: float + test_start_time: datetime + test_version: str + + # Fields emitted with possibly default values + probe_ip: str = "127.0.0.1" + report_id: str = "" + + # Optional fields + options: list[str] = field(default_factory=list) + probe_network_name: str = "" + resolver_asn: str = "" + resolver_cc: str = "" + resolver_ip: str = "" + resolver_network_name: str = "" + test_helpers: dict[str, Any] = field(default_factory=dict) + + def as_open_report_request(self) -> OpenReportRequest: + """Converts the measurement to an OpenReportRequest.""" + return OpenReportRequest( + probe_asn=self.probe_asn, + probe_cc=self.probe_cc, + software_name=self.software_name, + software_version=self.software_version, + test_name=self.test_name, + test_start_time=self.test_start_time, + test_version=self.test_version, + ) + + def as_dict(self) -> dict: + """Converts the measurement to a JSON-serializable dict""" + + # Add mandatory fields + dct = { + "annotations": self.annotations, + "data_format_version": self.data_format_version, + "input": self.input, + "measurement_start_time": format_datetime(self.measurement_start_time), + "probe_asn": self.probe_asn, + "probe_cc": self.probe_cc, + "software_name": self.software_name, + "software_version": self.software_version, + "test_keys": self.test_keys.as_dict(), + "test_name": self.test_name, + "test_runtime": self.test_runtime, + "test_start_time": format_datetime(self.test_start_time), + "test_version": self.test_version, + } + + # Fields emitted with possibly default values + dct["probe_ip"] = self.probe_ip if self.probe_ip else "127.0.0.1" + dct["report_id"] = self.report_id + + # Add optional fields + if self.options: + dct["options"] = self.options + if self.probe_network_name: + dct["probe_network_name"] = self.probe_network_name + if self.resolver_asn: + dct["resolver_asn"] = self.resolver_asn + if self.resolver_cc: + dct["resolver_cc"] = self.resolver_cc + if self.resolver_ip: + dct["resolver_ip"] = self.resolver_ip + if self.resolver_network_name: + dct["resolver_network_name"] = self.resolver_network_name + if self.test_helpers: + dct["test_helpers"] = self.test_helpers + + return dct + + def with_report_id(self, report_id: str) -> Measurement: + """Creates a new Measurement instance with the given report_id.""" + return replace(self, report_id=report_id) + + +def format_datetime(dt: datetime) -> str: + """Converts a datetime to OONI's datetime format (YYYY-mm-dd HH:MM:SS).""" + return dt.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") + + +@dataclass(frozen=True) +class Config: + """ + Configures the OONI collector client. + + Fields: + collector_base_url: mandatory base URL to use (e.g., "https://api.ooni.io/"). + timeout: optional operations timeout in seconds (default: 30.0). + """ + + collector_base_url: str # e.g., "https://api.ooni.io/" + timeout: float = 30.0 + + +@dataclass(frozen=True) +class OpenReportRequest: + """ + Contains data required to open an OONI report. + + Fields: + report_id: the report ID. + test_name: the test name. + test_version: the test version. + software_name: the software name. + software_version: the software version. + probe_asn: the probe ASN. + probe_cc: the probe country code. + test_start_time: the test start time. + """ + + probe_asn: str + probe_cc: str + software_name: str + software_version: str + test_name: str + test_start_time: datetime + test_version: str + + def as_dict(self) -> dict: + return { + "data_format_version": "0.2.0", + "format": "json", + "probe_asn": self.probe_asn, + "probe_cc": self.probe_cc, + "software_name": self.software_name, + "software_version": self.software_version, + "test_name": self.test_name, + "test_start_time": format_datetime(self.test_start_time), + "test_version": self.test_version, + } + + +ReportID = str +"""Type alias describing an open OONI report ID.""" + + +MaybeMeasurementID = str | None +""" +Type alias describing an optional OONI measurement ID. + +The current API does not guaranteee that a measurement ID is returned\ +upon submitting a measurement to the OONI collector. +""" + + +@runtime_checkable +class Client(Protocol): + """ + Allows submitting measurements to the OONI collector. + + Methods: + create_report: Creates a new report and returns the report ID. + submit_measurement: Submit measurement and returns the measurement ID. + """ + + def create_report(self, req: OpenReportRequest) -> ReportID: ... + + def submit_measurement( + self, + rid: ReportID, + m: Measurement, + ) -> MaybeMeasurementID: ... diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c017058 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "aggregatetunnelmetrics" +version = "0.1.0" +description = "Solitech monitoring library" +requires-python = ">=3.12" +dependencies = [] +license = "GPL-3.0-or-later" + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[dependency-groups] +dev = ["black>=24.10.0", "pytest", "pytest-cov"] diff --git a/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py b/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py index b3d42da..b0004ec 100644 --- a/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py +++ b/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py @@ -47,9 +47,11 @@ class TestFileIOUnit(unittest.TestCase): """Test read fails when unable to acquire lock.""" m = mock_open() - with patch("builtins.open", m), patch( - "fcntl.flock", side_effect=BlockingIOError() - ), patch("time.sleep"): + with ( + patch("builtins.open", m), + patch("fcntl.flock", side_effect=BlockingIOError()), + patch("time.sleep"), + ): with self.assertRaises(common.FileLockError): fileio.read("test.txt") @@ -59,9 +61,11 @@ class TestFileIOUnit(unittest.TestCase): test_content = "test content" m = mock_open() - with patch("builtins.open", m), patch("fcntl.flock") as mock_flock, patch( - "os.fsync" - ) as mock_fsync: + with ( + patch("builtins.open", m), + patch("fcntl.flock") as mock_flock, + patch("os.fsync") as mock_fsync, + ): fileio.write("test.txt", test_content) @@ -74,9 +78,11 @@ class TestFileIOUnit(unittest.TestCase): """Test write fails when unable to acquire lock.""" m = mock_open() - with patch("builtins.open", m), patch( - "fcntl.flock", side_effect=BlockingIOError() - ), patch("time.sleep"): + with ( + patch("builtins.open", m), + patch("fcntl.flock", side_effect=BlockingIOError()), + patch("time.sleep"), + ): with self.assertRaises(common.FileLockError): fileio.write("test.txt", "content") @@ -91,8 +97,9 @@ class TestFileIOUnit(unittest.TestCase): raise OSError("mock release error") return None # Success for lock acquisition - with patch("builtins.open", m), patch( - "fcntl.flock", side_effect=flock_side_effect + with ( + patch("builtins.open", m), + patch("fcntl.flock", side_effect=flock_side_effect), ): # Should complete without raising exception result = fileio.read("test.txt") diff --git a/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py b/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py index 0c89e5d..e6960ec 100644 --- a/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py +++ b/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py @@ -79,8 +79,9 @@ class TestMutexUnit(unittest.TestCase): """Test mutex fails to acquire with IOError.""" m = mock_open() - with patch("builtins.open", m), patch( - "fcntl.flock", side_effect=IOError("mock error") + with ( + patch("builtins.open", m), + patch("fcntl.flock", side_effect=IOError("mock error")), ): with self.assertRaises(FileLockError) as cm: @@ -93,8 +94,9 @@ class TestMutexUnit(unittest.TestCase): """Test mutex fails to acquire when already locked.""" m = mock_open() - with patch("builtins.open", m), patch( - "fcntl.flock", side_effect=BlockingIOError() + with ( + patch("builtins.open", m), + patch("fcntl.flock", side_effect=BlockingIOError()), ): with self.assertRaises(FileLockError) as cm: @@ -112,8 +114,9 @@ class TestMutexUnit(unittest.TestCase): raise OSError("mock release error") return None - with patch("builtins.open", m), patch( - "fcntl.flock", side_effect=flock_side_effect + with ( + patch("builtins.open", m), + patch("fcntl.flock", side_effect=flock_side_effect), ): # Should complete without raising exception with Mutex("test.lock"): diff --git a/tests/aggregatetunnelmetrics/pipeline/test_processor.py b/tests/aggregatetunnelmetrics/pipeline/test_processor.py new file mode 100644 index 0000000..2933ad8 --- /dev/null +++ b/tests/aggregatetunnelmetrics/pipeline/test_processor.py @@ -0,0 +1,354 @@ +"""Tests for the metrics processing pipeline.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from datetime import datetime, timedelta, timezone +import os +import tempfile +import unittest +from unittest.mock import Mock, patch + +from aggregatetunnelmetrics.pipeline.processor import MetricsProcessor +from aggregatetunnelmetrics.pipeline.config import ProcessConfig, FileIOConfig +from aggregatetunnelmetrics.pipeline.windowpolicy import DailyPolicy, Window +from aggregatetunnelmetrics.fieldtestingcsv import Entry +from aggregatetunnelmetrics.oonireport import APIError + + +class TestMetricsProcessor(unittest.TestCase): + """Test MetricsProcessor functionality.""" + + def setUp(self): + """Set up test fixtures.""" + # Setup the filesystem + self.temp_dir = tempfile.mkdtemp() + self.csv_path = os.path.join(self.temp_dir, "test.csv") + self.state_path = os.path.join(self.temp_dir, "state.json") + + # Time fixture - using a fixed known time + self.reference_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + # Create basic configs + self.process_config = ProcessConfig( + provider="test-provider", + upstream_collector="test-collector", + probe_asn="AS12345", + probe_cc="XX", + collector_base_url="https://example.org/", + min_sample_size=10, # Small value for testing + ) + + self.fileio_config = FileIOConfig( + state_file=self.state_path, + num_retries=1, + sleep_interval=0.1, + ) + + self.window_policy = DailyPolicy() + + # Create processor instance + self.processor = MetricsProcessor( + self.process_config, + self.fileio_config, + self.window_policy, + ) + + # Create empty CSV file + with open(self.csv_path, "w") as f: + f.write("") + + # Setup collector mock + self.collector_patcher = patch( + "aggregatetunnelmetrics.oonireport.CollectorClient" + ) + self.mock_collector = self.collector_patcher.start() + self.mock_collector_instance = Mock() + self.mock_collector.return_value = self.mock_collector_instance + + # Setup file operations mock + self.read_patcher = patch("aggregatetunnelmetrics.lockedfile.read") + self.mock_read = self.read_patcher.start() + self.mock_read.return_value = "" + + def tearDown(self): + """Clean up temporary files.""" + self.collector_patcher.stop() + self.read_patcher.stop() + try: + os.unlink(self.csv_path) + os.unlink(self.state_path) + os.rmdir(self.temp_dir) + except FileNotFoundError: + pass + + def test_initialization(self): + """Test proper initialization of MetricsProcessor.""" + self.assertEqual(self.processor.process_config, self.process_config) + self.assertEqual(self.processor.fileio_config, self.fileio_config) + self.assertEqual(self.processor.window_policy, self.window_policy) + self.assertEqual( + self.processor.aggregator_config.provider, self.process_config.provider + ) + self.assertIsNone(self.processor._current_report_id) + + @patch("aggregatetunnelmetrics.fieldtestingcsv.parse_file") + @patch("aggregatetunnelmetrics.lockedfile.Mutex") + @patch("aggregatetunnelmetrics.pipeline.windowpolicy.generate_windows") + def test_empty_csv_processing(self, mock_generate_windows, mock_mutex, mock_parse): + """Test processing an empty CSV file.""" + # Set up empty window list + mock_generate_windows.return_value = [] + + # Mock parse_file to return empty list + mock_parse.return_value = [] + + # Mock mutex context manager + mock_mutex.return_value.__enter__.return_value = Mock() + mock_mutex.return_value.__exit__.return_value = None + + self.processor.process_csv_file(self.csv_path) + + # Verify mutex was used with correct path + mock_mutex.assert_called_once_with(f"{self.csv_path}.lock") + mock_parse.assert_not_called() + + def _create_test_entry(self, date: datetime, is_tunnel: bool = True) -> Entry: + """Helper to create a test entry.""" + return Entry( + filename="test.csv", + date=date, + asn="AS12345", + isp="Test ISP", + est_city="Test City", + user="testuser", + region="testregion", + server_fqdn="test.server.com", + server_ip="1.1.1.1", + mobile=False, + tunnel="tunnel" if is_tunnel else "baseline", + protocol="obfs4", + throughput_download=100.0, + throughput_upload=50.0, + latency_download=20.0, + latency_upload=25.0, + retransmission_download=0.01, + retransmission_upload=0.02, + ping_packets_loss=0.0, + ping_roundtrip_min=10.0, + ping_roundtrip_avg=15.0, + ping_roundtrip_max=20.0, + err_message="", + ) + + @patch("aggregatetunnelmetrics.fieldtestingcsv.parse_file") + @patch("aggregatetunnelmetrics.lockedfile.Mutex") + @patch("aggregatetunnelmetrics.pipeline.windowpolicy.generate_windows") + def test_successful_processing(self, mock_generate_windows, mock_mutex, mock_parse): + """Test successful processing of CSV with valid entries.""" + # Set up window + window = Window( + start=self.reference_time, + end=self.reference_time + timedelta(days=1), + delta=timedelta(days=1), + ) + mock_generate_windows.return_value = [window] + + # Set up entries + entries = [self._create_test_entry(self.reference_time + timedelta(hours=1))] + mock_parse.return_value = entries + + # Mock mutex context manager + mock_mutex.return_value.__enter__.return_value = Mock() + mock_mutex.return_value.__exit__.return_value = None + + # Set up successful submission + self.mock_collector_instance.create_report_from_measurement.return_value = ( + "test-report" + ) + + # Process the file + self.processor.process_csv_file(self.csv_path) + + # Verify submission happened + self.mock_collector_instance.create_report_from_measurement.assert_called_once() + + @patch("aggregatetunnelmetrics.fieldtestingcsv.parse_file") + @patch("aggregatetunnelmetrics.lockedfile.Mutex") + @patch("aggregatetunnelmetrics.pipeline.windowpolicy.generate_windows") + def test_skip_out_of_window_entries( + self, mock_generate_windows, mock_mutex, mock_parse + ): + """Test that entries outside the current window are skipped.""" + # Set up window + window = Window( + start=self.reference_time, + end=self.reference_time + timedelta(days=1), + delta=timedelta(days=1), + ) + mock_generate_windows.return_value = [window] + + # Create test entries outside the window + entries = [ + self._create_test_entry(window.start - timedelta(days=1)), + self._create_test_entry(window.end + timedelta(days=1)), + ] + mock_parse.return_value = entries + + # Mock mutex context manager + mock_mutex.return_value.__enter__.return_value = Mock() + mock_mutex.return_value.__exit__.return_value = None + + self.processor.process_csv_file(self.csv_path) + + # Verify no measurements were created + self.assertIsNone(self.processor._current_report_id) + + @patch("aggregatetunnelmetrics.fieldtestingcsv.parse_file") + @patch("aggregatetunnelmetrics.lockedfile.Mutex") + @patch("aggregatetunnelmetrics.pipeline.windowpolicy.generate_windows") + def test_skip_non_tunnel_entries( + self, mock_generate_windows, mock_mutex, mock_parse + ): + """Test that non-tunnel entries are skipped.""" + # Set up window + window = Window( + start=self.reference_time, + end=self.reference_time + timedelta(days=1), + delta=timedelta(days=1), + ) + mock_generate_windows.return_value = [window] + + # Create test entries with non-tunnel measurements + entries = [ + self._create_test_entry(window.start + timedelta(hours=1), is_tunnel=False), + self._create_test_entry(window.start + timedelta(hours=2), is_tunnel=False), + ] + mock_parse.return_value = entries + + # Mock mutex context manager + mock_mutex.return_value.__enter__.return_value = Mock() + mock_mutex.return_value.__exit__.return_value = None + + self.processor.process_csv_file(self.csv_path) + + # Verify no measurements were created + self.assertIsNone(self.processor._current_report_id) + + @patch("aggregatetunnelmetrics.fieldtestingcsv.parse_file") + @patch("aggregatetunnelmetrics.lockedfile.Mutex") + @patch("aggregatetunnelmetrics.pipeline.windowpolicy.generate_windows") + def test_submission_error_handling( + self, mock_generate_windows, mock_mutex, mock_parse + ): + """Test handling of measurement submission errors.""" + # Set up window + window = Window( + start=self.reference_time, + end=self.reference_time + timedelta(days=1), + delta=timedelta(days=1), + ) + mock_generate_windows.return_value = [window] + + # Create test entries + entries = [ + self._create_test_entry(self.reference_time + timedelta(hours=1)), + ] + mock_parse.return_value = entries + + # Mock mutex context manager + mock_mutex.return_value.__enter__.return_value = Mock() + mock_mutex.return_value.__exit__.return_value = None + + # Set up submission failure + self.mock_collector_instance.create_report_from_measurement.side_effect = ( + APIError("test error") + ) + + # Process should raise the error + with self.assertRaises(APIError): + self.processor.process_csv_file(self.csv_path) + + @patch("aggregatetunnelmetrics.fieldtestingcsv.parse_file") + @patch("aggregatetunnelmetrics.lockedfile.Mutex") + def test_mutex_error_handling(self, mock_mutex, mock_parse): + """Test handling of mutex acquisition errors.""" + # Mock mutex acquisition failure + mock_mutex.return_value.__enter__.side_effect = Exception("Lock Error") + + with self.assertRaises(Exception): + self.processor.process_csv_file(self.csv_path) + + def test_state_persistence(self): + """Test that state is properly persisted between processing runs.""" + # Create initial state + self.processor.state.next_submission_after = self.reference_time + self.processor.state.save( + self.state_path, self.fileio_config.as_lockedfile_fileio_config() + ) + + # Create new processor instance + new_processor = MetricsProcessor( + self.process_config, + self.fileio_config, + self.window_policy, + ) + + # Verify state was loaded + self.assertEqual(new_processor.state.next_submission_after, self.reference_time) + + @patch("aggregatetunnelmetrics.fieldtestingcsv.parse_file") + @patch("aggregatetunnelmetrics.lockedfile.Mutex") + @patch("aggregatetunnelmetrics.pipeline.windowpolicy.generate_windows") + def test_state_unchanged_on_submission_failure( + self, mock_generate_windows, mock_mutex, mock_parse + ): + """Test that state is not updated when measurement submission fails.""" + # Set up initial state + self.processor.state.next_submission_after = self.reference_time + self.processor.state.save( + self.state_path, self.fileio_config.as_lockedfile_fileio_config() + ) + + # Set up window + window = Window( + start=self.reference_time, + end=self.reference_time + timedelta(days=1), + delta=timedelta(days=1), + ) + mock_generate_windows.return_value = [window] + + # Create test entries + entries = [ + self._create_test_entry(self.reference_time + timedelta(hours=1)), + ] + mock_parse.return_value = entries + + # Mock mutex context manager + mock_mutex.return_value.__enter__.return_value = Mock() + mock_mutex.return_value.__exit__.return_value = None + + # Set up submission failure + self.mock_collector_instance.create_report_from_measurement.side_effect = ( + APIError("test error") + ) + + # Process should raise the error + with self.assertRaises(APIError): + self.processor.process_csv_file(self.csv_path) + + # Verify state hasn't changed + self.assertEqual( + self.processor.state.next_submission_after, self.reference_time + ) + + # Load state from disk and verify it hasn't changed + new_processor = MetricsProcessor( + self.process_config, + self.fileio_config, + self.window_policy, + ) + self.assertEqual(new_processor.state.next_submission_after, self.reference_time) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/aggregatetunnelmetrics/spec/__init__.py b/tests/aggregatetunnelmetrics/spec/__init__.py new file mode 100644 index 0000000..421f8fc --- /dev/null +++ b/tests/aggregatetunnelmetrics/spec/__init__.py @@ -0,0 +1,3 @@ +"""Tests for the aggregatetunnelmetrics.spec package.""" + +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/tests/aggregatetunnelmetrics/spec/test_aggregator.py b/tests/aggregatetunnelmetrics/spec/test_aggregator.py new file mode 100644 index 0000000..eeada48 --- /dev/null +++ b/tests/aggregatetunnelmetrics/spec/test_aggregator.py @@ -0,0 +1,23 @@ +"""Tests for the aggregatetunnelmetrics.spec.aggregator module.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from typing import Iterator + +from aggregatetunnelmetrics.spec import ( + aggregator, + fieldtesting, + oonicollector, +) + + +def test_logic_protocol(): + class TestLogic: + def aggregate( + self, + entries: Iterator[fieldtesting.Entry], + ) -> Iterator[oonicollector.Measurement]: + yield from [] + + # Test that our class implements the Logic protocol + assert isinstance(TestLogic(), aggregator.Logic) diff --git a/tests/aggregatetunnelmetrics/spec/test_fieldtesting.py b/tests/aggregatetunnelmetrics/spec/test_fieldtesting.py new file mode 100644 index 0000000..3195fd3 --- /dev/null +++ b/tests/aggregatetunnelmetrics/spec/test_fieldtesting.py @@ -0,0 +1,42 @@ +"""Tests for the aggregatetunnelmetrics.spec.fieldtesting module.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from datetime import datetime + +from aggregatetunnelmetrics.spec import fieldtesting + + +def test_entry_tunnel_measurement(): + entry = fieldtesting.Entry( + filename="test.csv", + date=datetime.now(), + asn="AS12345", + isp="Test ISP", + est_city="Test City", + user="test_user", + region="Test Region", + server_fqdn="test.server.com", + server_ip="1.1.1.1", + mobile=False, + tunnel="tunnel", # successful tunnel measurement + throughput_download=10.0, + throughput_upload=5.0, + latency_download=100.0, + latency_upload=150.0, + retransmission_download=0.1, + retransmission_upload=0.2, + ping_packets_loss=0.0, + ping_roundtrip_min=10.0, + ping_roundtrip_avg=15.0, + ping_roundtrip_max=20.0, + err_message="", + protocol="openvpn", + ping_target_address="8.8.8.8", + ndt_target_hostname="ndt.server.com", + ndt_target_address="2.2.2.2", + ndt_target_port=3001, + ) + + assert entry.is_tunnel_measurement() is True + assert entry.is_tunnel_error_measurement() is False diff --git a/tests/aggregatetunnelmetrics/spec/test_filelocking.py b/tests/aggregatetunnelmetrics/spec/test_filelocking.py new file mode 100644 index 0000000..56a23a9 --- /dev/null +++ b/tests/aggregatetunnelmetrics/spec/test_filelocking.py @@ -0,0 +1,42 @@ +"""Tests for the aggregatetunnelmetrics.spec.filelocking module.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from aggregatetunnelmetrics.spec import filelocking + + +def test_readwrite_config(): + config = filelocking.ReadWriteConfig() + assert config.num_retries == 10 + assert config.sleep_interval == 0.1 + + custom_config = filelocking.ReadWriteConfig(num_retries=5, sleep_interval=0.2) + assert custom_config.num_retries == 5 + assert custom_config.sleep_interval == 0.2 + + +class FakeMutex: + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + +def test_mutex_protocol(): + assert isinstance(FakeMutex(), filelocking.Mutex) + + +class FakeAPI: + def readfile(self, config: filelocking.ReadWriteConfig | None) -> str: + return "" + + def writefile(self, data: str, config: filelocking.ReadWriteConfig | None) -> None: + pass + + def mutex(self, filepath: str) -> filelocking.Mutex: + return FakeMutex() + + +def test_api_protocol(): + assert isinstance(FakeAPI(), filelocking.API) diff --git a/tests/aggregatetunnelmetrics/spec/test_metrics.py b/tests/aggregatetunnelmetrics/spec/test_metrics.py new file mode 100644 index 0000000..e0b26d3 --- /dev/null +++ b/tests/aggregatetunnelmetrics/spec/test_metrics.py @@ -0,0 +1,175 @@ +"""Tests for the aggregatetunnelmetrics.spec.metrics module.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from datetime import datetime, timezone + +from aggregatetunnelmetrics.spec import metrics + + +def test_global_scope(): + scope = metrics.GlobalScope(protocol="openvpn") + dct = {} + scope.update_dict(dct) + assert dct == {"scope": "global", "protocol": "openvpn"} + + +def test_distribution(): + dist = metrics.Distribution(p25=10.0, p50=20.0, p75=30.0, p99=40.0) + assert dist.as_dict() == {"25p": 10.0, "50p": 20.0, "75p": 30.0, "99p": 40.0} + + +def test_tunnel_ping_statement(): + dist = metrics.Distribution(p25=10.0, p50=20.0, p75=30.0, p99=40.0) + stmt = metrics.TunnelPingStatement( + target_address="8.8.8.8", + sample_size=100, + latency_avg=dist, + latency_min=dist, + latency_max=dist, + loss=dist, + ) + + assert stmt.as_dict() == { + "phase": "tunnel_ping", + "target_address": "8.8.8.8", + "sample_size": 100, + "type": "ping", + "latency_avg_ms": {"25p": 10.0, "50p": 20.0, "75p": 30.0, "99p": 40.0}, + "latency_min_ms": {"25p": 10.0, "50p": 20.0, "75p": 30.0, "99p": 40.0}, + "latency_max_ms": {"25p": 10.0, "50p": 20.0, "75p": 30.0, "99p": 40.0}, + "loss": {"25p": 10.0, "50p": 20.0, "75p": 30.0, "99p": 40.0}, + } + + +def test_endpoint_pool_scope(): + scope = metrics.EndpointPoolScope(protocol="openvpn", cc="US") + dct = {} + scope.update_dict(dct) + assert dct == {"scope": "endpoint_pool", "protocol": "openvpn", "cc": "US"} + + +def test_endpoint_scope(): + scope = metrics.EndpointScope( + protocol="openvpn", + cc="US", + asn="AS12345", + endpoint_hostname="test.com", + endpoint_address="1.1.1.1", + endpoint_port=443, + ) + dct = {} + scope.update_dict(dct) + assert dct == { + "scope": "endpoint", + "protocol": "openvpn", + "cc": "US", + "asn": "AS12345", + "endpoint_hostname": "test.com", + "endpoint_address": "1.1.1.1", + "endpoint_port": 443, + } + + +def test_network_error_statement(): + stmt = metrics.NetworkErrorStatement( + sample_size=100, failure_ratio=0.1, error="connection_failed" + ) + assert stmt.as_dict() == { + "phase": "creation", + "sample_size": 100, + "type": "network-error", + "failure_ratio": 0.1, + "error": "connection_failed", + } + + +def test_tunnel_ndt_statement(): + dist = metrics.Distribution(p25=10.0, p50=20.0, p75=30.0, p99=40.0) + stmt = metrics.TunnelNDTStatement( + direction="download", + target_hostname="test.com", + target_address="1.1.1.1", + target_port=443, + sample_size=100, + latency=dist, + speed=dist, + rexmit=dist, + ) + assert stmt.as_dict() == { + "phase": "tunnel_ndt_download", + "target_hostname": "test.com", + "target_address": "1.1.1.1", + "target_port": 443, + "sample_size": 100, + "type": "ndt_download", + "latency_ms": dist.as_dict(), + "speed_mbits": dist.as_dict(), + "rexmit": dist.as_dict(), + } + + +def test_tunnel_ndt_statement_none_values(): + stmt = metrics.TunnelNDTStatement( + direction="upload", + target_hostname=None, + target_address=None, + target_port=None, + sample_size=None, + latency=None, + speed=None, + rexmit=None, + ) + assert stmt.as_dict() == { + "phase": "tunnel_ndt_upload", + "target_hostname": None, + "target_address": None, + "target_port": None, + "sample_size": None, + "type": "ndt_download", + "latency_ms": None, + "speed_mbits": None, + "rexmit": None, + } + + +def test_time_window_as_dict(): + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + window = metrics.TimeWindow(start=start, end=end) + + assert window.as_dict() == {"start": "20240101T000000Z", "end": "20240102T000000Z"} + + +def test_measurement_test_keys_as_dict(): + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + window = metrics.TimeWindow(start=start, end=end) + + test_keys = metrics.MeasurementTestKeys( + provider="test_provider", + scope=metrics.GlobalScope(protocol="openvpn"), + time_window=window, + bodies=[], + ) + + assert test_keys.as_dict() == { + "provider": "test_provider", + "scope": "global", + "protocol": "openvpn", + "time_window": {"start": "20240101T000000Z", "end": "20240102T000000Z"}, + "bodies": [], + } + + +def test_format_datetime(): + dt = datetime(2024, 1, 1, 12, 34, 56, tzinfo=timezone.utc) + assert metrics.format_datetime(dt) == "20240101T123456Z" + + # Test with non-UTC timezone + from datetime import timedelta + + pacific = timezone(timedelta(hours=-8)) + dt = datetime(2024, 1, 1, 12, 34, 56, tzinfo=pacific) + # Should convert to UTC + assert metrics.format_datetime(dt) == "20240101T203456Z" diff --git a/tests/aggregatetunnelmetrics/spec/test_oonicollector.py b/tests/aggregatetunnelmetrics/spec/test_oonicollector.py new file mode 100644 index 0000000..63f9397 --- /dev/null +++ b/tests/aggregatetunnelmetrics/spec/test_oonicollector.py @@ -0,0 +1,153 @@ +"""Tests for the aggregatetunnelmetrics.spec.oonicollector module.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + + +from datetime import datetime, timezone + +from aggregatetunnelmetrics.spec import oonicollector + + +class MockTestKeys: + def as_dict(self): + return {"test": "data"} + + +def test_measurement_as_dict(): + now = datetime(2024, 1, 1, tzinfo=timezone.utc) + measurement = oonicollector.Measurement( + annotations={"source": "test"}, + data_format_version="0.2.0", + input="openvpn://testprovider/?param=value", + measurement_start_time=now, + probe_asn="AS12345", + probe_cc="US", + software_name="test_software", + software_version="1.0.0", + test_keys=MockTestKeys(), + test_name="aggregate_tunnel_metrics", + test_runtime=1.5, + test_start_time=now, + test_version="0.1.0", + ) + + result = measurement.as_dict() + assert result["data_format_version"] == "0.2.0" + assert result["probe_asn"] == "AS12345" + assert result["test_keys"] == {"test": "data"} + + +def test_measurement_with_report_id(): + now = datetime(2024, 1, 1, tzinfo=timezone.utc) + measurement = oonicollector.Measurement( + annotations={}, + data_format_version="0.2.0", + input="test", + measurement_start_time=now, + probe_asn="AS12345", + probe_cc="US", + software_name="test", + software_version="1.0", + test_keys=MockTestKeys(), + test_name="test", + test_runtime=1.0, + test_start_time=now, + test_version="1.0", + ) + + new_measurement = measurement.with_report_id("test_report_id") + assert new_measurement.report_id == "test_report_id" + + +def test_measurement_with_optional_fields(): + now = datetime(2024, 1, 1, tzinfo=timezone.utc) + measurement = oonicollector.Measurement( + annotations={"source": "test"}, + data_format_version="0.2.0", + input="openvpn://testprovider/?param=value", + measurement_start_time=now, + probe_asn="AS12345", + probe_cc="US", + software_name="test_software", + software_version="1.0.0", + test_keys=MockTestKeys(), + test_name="aggregate_tunnel_metrics", + test_runtime=1.5, + test_start_time=now, + test_version="0.1.0", + options=["option1", "option2"], + probe_network_name="Test Network", + resolver_asn="AS67890", + resolver_cc="UK", + resolver_ip="1.1.1.1", + resolver_network_name="Test Resolver", + test_helpers={"helper1": "value1"}, + ) + + result = measurement.as_dict() + # First verify all mandatory fields + assert result["data_format_version"] == "0.2.0" + assert result["probe_asn"] == "AS12345" + + # Then verify optional fields are present with correct values + assert result["options"] == ["option1", "option2"] + assert result["probe_network_name"] == "Test Network" + assert result["resolver_asn"] == "AS67890" + assert result["resolver_cc"] == "UK" + assert result["resolver_ip"] == "1.1.1.1" + assert result["resolver_network_name"] == "Test Resolver" + assert result["test_helpers"] == {"helper1": "value1"} + + +def test_measurement_as_open_report_request(): + now = datetime(2024, 1, 1, tzinfo=timezone.utc) + measurement = oonicollector.Measurement( + annotations={}, + data_format_version="0.2.0", + input="test_input", + measurement_start_time=now, + probe_asn="AS12345", + probe_cc="US", + software_name="test_software", + software_version="1.0.0", + test_keys=MockTestKeys(), + test_name="test_name", + test_runtime=1.5, + test_start_time=now, + test_version="0.1.0", + ) + + report_request = measurement.as_open_report_request() + assert isinstance(report_request, oonicollector.OpenReportRequest) + assert report_request.probe_asn == "AS12345" + assert report_request.probe_cc == "US" + assert report_request.software_name == "test_software" + assert report_request.software_version == "1.0.0" + assert report_request.test_name == "test_name" + assert report_request.test_version == "0.1.0" + assert report_request.test_start_time == now + + +def test_open_report_request_as_dict(): + now = datetime(2024, 1, 1, tzinfo=timezone.utc) + request = oonicollector.OpenReportRequest( + probe_asn="AS12345", + probe_cc="US", + software_name="test_software", + software_version="1.0.0", + test_name="test_name", + test_start_time=now, + test_version="0.1.0", + ) + + assert request.as_dict() == { + "data_format_version": "0.2.0", + "format": "json", + "probe_asn": "AS12345", + "probe_cc": "US", + "software_name": "test_software", + "software_version": "1.0.0", + "test_name": "test_name", + "test_start_time": "2024-01-01 00:00:00", + "test_version": "0.1.0", + } diff --git a/uv.lock b/uv.lock new file mode 100644 index 0000000..6686524 --- /dev/null +++ b/uv.lock @@ -0,0 +1,190 @@ +version = 1 +revision = 1 +requires-python = ">=3.12" + +[[package]] +name = "aggregatetunnelmetrics" +version = "0.1.0" +source = { editable = "." } + +[package.dev-dependencies] +dev = [ + { name = "black" }, + { name = "pytest" }, + { name = "pytest-cov" }, +] + +[package.metadata] + +[package.metadata.requires-dev] +dev = [ + { name = "black", specifier = ">=24.10.0" }, + { name = "pytest" }, + { name = "pytest-cov" }, +] + +[[package]] +name = "black" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "mypy-extensions" }, + { name = "packaging" }, + { name = "pathspec" }, + { name = "platformdirs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/49/26a7b0f3f35da4b5a65f081943b7bcd22d7002f5f0fb8098ec1ff21cb6ef/black-25.1.0.tar.gz", hash = "sha256:33496d5cd1222ad73391352b4ae8da15253c5de89b93a80b3e2c8d9a19ec2666", size = 649449 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/71/3fe4741df7adf015ad8dfa082dd36c94ca86bb21f25608eb247b4afb15b2/black-25.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4b60580e829091e6f9238c848ea6750efed72140b91b048770b64e74fe04908b", size = 1650988 }, + { url = "https://files.pythonhosted.org/packages/13/f3/89aac8a83d73937ccd39bbe8fc6ac8860c11cfa0af5b1c96d081facac844/black-25.1.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e2978f6df243b155ef5fa7e558a43037c3079093ed5d10fd84c43900f2d8ecc", size = 1453985 }, + { url = "https://files.pythonhosted.org/packages/6f/22/b99efca33f1f3a1d2552c714b1e1b5ae92efac6c43e790ad539a163d1754/black-25.1.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:3b48735872ec535027d979e8dcb20bf4f70b5ac75a8ea99f127c106a7d7aba9f", size = 1783816 }, + { url = "https://files.pythonhosted.org/packages/18/7e/a27c3ad3822b6f2e0e00d63d58ff6299a99a5b3aee69fa77cd4b0076b261/black-25.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:ea0213189960bda9cf99be5b8c8ce66bb054af5e9e861249cd23471bd7b0b3ba", size = 1440860 }, + { url = "https://files.pythonhosted.org/packages/98/87/0edf98916640efa5d0696e1abb0a8357b52e69e82322628f25bf14d263d1/black-25.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8f0b18a02996a836cc9c9c78e5babec10930862827b1b724ddfe98ccf2f2fe4f", size = 1650673 }, + { url = "https://files.pythonhosted.org/packages/52/e5/f7bf17207cf87fa6e9b676576749c6b6ed0d70f179a3d812c997870291c3/black-25.1.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:afebb7098bfbc70037a053b91ae8437c3857482d3a690fefc03e9ff7aa9a5fd3", size = 1453190 }, + { url = "https://files.pythonhosted.org/packages/e3/ee/adda3d46d4a9120772fae6de454c8495603c37c4c3b9c60f25b1ab6401fe/black-25.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:030b9759066a4ee5e5aca28c3c77f9c64789cdd4de8ac1df642c40b708be6171", size = 1782926 }, + { url = "https://files.pythonhosted.org/packages/cc/64/94eb5f45dcb997d2082f097a3944cfc7fe87e071907f677e80788a2d7b7a/black-25.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:a22f402b410566e2d1c950708c77ebf5ebd5d0d88a6a2e87c86d9fb48afa0d18", size = 1442613 }, + { url = "https://files.pythonhosted.org/packages/09/71/54e999902aed72baf26bca0d50781b01838251a462612966e9fc4891eadd/black-25.1.0-py3-none-any.whl", hash = "sha256:95e8176dae143ba9097f351d174fdaf0ccd29efb414b362ae3fd72bf0f710717", size = 207646 }, +] + +[[package]] +name = "click" +version = "8.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7e/d4/7ebdbd03970677812aac39c869717059dbb71a4cfc033ca6e5221787892c/click-8.1.8-py3-none-any.whl", hash = "sha256:63c132bbbed01578a06712a2d1f497bb62d9c1c0d329b7903a866228027263b2", size = 98188 }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, +] + +[[package]] +name = "coverage" +version = "7.6.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/d6/2b53ab3ee99f2262e6f0b8369a43f6d66658eab45510331c0b3d5c8c4272/coverage-7.6.12.tar.gz", hash = "sha256:48cfc4641d95d34766ad41d9573cc0f22a48aa88d22657a1fe01dca0dbae4de2", size = 805941 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/7f/4af2ed1d06ce6bee7eafc03b2ef748b14132b0bdae04388e451e4b2c529b/coverage-7.6.12-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b172f8e030e8ef247b3104902cc671e20df80163b60a203653150d2fc204d1ad", size = 208645 }, + { url = "https://files.pythonhosted.org/packages/dc/60/d19df912989117caa95123524d26fc973f56dc14aecdec5ccd7d0084e131/coverage-7.6.12-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:641dfe0ab73deb7069fb972d4d9725bf11c239c309ce694dd50b1473c0f641c3", size = 208898 }, + { url = "https://files.pythonhosted.org/packages/bd/10/fecabcf438ba676f706bf90186ccf6ff9f6158cc494286965c76e58742fa/coverage-7.6.12-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0e549f54ac5f301e8e04c569dfdb907f7be71b06b88b5063ce9d6953d2d58574", size = 242987 }, + { url = "https://files.pythonhosted.org/packages/4c/53/4e208440389e8ea936f5f2b0762dcd4cb03281a7722def8e2bf9dc9c3d68/coverage-7.6.12-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:959244a17184515f8c52dcb65fb662808767c0bd233c1d8a166e7cf74c9ea985", size = 239881 }, + { url = "https://files.pythonhosted.org/packages/c4/47/2ba744af8d2f0caa1f17e7746147e34dfc5f811fb65fc153153722d58835/coverage-7.6.12-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bda1c5f347550c359f841d6614fb8ca42ae5cb0b74d39f8a1e204815ebe25750", size = 242142 }, + { url = "https://files.pythonhosted.org/packages/e9/90/df726af8ee74d92ee7e3bf113bf101ea4315d71508952bd21abc3fae471e/coverage-7.6.12-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1ceeb90c3eda1f2d8c4c578c14167dbd8c674ecd7d38e45647543f19839dd6ea", size = 241437 }, + { url = "https://files.pythonhosted.org/packages/f6/af/995263fd04ae5f9cf12521150295bf03b6ba940d0aea97953bb4a6db3e2b/coverage-7.6.12-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:0f16f44025c06792e0fb09571ae454bcc7a3ec75eeb3c36b025eccf501b1a4c3", size = 239724 }, + { url = "https://files.pythonhosted.org/packages/1c/8e/5bb04f0318805e190984c6ce106b4c3968a9562a400180e549855d8211bd/coverage-7.6.12-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b076e625396e787448d27a411aefff867db2bffac8ed04e8f7056b07024eed5a", size = 241329 }, + { url = "https://files.pythonhosted.org/packages/9e/9d/fa04d9e6c3f6459f4e0b231925277cfc33d72dfab7fa19c312c03e59da99/coverage-7.6.12-cp312-cp312-win32.whl", hash = "sha256:00b2086892cf06c7c2d74983c9595dc511acca00665480b3ddff749ec4fb2a95", size = 211289 }, + { url = "https://files.pythonhosted.org/packages/53/40/53c7ffe3c0c3fff4d708bc99e65f3d78c129110d6629736faf2dbd60ad57/coverage-7.6.12-cp312-cp312-win_amd64.whl", hash = "sha256:7ae6eabf519bc7871ce117fb18bf14e0e343eeb96c377667e3e5dd12095e0288", size = 212079 }, + { url = "https://files.pythonhosted.org/packages/76/89/1adf3e634753c0de3dad2f02aac1e73dba58bc5a3a914ac94a25b2ef418f/coverage-7.6.12-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:488c27b3db0ebee97a830e6b5a3ea930c4a6e2c07f27a5e67e1b3532e76b9ef1", size = 208673 }, + { url = "https://files.pythonhosted.org/packages/ce/64/92a4e239d64d798535c5b45baac6b891c205a8a2e7c9cc8590ad386693dc/coverage-7.6.12-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5d1095bbee1851269f79fd8e0c9b5544e4c00c0c24965e66d8cba2eb5bb535fd", size = 208945 }, + { url = "https://files.pythonhosted.org/packages/b4/d0/4596a3ef3bca20a94539c9b1e10fd250225d1dec57ea78b0867a1cf9742e/coverage-7.6.12-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0533adc29adf6a69c1baa88c3d7dbcaadcffa21afbed3ca7a225a440e4744bf9", size = 242484 }, + { url = "https://files.pythonhosted.org/packages/1c/ef/6fd0d344695af6718a38d0861408af48a709327335486a7ad7e85936dc6e/coverage-7.6.12-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:53c56358d470fa507a2b6e67a68fd002364d23c83741dbc4c2e0680d80ca227e", size = 239525 }, + { url = "https://files.pythonhosted.org/packages/0c/4b/373be2be7dd42f2bcd6964059fd8fa307d265a29d2b9bcf1d044bcc156ed/coverage-7.6.12-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:64cbb1a3027c79ca6310bf101014614f6e6e18c226474606cf725238cf5bc2d4", size = 241545 }, + { url = "https://files.pythonhosted.org/packages/a6/7d/0e83cc2673a7790650851ee92f72a343827ecaaea07960587c8f442b5cd3/coverage-7.6.12-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:79cac3390bfa9836bb795be377395f28410811c9066bc4eefd8015258a7578c6", size = 241179 }, + { url = "https://files.pythonhosted.org/packages/ff/8c/566ea92ce2bb7627b0900124e24a99f9244b6c8c92d09ff9f7633eb7c3c8/coverage-7.6.12-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:9b148068e881faa26d878ff63e79650e208e95cf1c22bd3f77c3ca7b1d9821a3", size = 239288 }, + { url = "https://files.pythonhosted.org/packages/7d/e4/869a138e50b622f796782d642c15fb5f25a5870c6d0059a663667a201638/coverage-7.6.12-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:8bec2ac5da793c2685ce5319ca9bcf4eee683b8a1679051f8e6ec04c4f2fd7dc", size = 241032 }, + { url = "https://files.pythonhosted.org/packages/ae/28/a52ff5d62a9f9e9fe9c4f17759b98632edd3a3489fce70154c7d66054dd3/coverage-7.6.12-cp313-cp313-win32.whl", hash = "sha256:200e10beb6ddd7c3ded322a4186313d5ca9e63e33d8fab4faa67ef46d3460af3", size = 211315 }, + { url = "https://files.pythonhosted.org/packages/bc/17/ab849b7429a639f9722fa5628364c28d675c7ff37ebc3268fe9840dda13c/coverage-7.6.12-cp313-cp313-win_amd64.whl", hash = "sha256:2b996819ced9f7dbb812c701485d58f261bef08f9b85304d41219b1496b591ef", size = 212099 }, + { url = "https://files.pythonhosted.org/packages/d2/1c/b9965bf23e171d98505eb5eb4fb4d05c44efd256f2e0f19ad1ba8c3f54b0/coverage-7.6.12-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:299cf973a7abff87a30609879c10df0b3bfc33d021e1adabc29138a48888841e", size = 209511 }, + { url = "https://files.pythonhosted.org/packages/57/b3/119c201d3b692d5e17784fee876a9a78e1b3051327de2709392962877ca8/coverage-7.6.12-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:4b467a8c56974bf06e543e69ad803c6865249d7a5ccf6980457ed2bc50312703", size = 209729 }, + { url = "https://files.pythonhosted.org/packages/52/4e/a7feb5a56b266304bc59f872ea07b728e14d5a64f1ad3a2cc01a3259c965/coverage-7.6.12-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2458f275944db8129f95d91aee32c828a408481ecde3b30af31d552c2ce284a0", size = 253988 }, + { url = "https://files.pythonhosted.org/packages/65/19/069fec4d6908d0dae98126aa7ad08ce5130a6decc8509da7740d36e8e8d2/coverage-7.6.12-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a9d8be07fb0832636a0f72b80d2a652fe665e80e720301fb22b191c3434d924", size = 249697 }, + { url = "https://files.pythonhosted.org/packages/1c/da/5b19f09ba39df7c55f77820736bf17bbe2416bbf5216a3100ac019e15839/coverage-7.6.12-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:14d47376a4f445e9743f6c83291e60adb1b127607a3618e3185bbc8091f0467b", size = 252033 }, + { url = "https://files.pythonhosted.org/packages/1e/89/4c2750df7f80a7872267f7c5fe497c69d45f688f7b3afe1297e52e33f791/coverage-7.6.12-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:b95574d06aa9d2bd6e5cc35a5bbe35696342c96760b69dc4287dbd5abd4ad51d", size = 251535 }, + { url = "https://files.pythonhosted.org/packages/78/3b/6d3ae3c1cc05f1b0460c51e6f6dcf567598cbd7c6121e5ad06643974703c/coverage-7.6.12-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:ecea0c38c9079570163d663c0433a9af4094a60aafdca491c6a3d248c7432827", size = 249192 }, + { url = "https://files.pythonhosted.org/packages/6e/8e/c14a79f535ce41af7d436bbad0d3d90c43d9e38ec409b4770c894031422e/coverage-7.6.12-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:2251fabcfee0a55a8578a9d29cecfee5f2de02f11530e7d5c5a05859aa85aee9", size = 250627 }, + { url = "https://files.pythonhosted.org/packages/cb/79/b7cee656cfb17a7f2c1b9c3cee03dd5d8000ca299ad4038ba64b61a9b044/coverage-7.6.12-cp313-cp313t-win32.whl", hash = "sha256:eb5507795caabd9b2ae3f1adc95f67b1104971c22c624bb354232d65c4fc90b3", size = 212033 }, + { url = "https://files.pythonhosted.org/packages/b6/c3/f7aaa3813f1fa9a4228175a7bd368199659d392897e184435a3b66408dd3/coverage-7.6.12-cp313-cp313t-win_amd64.whl", hash = "sha256:f60a297c3987c6c02ffb29effc70eadcbb412fe76947d394a1091a3615948e2f", size = 213240 }, + { url = "https://files.pythonhosted.org/packages/fb/b2/f655700e1024dec98b10ebaafd0cedbc25e40e4abe62a3c8e2ceef4f8f0a/coverage-7.6.12-py3-none-any.whl", hash = "sha256:eb8668cfbc279a536c633137deeb9435d2962caec279c3f8cf8b91fff6ff8953", size = 200552 }, +] + +[[package]] +name = "iniconfig" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d7/4b/cbd8e699e64a6f16ca3a8220661b5f83792b3017d0f79807cb8708d33913/iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3", size = 4646 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, +] + +[[package]] +name = "mypy-extensions" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/98/a4/1ab47638b92648243faf97a5aeb6ea83059cc3624972ab6b8d2316078d3f/mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782", size = 4433 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2a/e2/5d3f6ada4297caebe1a2add3b126fe800c96f56dbe5d1988a2cbe0b267aa/mypy_extensions-1.0.0-py3-none-any.whl", hash = "sha256:4392f6c0eb8a5668a69e23d168ffa70f0be9ccfd32b5cc2d26a34ae5b844552d", size = 4695 }, +] + +[[package]] +name = "packaging" +version = "24.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/63/68dbb6eb2de9cb10ee4c9c14a0148804425e13c4fb20d61cce69f53106da/packaging-24.2.tar.gz", hash = "sha256:c228a6dc5e932d346bc5739379109d49e8853dd8223571c7c5b55260edc0b97f", size = 163950 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/ef/eb23f262cca3c0c4eb7ab1933c3b1f03d021f2c48f54763065b6f0e321be/packaging-24.2-py3-none-any.whl", hash = "sha256:09abb1bccd265c01f4a3aa3f7a7db064b36514d2cba19a2f694fe6150451a759", size = 65451 }, +] + +[[package]] +name = "pathspec" +version = "0.12.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ca/bc/f35b8446f4531a7cb215605d100cd88b7ac6f44ab3fc94870c120ab3adbf/pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712", size = 51043 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/20/ff623b09d963f88bfde16306a54e12ee5ea43e9b597108672ff3a408aad6/pathspec-0.12.1-py3-none-any.whl", hash = "sha256:a0d503e138a4c123b27490a4f7beda6a01c6f288df0e4a8b79c7eb0dc7b4cc08", size = 31191 }, +] + +[[package]] +name = "platformdirs" +version = "4.3.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/13/fc/128cc9cb8f03208bdbf93d3aa862e16d376844a14f9a0ce5cf4507372de4/platformdirs-4.3.6.tar.gz", hash = "sha256:357fb2acbc885b0419afd3ce3ed34564c13c9b95c89360cd9563f73aa5e2b907", size = 21302 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3c/a6/bc1012356d8ece4d66dd75c4b9fc6c1f6650ddd5991e421177d9f8f671be/platformdirs-4.3.6-py3-none-any.whl", hash = "sha256:73e575e1408ab8103900836b97580d5307456908a03e92031bab39e4554cc3fb", size = 18439 }, +] + +[[package]] +name = "pluggy" +version = "1.5.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/96/2d/02d4312c973c6050a18b314a5ad0b3210edb65a906f868e31c111dede4a6/pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1", size = 67955 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, +] + +[[package]] +name = "pytest" +version = "8.3.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/35/30e0d83068951d90a01852cb1cef56e5d8a09d20c7f511634cc2f7e0372a/pytest-8.3.4.tar.gz", hash = "sha256:965370d062bce11e73868e0335abac31b4d3de0e82f4007408d242b4f8610761", size = 1445919 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 }, +] + +[[package]] +name = "pytest-cov" +version = "6.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "coverage" }, + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/be/45/9b538de8cef30e17c7b45ef42f538a94889ed6a16f2387a6c89e73220651/pytest-cov-6.0.0.tar.gz", hash = "sha256:fde0b595ca248bb8e2d76f020b465f3b107c9632e6a1d1705f17834c89dcadc0", size = 66945 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/36/3b/48e79f2cd6a61dbbd4807b4ed46cb564b4fd50a76166b1c4ea5c1d9e2371/pytest_cov-6.0.0-py3-none-any.whl", hash = "sha256:eee6f1b9e61008bd34975a4d5bab25801eb31898b032dd55addc93e96fcaaa35", size = 22949 }, +] -- GitLab From 4ec8892a018f13a9a6b495f8b3d6fc0f3976c064 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 23 Feb 2025 16:55:59 +0100 Subject: [PATCH 54/75] refactor(lockedfile): use protocols and migrate to pytest --- .gitignore | 2 +- aggregatetunnelmetrics/lockedfile/__init__.py | 9 +- aggregatetunnelmetrics/lockedfile/api.py | 40 ++ aggregatetunnelmetrics/lockedfile/common.py | 6 +- aggregatetunnelmetrics/lockedfile/fileio.py | 28 +- aggregatetunnelmetrics/pipeline/config.py | 4 +- aggregatetunnelmetrics/pipeline/state.py | 4 +- aggregatetunnelmetrics/spec/filelocking.py | 6 +- pyproject.toml | 3 + .../lockedfile/__init__.py | 3 + .../lockedfile/test_api.py | 125 ++++++ .../lockedfile/test_fileio.py | 382 +++++++++-------- .../lockedfile/test_mutex.py | 385 +++++++++--------- .../pipeline/test_config.py | 2 +- .../pipeline/test_state.py | 4 +- tests/aggregatetunnelmetrics/spec/__init__.py | 2 +- .../spec/test_aggregator.py | 2 +- .../spec/test_fieldtesting.py | 2 +- .../spec/test_filelocking.py | 10 +- .../spec/test_metrics.py | 2 +- .../spec/test_oonicollector.py | 2 +- 21 files changed, 613 insertions(+), 410 deletions(-) create mode 100644 aggregatetunnelmetrics/lockedfile/api.py create mode 100644 tests/aggregatetunnelmetrics/lockedfile/test_api.py diff --git a/.gitignore b/.gitignore index 9633c92..806f330 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -.coverage +.coverage* __pycache__ logs/* results/* diff --git a/aggregatetunnelmetrics/lockedfile/__init__.py b/aggregatetunnelmetrics/lockedfile/__init__.py index 3aed17b..7a95535 100644 --- a/aggregatetunnelmetrics/lockedfile/__init__.py +++ b/aggregatetunnelmetrics/lockedfile/__init__.py @@ -1,5 +1,6 @@ """ -File locking utilities for safe concurrent access. +File locking utilities for safe concurrent access implementing +the aggregatetunnelmetrics.spec.filelocking.API. This package provides utilities for reading and writing files while holding locks on them, ensuring safe concurrent access. @@ -11,12 +12,14 @@ Note: this package only works on Unix-like systems. # SPDX-License-Identifier: GPL-3.0-or-later +from .api import API from .common import FileLockError -from .fileio import FileIOConfig, read, write +from .fileio import ReadWriteConfig, read, write from .mutex import Mutex __all__ = [ - "FileIOConfig", + "API", + "ReadWriteConfig", "FileLockError", "Mutex", "read", diff --git a/aggregatetunnelmetrics/lockedfile/api.py b/aggregatetunnelmetrics/lockedfile/api.py new file mode 100644 index 0000000..a11bc05 --- /dev/null +++ b/aggregatetunnelmetrics/lockedfile/api.py @@ -0,0 +1,40 @@ +""" +Internal API for file locking. + +You should typically import `lockefile` directly instead. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from dataclasses import dataclass + +from ..spec import filelocking + +from .fileio import read, write +from .mutex import Mutex + + +@dataclass +class API: + """Implements aggregatetunnelmetrics.spec.filelocking.API.""" + + def readfile( + self, + filepath: str, + config: filelocking.ReadWriteConfig | None, + ) -> str: + """Read file contents while holding a lock.""" + return read(filepath, config) + + def writefile( + self, + filepath: str, + data: str, + config: filelocking.ReadWriteConfig | None, + ) -> None: + """Write file contents while holding a lock.""" + write(filepath, data, config) + + def mutex(self, filepath: str) -> filelocking.Mutex: + """Create a mutex using the given lockfile.""" + return Mutex(filepath) diff --git a/aggregatetunnelmetrics/lockedfile/common.py b/aggregatetunnelmetrics/lockedfile/common.py index f74c32f..311262a 100644 --- a/aggregatetunnelmetrics/lockedfile/common.py +++ b/aggregatetunnelmetrics/lockedfile/common.py @@ -1,4 +1,8 @@ -"""Common code for file locking.""" +""" +Internal, common code for file locking. + +You should typically import `lockefile` directly instead. +""" # SPDX-License-Identifier: GPL-3.0-or-later diff --git a/aggregatetunnelmetrics/lockedfile/fileio.py b/aggregatetunnelmetrics/lockedfile/fileio.py index 1e0cdc6..17bd780 100644 --- a/aggregatetunnelmetrics/lockedfile/fileio.py +++ b/aggregatetunnelmetrics/lockedfile/fileio.py @@ -6,7 +6,6 @@ You should typically import `lockefile` directly instead. # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass from io import TextIOWrapper from typing import Optional @@ -14,24 +13,15 @@ import fcntl import os import time +from ..spec import filelocking from .common import FileLockError -@dataclass(frozen=True) -class FileIOConfig: - """ - Configures attempting to acquire a file lock. - - Fields: - num_retries: Number of times to retry acquiring the lock - sleep_interval: Time to sleep between each attempt (in seconds) - """ - - num_retries: int = 10 - sleep_interval: float = 0.1 +ReadWriteConfig = filelocking.ReadWriteConfig +"""Type alias containing the file I/O configuration.""" -def read(filepath: str, config: Optional[FileIOConfig] = None) -> str: +def read(filepath: str, config: Optional[ReadWriteConfig] = None) -> str: """ Read entire file while holding a shared lock. @@ -53,7 +43,7 @@ def read(filepath: str, config: Optional[FileIOConfig] = None) -> str: _release(filep) -def write(filepath: str, data: str, config: Optional[FileIOConfig] = None) -> None: +def write(filepath: str, data: str, config: Optional[ReadWriteConfig] = None) -> None: """ Write entire file while holding an exclusive lock. @@ -80,11 +70,11 @@ def write(filepath: str, data: str, config: Optional[FileIOConfig] = None) -> No _release(filep) -def _acquire_shared(filep: TextIOWrapper, config: Optional[FileIOConfig]) -> bool: +def _acquire_shared(filep: TextIOWrapper, config: Optional[ReadWriteConfig]) -> bool: return _try_lock(filep, fcntl.LOCK_SH | fcntl.LOCK_NB, config) -def _acquire_exclusive(filep: TextIOWrapper, config: Optional[FileIOConfig]) -> bool: +def _acquire_exclusive(filep: TextIOWrapper, config: Optional[ReadWriteConfig]) -> bool: return _try_lock(filep, fcntl.LOCK_EX | fcntl.LOCK_NB, config) @@ -96,10 +86,10 @@ def _release(filep: TextIOWrapper) -> None: def _try_lock( - filep: TextIOWrapper, operation: int, config: Optional[FileIOConfig] + filep: TextIOWrapper, operation: int, config: Optional[ReadWriteConfig] ) -> bool: if not config: - config = FileIOConfig() + config = ReadWriteConfig() for _ in range(config.num_retries): try: fcntl.flock(filep.fileno(), operation) diff --git a/aggregatetunnelmetrics/pipeline/config.py b/aggregatetunnelmetrics/pipeline/config.py index 009aba2..6f3c488 100644 --- a/aggregatetunnelmetrics/pipeline/config.py +++ b/aggregatetunnelmetrics/pipeline/config.py @@ -62,8 +62,8 @@ class FileIOConfig: num_retries: int = 10 sleep_interval: float = 0.1 - def as_lockedfile_fileio_config(self) -> lockedfile.FileIOConfig: + def as_lockedfile_fileio_config(self) -> lockedfile.ReadWriteConfig: """Convert to a lockedfile.FileIOConfig.""" - return lockedfile.FileIOConfig( + return lockedfile.ReadWriteConfig( num_retries=self.num_retries, sleep_interval=self.sleep_interval ) diff --git a/aggregatetunnelmetrics/pipeline/state.py b/aggregatetunnelmetrics/pipeline/state.py index 82fee5a..76eeea5 100644 --- a/aggregatetunnelmetrics/pipeline/state.py +++ b/aggregatetunnelmetrics/pipeline/state.py @@ -25,7 +25,7 @@ class ProcessorState: next_submission_after: datetime | None = None @classmethod - def load(cls, path: str, config: lockedfile.FileIOConfig) -> ProcessorState: + def load(cls, path: str, config: lockedfile.ReadWriteConfig) -> ProcessorState: """Load state from file with proper locking.""" try: content = lockedfile.read(path, config) @@ -47,7 +47,7 @@ class ProcessorState: except ValueError as e: raise StateError(f"Invalid datetime in state: {e}") - def save(self, path: str, config: lockedfile.FileIOConfig) -> None: + def save(self, path: str, config: lockedfile.ReadWriteConfig) -> None: """Save state to file with proper locking.""" data = { "next_submission_after": ( diff --git a/aggregatetunnelmetrics/spec/filelocking.py b/aggregatetunnelmetrics/spec/filelocking.py index e7ee73c..6293f2a 100644 --- a/aggregatetunnelmetrics/spec/filelocking.py +++ b/aggregatetunnelmetrics/spec/filelocking.py @@ -56,8 +56,10 @@ class API(Protocol): mutex: create a mutex using a lockfile. """ - def readfile(self, config: ReadWriteConfig | None) -> str: ... + def readfile(self, filepath: str, config: ReadWriteConfig | None) -> str: ... - def writefile(self, data: str, config: ReadWriteConfig | None) -> None: ... + def writefile( + self, filepath: str, data: str, config: ReadWriteConfig | None + ) -> None: ... def mutex(self, filepath: str) -> Mutex: ... diff --git a/pyproject.toml b/pyproject.toml index c017058..b066e1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,3 +12,6 @@ build-backend = "hatchling.build" [dependency-groups] dev = ["black>=24.10.0", "pytest", "pytest-cov"] + +[tool.pytest.ini_options] +markers = ["integration: marks tests as integration tests"] diff --git a/tests/aggregatetunnelmetrics/lockedfile/__init__.py b/tests/aggregatetunnelmetrics/lockedfile/__init__.py index e69de29..e904769 100644 --- a/tests/aggregatetunnelmetrics/lockedfile/__init__.py +++ b/tests/aggregatetunnelmetrics/lockedfile/__init__.py @@ -0,0 +1,3 @@ +"""Tests for the lockedfile package.""" + +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/tests/aggregatetunnelmetrics/lockedfile/test_api.py b/tests/aggregatetunnelmetrics/lockedfile/test_api.py new file mode 100644 index 0000000..3ef7110 --- /dev/null +++ b/tests/aggregatetunnelmetrics/lockedfile/test_api.py @@ -0,0 +1,125 @@ +"""Tests for the lockedfile.api module.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from unittest.mock import patch + +import fcntl +import os +import tempfile + +import pytest + +from aggregatetunnelmetrics import lockedfile +from aggregatetunnelmetrics.spec import filelocking + + +@pytest.fixture +def temp_file(): + """Fixture providing a temporary file path.""" + fd, path = tempfile.mkstemp() + os.close(fd) + yield path + try: + os.unlink(path) + except FileNotFoundError: + pass + + +@pytest.fixture +def temp_dir(): + """Fixture providing a temporary directory.""" + path = tempfile.mkdtemp() + yield path + try: + # Clean up all files in the directory first + for filename in os.listdir(path): + filepath = os.path.join(path, filename) + try: + os.unlink(filepath) + except OSError: + pass + # Then remove the directory itself + os.rmdir(path) + except FileNotFoundError: + pass + + +@pytest.fixture +def api(): + """Fixture providing the API implementation.""" + return lockedfile.API() + + +def test_readfile_nonexistent(api, temp_dir): + """Test reading a nonexistent file fails.""" + nonexistent = os.path.join(temp_dir, "nonexistent") + with pytest.raises(FileNotFoundError): + api.readfile(nonexistent, None) + + +def test_writefile_then_readfile(api, temp_file): + """Test basic write then read functionality.""" + test_content = "test content" + api.writefile(temp_file, test_content, None) + result = api.readfile(temp_file, None) + assert result == test_content + + +def test_mutex_basic_operation(api, temp_file): + """Test basic mutex operation.""" + with api.mutex(temp_file): + assert os.path.exists(temp_file) + + +def test_mutex_prevents_concurrent_access(api, temp_file): + """Test mutex provides mutual exclusion.""" + with api.mutex(temp_file): + # Try to acquire same mutex - should fail + with pytest.raises(lockedfile.FileLockError): + with api.mutex(temp_file): + pass + + +def test_readfile_lock_failure(api, temp_file): + """Test readfile fails when unable to acquire lock.""" + with open(temp_file, "w") as f: + f.write("test") + + with patch("fcntl.flock", side_effect=BlockingIOError()): + with pytest.raises(lockedfile.FileLockError): + api.readfile(temp_file, filelocking.ReadWriteConfig(num_retries=1)) + + +def test_writefile_lock_failure(api, temp_file): + """Test writefile fails when unable to acquire lock.""" + with patch("fcntl.flock", side_effect=BlockingIOError()): + with pytest.raises(lockedfile.FileLockError): + api.writefile(temp_file, "test", filelocking.ReadWriteConfig(num_retries=1)) + + +def test_lock_release_error_suppressed(api, temp_file): + """Test that errors during lock release are suppressed.""" + + def flock_side_effect(fd, operation): + if operation == fcntl.LOCK_UN: # When trying to unlock + raise OSError("mock release error") + return None # Success for lock acquisition + + with patch("fcntl.flock", side_effect=flock_side_effect): + # Should complete without raising exception + api.readfile(temp_file, None) + + +def test_mutex_release_error_suppressed(api, temp_file): + """Test that errors during mutex release are suppressed.""" + + def flock_side_effect(fd, operation): + if operation == fcntl.LOCK_UN: + raise OSError("mock release error") + return None + + with patch("fcntl.flock", side_effect=flock_side_effect): + # Should complete without raising exception + with api.mutex(temp_file): + pass diff --git a/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py b/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py index b0004ec..74a630f 100644 --- a/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py +++ b/tests/aggregatetunnelmetrics/lockedfile/test_fileio.py @@ -1,186 +1,220 @@ -"""Tests for the lockedfile/fileio.py functionality.""" +"""Integration tests for file I/O operations through the API.""" # SPDX-License-Identifier: GPL-3.0-or-later -from unittest.mock import patch, mock_open +from dataclasses import dataclass +from multiprocessing.synchronize import Event +from multiprocessing import Queue, Value +from typing import Any -import fcntl import multiprocessing as mp -import os -import tempfile import time -import unittest -from aggregatetunnelmetrics.lockedfile import common, fileio - - -# Writer process function needs to be at module level for macOS compatibility -def _writer_process(temp_path, valid_contents, should_stop, config): - while not should_stop.is_set(): - for content in valid_contents: - fileio.write(temp_path, content, config=config) +import pytest + +from aggregatetunnelmetrics import lockedfile + + +@dataclass +class Context: + """Context for concurrent file access test.""" + + api: lockedfile.API + filepath: str + valid_contents: list[str] + results_queue: Queue + should_stop: Event + stats_contention: Any + stats_success: Any + processes: list[mp.Process] + start_time: float = 0.0 + + +def writer_process( + api: lockedfile.API, + filepath: str, + valid_contents: list[str], + should_stop: Event, + stats_contention: Any, + stats_success: Any, +) -> None: + """Writer process that records statistics about contention.""" + # Use minimal retries config to force contention + config = lockedfile.ReadWriteConfig(num_retries=1, sleep_interval=0.001) + try: + while not should_stop.is_set(): + for content in valid_contents: + try: + # No delay between attempts to increase contention + api.writefile(filepath, content, config) + with stats_success.get_lock(): + stats_success.value += 1 + except lockedfile.FileLockError: + with stats_contention.get_lock(): + stats_contention.value += 1 + except Exception as e: + print(f"Writer error: {e}") + + +def reader_process( + api: lockedfile.API, + filepath: str, + should_stop: Event, + results_queue: Queue, + stats_contention: Any, + stats_success: Any, +) -> None: + """Reader process that records statistics about contention.""" + # Use minimal retries config to force contention + config = lockedfile.ReadWriteConfig(num_retries=1, sleep_interval=0.001) + try: + while not should_stop.is_set(): + try: + # No delay between attempts to increase contention + content = api.readfile(filepath, config) + if content: + results_queue.put(content) + with stats_success.get_lock(): + stats_success.value += 1 + except lockedfile.FileLockError: + with stats_contention.get_lock(): + stats_contention.value += 1 + except Exception as e: + print(f"Reader error: {e}") + + +def setup_test_context(tmp_path) -> Context: + """Initialize test context with required resources.""" + ctx = Context( + api=lockedfile.API(), + filepath=str(tmp_path / "test.txt"), + valid_contents=["foo", "bar", "baz", "qux"], + results_queue=mp.Queue(), + should_stop=mp.Event(), + stats_contention=Value("i", 0), + stats_success=Value("i", 0), + processes=[], + ) + + # Initialize file + with open(ctx.filepath, "w") as f: + f.write(ctx.valid_contents[0]) + + return ctx + + +def create_processes(ctx: Context) -> None: + """Create reader and writer processes.""" + # Create writers + for _ in range(8): + p = mp.Process( + target=writer_process, + args=( + ctx.api, + ctx.filepath, + ctx.valid_contents, + ctx.should_stop, + ctx.stats_contention, + ctx.stats_success, + ), + ) + ctx.processes.append(p) + + # Create readers + for _ in range(8): + p = mp.Process( + target=reader_process, + args=( + ctx.api, + ctx.filepath, + ctx.should_stop, + ctx.results_queue, + ctx.stats_contention, + ctx.stats_success, + ), + ) + ctx.processes.append(p) -# Reader process function needs to be at module level for macOS compatibility -def _reader_process(temp_path, should_stop, results_queue, config): - while not should_stop.is_set(): - content = fileio.read(temp_path, config=config) - if content: - results_queue.put(content) - - -class TestFileIOUnit(unittest.TestCase): - """Unit tests for the lockedfile/fileio.py functionality.""" - - def test_read_success(self): - """Test successful file read with lock acquisition.""" - mock_content = "test content" - m = mock_open(read_data=mock_content) - - with patch("builtins.open", m), patch("fcntl.flock") as mock_flock: - result = fileio.read("test.txt") - self.assertEqual(result, mock_content) - m().read.assert_called_once() - self.assertEqual(mock_flock.call_count, 2) # acquire and release - - def test_read_lock_failure(self): - """Test read fails when unable to acquire lock.""" - m = mock_open() - - with ( - patch("builtins.open", m), - patch("fcntl.flock", side_effect=BlockingIOError()), - patch("time.sleep"), - ): - - with self.assertRaises(common.FileLockError): - fileio.read("test.txt") - - def test_write_success(self): - """Test successful file write with lock acquisition.""" - test_content = "test content" - m = mock_open() - - with ( - patch("builtins.open", m), - patch("fcntl.flock") as mock_flock, - patch("os.fsync") as mock_fsync, - ): - - fileio.write("test.txt", test_content) - - m().write.assert_called_once_with(test_content) - m().flush.assert_called_once() - mock_fsync.assert_called_once() - self.assertEqual(mock_flock.call_count, 2) # acquire and release - - def test_write_lock_failure(self): - """Test write fails when unable to acquire lock.""" - m = mock_open() - - with ( - patch("builtins.open", m), - patch("fcntl.flock", side_effect=BlockingIOError()), - patch("time.sleep"), - ): - - with self.assertRaises(common.FileLockError): - fileio.write("test.txt", "content") - - def test_release_error_is_suppressed(self): - """Test that errors during lock release are suppressed.""" - mock_content = "test content" - m = mock_open(read_data=mock_content) - - def flock_side_effect(fd, operation): - if operation == fcntl.LOCK_UN: # When trying to unlock - raise OSError("mock release error") - return None # Success for lock acquisition - - with ( - patch("builtins.open", m), - patch("fcntl.flock", side_effect=flock_side_effect), - ): - # Should complete without raising exception - result = fileio.read("test.txt") - self.assertEqual(result, mock_content) - - -class TestFileIOIntegration(unittest.TestCase): - """Integration tests for the lockedfile/fileio.py functionality.""" - - def setUp(self): - self.temp_fd, self.temp_path = tempfile.mkstemp() - - def tearDown(self): - os.close(self.temp_fd) - os.unlink(self.temp_path) - - def test_concurrent_access(self): - """ - The purpose of this test is to spawn many readers and - writers using background processes, to ensure that, - regardless of concurrent access attempts, we are always - able to write to the file, and read from the file, - consistent strings. With enough repetitions, if there - are actual race conditions, the file content will - possibly eventually be corrupted and we will notice. - """ - - # TODO(bassosimone): ensure we have some cases of contention - # by collecting statistics about retries. - - valid_contents = ["foo", "bar", "baz", "qux"] - results_queue = mp.Queue() - should_stop = mp.Event() - config = fileio.FileIOConfig(num_retries=30, sleep_interval=0.1) - - writers = [ - mp.Process( - target=_writer_process, - args=(self.temp_path, valid_contents, should_stop, config), - ) - for _ in range(4) - ] - - readers = [ - mp.Process( - target=_reader_process, - args=(self.temp_path, should_stop, results_queue, config), - ) - for _ in range(8) - ] - +def run_processes(ctx: Context) -> None: + """Start and manage process execution.""" + try: # Start all processes - for p in readers + writers: + for p in ctx.processes: p.start() - # Allow the processes to run for a while - time.sleep(1) - - # Interrupt the processes - should_stop.set() - - # Wait for processes to terminate - for p in readers + writers: - p.join() - - # Collect the results from the queue - observed_contents = set() - while not results_queue.empty(): - value = results_queue.get() - observed_contents.add(value) - - # Ensure we never read garbled data - self.assertTrue( - len(observed_contents) > 0, - "No data was read", - ) - self.assertTrue( - observed_contents.issubset(valid_contents), - f"Observed contents: {observed_contents}", - ) - - -if __name__ == "__main__": - unittest.main() + ctx.start_time = time.time() + time.sleep(1.0) + ctx.should_stop.set() + + # Wait for processes with timeout + wait_for_processes(ctx.processes) + + finally: + cleanup_processes(ctx.processes) + + +def wait_for_processes(processes: list[mp.Process], timeout: float = 2.0) -> None: + """Wait for processes to complete with timeout.""" + start_time = time.time() + while time.time() - start_time < timeout: + if not any(p.is_alive() for p in processes): + break + time.sleep(0.1) + + +def cleanup_processes(processes: list[mp.Process]) -> None: + """Ensure all processes are terminated and cleaned up.""" + for p in processes: + try: + if p.is_alive(): + p.terminate() + p.join(0.1) + except Exception: + pass + + +def collect_and_verify_results(ctx: Context) -> None: + """Collect results and verify test conditions.""" + observed_contents = set() + while not ctx.results_queue.empty(): + try: + observed_contents.add(ctx.results_queue.get_nowait()) + except Exception: + break + + # Print statistics + print(f"\nTest duration: {time.time() - ctx.start_time:.2f} seconds") + print(f"Contention events: {ctx.stats_contention.value}") + print(f"Successful operations: {ctx.stats_success.value}") + + total_ops = ctx.stats_contention.value + ctx.stats_success.value + if total_ops > 0: + contention_ratio = ctx.stats_contention.value / total_ops + print(f"Contention ratio: {contention_ratio:.2%}") + + # Verify results + assert ( + ctx.stats_contention.value > 1000 + ), f"Expected significant contention (got {ctx.stats_contention.value})" + assert ( + ctx.stats_success.value > 1000 + ), f"Expected significant successes (got {ctx.stats_success.value})" + + contention_ratio = ctx.stats_contention.value / total_ops + assert ( + 0.3 <= contention_ratio <= 0.7 + ), f"Expected balanced contention ratio (got {contention_ratio:.2%})" + + assert observed_contents.issubset( + ctx.valid_contents + ), f"Data corruption detected: {observed_contents}" + + +@pytest.mark.integration +def test_concurrent_file_access(tmp_path): + """Integration test verifying concurrent file access with contention.""" + ctx = setup_test_context(tmp_path) + create_processes(ctx) + run_processes(ctx) + collect_and_verify_results(ctx) diff --git a/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py b/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py index e6960ec..e6213e4 100644 --- a/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py +++ b/tests/aggregatetunnelmetrics/lockedfile/test_mutex.py @@ -1,210 +1,205 @@ -"""Tests for the lockedfile/mutex.py functionality.""" +"""Integration tests for mutex operations through the API.""" # SPDX-License-Identifier: GPL-3.0-or-later -from unittest.mock import patch, mock_open -import fcntl +from dataclasses import dataclass +from multiprocessing.synchronize import Event +from multiprocessing import Queue, Value +from typing import Any + import multiprocessing as mp import os -import tempfile import time -import unittest - -from aggregatetunnelmetrics.lockedfile import Mutex, FileLockError - -def writer_process(lock_path, temp_path, valid_contents, should_stop, results_queue): +import pytest + +from aggregatetunnelmetrics import lockedfile + + +@dataclass +class Context: + """Context for concurrent mutex access test.""" + + api: lockedfile.API + lock_path: str + data_path: str + valid_contents: list[str] + results_queue: Queue + should_stop: Event + stats_contention: Any + stats_success: Any + processes: list[mp.Process] + start_time: float = 0.0 + + +def worker_process( + api: lockedfile.API, + lock_path: str, + data_path: str, + valid_contents: list[str], + should_stop: Event, + results_queue: Queue, + stats_contention: Any, + stats_success: Any, +) -> None: + """Worker process that tries to access shared resource with mutex.""" try: while not should_stop.is_set(): for content in valid_contents: - while not should_stop.is_set(): - try: - with Mutex(lock_path): - with open(temp_path, "w") as f: - f.write(content) - f.flush() - os.fsync(f.fileno()) - break # Success! Move to next content - except FileLockError: - # Lock busy, try again - time.sleep(0.1) - except Exception as exc: - results_queue.put(exc) - - -def reader_process(lock_path, temp_path, should_stop, results_queue): - try: - while not should_stop.is_set(): - try: - with Mutex(lock_path): - with open(temp_path) as f: - content = f.read() - if content: - results_queue.put(content) - break # Success! Continue outer loop - except FileLockError: - # Lock busy, try again - time.sleep(0.1) - except Exception as exc: - results_queue.put(exc) - - -class TestMutexUnit(unittest.TestCase): - """Unit tests for the lockedfile/mutex.py functionality.""" - - def test_mutex_acquire_success(self): - """Test successful mutex lock acquisition.""" - m = mock_open() - - with patch("builtins.open", m), patch("fcntl.flock") as mock_flock: - with Mutex("test.lock"): - mock_flock.assert_called_once_with( - m().fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB - ) - - def test_mutex_release(self): - """Test mutex properly releases lock.""" - m = mock_open() - - with patch("builtins.open", m), patch("fcntl.flock") as mock_flock: - with Mutex("test.lock"): - pass - - # Second call should be unlock - mock_flock.assert_called_with(m().fileno(), fcntl.LOCK_UN) - self.assertEqual(mock_flock.call_count, 2) - m().close.assert_called_once() - - def test_mutex_acquire_failure_io_error(self): - """Test mutex fails to acquire with IOError.""" - m = mock_open() - - with ( - patch("builtins.open", m), - patch("fcntl.flock", side_effect=IOError("mock error")), - ): - - with self.assertRaises(FileLockError) as cm: - with Mutex("test.lock"): - pass - self.assertIn("mock error", str(cm.exception)) - m().close.assert_called_once() - - def test_mutex_acquire_failure_blocking(self): - """Test mutex fails to acquire when already locked.""" - m = mock_open() - - with ( - patch("builtins.open", m), - patch("fcntl.flock", side_effect=BlockingIOError()), - ): - - with self.assertRaises(FileLockError) as cm: - with Mutex("test.lock"): - pass - self.assertIn("cannot acquire lock", str(cm.exception)) - m().close.assert_called_once() - - def test_mutex_release_error_suppressed(self): - """Test that errors during mutex release are suppressed.""" - m = mock_open() - - def flock_side_effect(fd, operation): - if operation == fcntl.LOCK_UN: - raise OSError("mock release error") - return None - - with ( - patch("builtins.open", m), - patch("fcntl.flock", side_effect=flock_side_effect), - ): - # Should complete without raising exception - with Mutex("test.lock"): - pass - m().close.assert_called_once() - - -class TestMutexIntegration(unittest.TestCase): - """Integration tests for the lockedfile/mutex.py functionality.""" - - def setUp(self): - self.temp_dir = tempfile.mkdtemp() - self.lock_path = os.path.join(self.temp_dir, "test.lock") - self.temp_path = os.path.join(self.temp_dir, "test.txt") - - def tearDown(self): - try: - os.unlink(self.lock_path) - os.unlink(self.temp_path) - os.rmdir(self.temp_dir) - except FileNotFoundError: - pass + try: + # Use minimal delay to increase contention + with api.mutex(lock_path): + # Read current content + with open(data_path) as f: + current = f.read() + results_queue.put(current) + + # Write new content + with open(data_path, "w") as f: + f.write(content) + f.flush() + os.fsync(f.fileno()) + + with stats_success.get_lock(): + stats_success.value += 1 + + except lockedfile.FileLockError: + with stats_contention.get_lock(): + stats_contention.value += 1 + # Minimal sleep to maximize contention + time.sleep(0.001) + except Exception as e: + print(f"Worker error: {e}") + + +def setup_test_context(tmp_path) -> Context: + """Initialize test context with required resources.""" + ctx = Context( + api=lockedfile.API(), + lock_path=str(tmp_path / "test.lock"), + data_path=str(tmp_path / "test.txt"), + valid_contents=["alpha", "beta", "gamma", "delta"], + results_queue=mp.Queue(), + should_stop=mp.Event(), + stats_contention=Value("i", 0), + stats_success=Value("i", 0), + processes=[], + ) + + # Initialize data file + with open(ctx.data_path, "w") as f: + f.write(ctx.valid_contents[0]) + + return ctx + + +def create_processes(ctx: Context) -> None: + """Create worker processes.""" + # Create multiple workers to ensure contention + for _ in range(8): + p = mp.Process( + target=worker_process, + args=( + ctx.api, + ctx.lock_path, + ctx.data_path, + ctx.valid_contents, + ctx.should_stop, + ctx.results_queue, + ctx.stats_contention, + ctx.stats_success, + ), + ) + ctx.processes.append(p) + - def test_concurrent_access(self): - """ - Test that mutex provides mutual exclusion with concurrent access. - By writing only from a fixed set of valid strings, any corruption - would result in invalid content that wouldn't match our set. - """ - valid_contents = ["alpha", "beta", "gamma", "delta"] - results_queue = mp.Queue() - should_stop = mp.Event() - - # Create and initialize the file - with open(self.temp_path, "w") as f: - f.write(valid_contents[0]) - - # Start processes - writers = [ - mp.Process( - target=writer_process, - args=( - self.lock_path, - self.temp_path, - valid_contents, - should_stop, - results_queue, - ), - ) - for _ in range(4) - ] - readers = [ - mp.Process( - target=reader_process, - args=(self.lock_path, self.temp_path, should_stop, results_queue), - ) - for _ in range(8) - ] - - for p in readers + writers: +def run_processes(ctx: Context) -> None: + """Start and manage process execution.""" + try: + # Start all processes + for p in ctx.processes: p.start() - # Let them run for a bit - time.sleep(1) - should_stop.set() - - # Wait for all processes - for p in readers + writers: - p.join() - - # Check for any errors - observed_contents = set() - while not results_queue.empty(): - item = results_queue.get() - if isinstance(item, Exception): - raise item - observed_contents.add(item) - - # Verify we only saw valid contents - self.assertTrue( - len(observed_contents) > 0, - "No data was read", - ) - self.assertTrue( - observed_contents.issubset(valid_contents), - f"Observed corrupted contents: {observed_contents}", - ) + ctx.start_time = time.time() + time.sleep(1.0) + ctx.should_stop.set() + + # Wait for processes with timeout + wait_for_processes(ctx.processes) + + finally: + cleanup_processes(ctx.processes) + + +def wait_for_processes(processes: list[mp.Process], timeout: float = 2.0) -> None: + """Wait for processes to complete with timeout.""" + start_time = time.time() + while time.time() - start_time < timeout: + if not any(p.is_alive() for p in processes): + break + time.sleep(0.1) -if __name__ == "__main__": - unittest.main() +def cleanup_processes(processes: list[mp.Process]) -> None: + """Ensure all processes are terminated and cleaned up.""" + for p in processes: + try: + if p.is_alive(): + p.terminate() + p.join(0.1) + except Exception: + pass + + +def collect_and_verify_results(ctx: Context) -> None: + """Collect results and verify test conditions.""" + observed_contents = set() + while not ctx.results_queue.empty(): + try: + observed_contents.add(ctx.results_queue.get_nowait()) + except Exception: + break + + # Print statistics + print(f"\nTest duration: {time.time() - ctx.start_time:.2f} seconds") + print(f"Contention events: {ctx.stats_contention.value}") + print(f"Successful operations: {ctx.stats_success.value}") + + total_ops = ctx.stats_contention.value + ctx.stats_success.value + if total_ops > 0: + contention_ratio = ctx.stats_contention.value / total_ops + print(f"Contention ratio: {contention_ratio:.2%}") + + # Verify results + assert ( + ctx.stats_contention.value > 1000 + ), f"Expected significant contention (got {ctx.stats_contention.value})" + assert ( + ctx.stats_success.value > 1000 + ), f"Expected significant successes (got {ctx.stats_success.value})" + + contention_ratio = ctx.stats_contention.value / total_ops + assert ( + 0.3 <= contention_ratio <= 0.7 + ), f"Expected balanced contention ratio (got {contention_ratio:.2%})" + + assert observed_contents.issubset( + ctx.valid_contents + ), f"Data corruption detected: {observed_contents}" + + +@pytest.mark.integration +def test_concurrent_mutex_access(tmp_path): + """ + Integration test verifying mutex provides proper exclusion with contention. + + This test ensures that: + 1. We observe significant contention between processes + 2. Despite contention, data remains consistent + 3. Mutex operations eventually succeed + 4. We achieve a balanced contention ratio + """ + ctx = setup_test_context(tmp_path) + create_processes(ctx) + run_processes(ctx) + collect_and_verify_results(ctx) diff --git a/tests/aggregatetunnelmetrics/pipeline/test_config.py b/tests/aggregatetunnelmetrics/pipeline/test_config.py index 2caf3e8..c9af26a 100644 --- a/tests/aggregatetunnelmetrics/pipeline/test_config.py +++ b/tests/aggregatetunnelmetrics/pipeline/test_config.py @@ -5,7 +5,7 @@ import unittest from aggregatetunnelmetrics.pipeline.config import ProcessConfig, FileIOConfig -from aggregatetunnelmetrics.lockedfile import FileIOConfig as LockedFileIOConfig +from aggregatetunnelmetrics.lockedfile import ReadWriteConfig as LockedFileIOConfig class TestProcessConfig(unittest.TestCase): diff --git a/tests/aggregatetunnelmetrics/pipeline/test_state.py b/tests/aggregatetunnelmetrics/pipeline/test_state.py index 5541eb8..1344bb9 100644 --- a/tests/aggregatetunnelmetrics/pipeline/test_state.py +++ b/tests/aggregatetunnelmetrics/pipeline/test_state.py @@ -11,7 +11,7 @@ from unittest.mock import patch from aggregatetunnelmetrics.pipeline.state import ProcessorState from aggregatetunnelmetrics.pipeline.errors import StateError -from aggregatetunnelmetrics.lockedfile import FileIOConfig +from aggregatetunnelmetrics.lockedfile import ReadWriteConfig class TestProcessorState(unittest.TestCase): @@ -21,7 +21,7 @@ class TestProcessorState(unittest.TestCase): """Create a temporary file for testing.""" self.temp_dir = tempfile.mkdtemp() self.state_file = os.path.join(self.temp_dir, "state.json") - self.config = FileIOConfig(num_retries=1, sleep_interval=0.1) + self.config = ReadWriteConfig(num_retries=1, sleep_interval=0.1) def tearDown(self): """Clean up temporary files.""" diff --git a/tests/aggregatetunnelmetrics/spec/__init__.py b/tests/aggregatetunnelmetrics/spec/__init__.py index 421f8fc..c0e5698 100644 --- a/tests/aggregatetunnelmetrics/spec/__init__.py +++ b/tests/aggregatetunnelmetrics/spec/__init__.py @@ -1,3 +1,3 @@ -"""Tests for the aggregatetunnelmetrics.spec package.""" +"""Tests for the spec package.""" # SPDX-License-Identifier: GPL-3.0-or-later diff --git a/tests/aggregatetunnelmetrics/spec/test_aggregator.py b/tests/aggregatetunnelmetrics/spec/test_aggregator.py index eeada48..7b1246d 100644 --- a/tests/aggregatetunnelmetrics/spec/test_aggregator.py +++ b/tests/aggregatetunnelmetrics/spec/test_aggregator.py @@ -1,4 +1,4 @@ -"""Tests for the aggregatetunnelmetrics.spec.aggregator module.""" +"""Tests for the spec.aggregator module.""" # SPDX-License-Identifier: GPL-3.0-or-later diff --git a/tests/aggregatetunnelmetrics/spec/test_fieldtesting.py b/tests/aggregatetunnelmetrics/spec/test_fieldtesting.py index 3195fd3..b426ffb 100644 --- a/tests/aggregatetunnelmetrics/spec/test_fieldtesting.py +++ b/tests/aggregatetunnelmetrics/spec/test_fieldtesting.py @@ -1,4 +1,4 @@ -"""Tests for the aggregatetunnelmetrics.spec.fieldtesting module.""" +"""Tests for the spec.fieldtesting module.""" # SPDX-License-Identifier: GPL-3.0-or-later diff --git a/tests/aggregatetunnelmetrics/spec/test_filelocking.py b/tests/aggregatetunnelmetrics/spec/test_filelocking.py index 56a23a9..4ddeded 100644 --- a/tests/aggregatetunnelmetrics/spec/test_filelocking.py +++ b/tests/aggregatetunnelmetrics/spec/test_filelocking.py @@ -1,4 +1,4 @@ -"""Tests for the aggregatetunnelmetrics.spec.filelocking module.""" +"""Tests for the spec.filelocking module.""" # SPDX-License-Identifier: GPL-3.0-or-later @@ -28,10 +28,14 @@ def test_mutex_protocol(): class FakeAPI: - def readfile(self, config: filelocking.ReadWriteConfig | None) -> str: + def readfile( + self, filepath: str, config: filelocking.ReadWriteConfig | None + ) -> str: return "" - def writefile(self, data: str, config: filelocking.ReadWriteConfig | None) -> None: + def writefile( + self, filepath: str, data: str, config: filelocking.ReadWriteConfig | None + ) -> None: pass def mutex(self, filepath: str) -> filelocking.Mutex: diff --git a/tests/aggregatetunnelmetrics/spec/test_metrics.py b/tests/aggregatetunnelmetrics/spec/test_metrics.py index e0b26d3..bf16b29 100644 --- a/tests/aggregatetunnelmetrics/spec/test_metrics.py +++ b/tests/aggregatetunnelmetrics/spec/test_metrics.py @@ -1,4 +1,4 @@ -"""Tests for the aggregatetunnelmetrics.spec.metrics module.""" +"""Tests for the spec.metrics module.""" # SPDX-License-Identifier: GPL-3.0-or-later diff --git a/tests/aggregatetunnelmetrics/spec/test_oonicollector.py b/tests/aggregatetunnelmetrics/spec/test_oonicollector.py index 63f9397..31864d4 100644 --- a/tests/aggregatetunnelmetrics/spec/test_oonicollector.py +++ b/tests/aggregatetunnelmetrics/spec/test_oonicollector.py @@ -1,4 +1,4 @@ -"""Tests for the aggregatetunnelmetrics.spec.oonicollector module.""" +"""Tests for the spec.oonicollector module.""" # SPDX-License-Identifier: GPL-3.0-or-later -- GitLab From 6de4f4e297174b2722beaa43506d8b4070f31970 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 23 Feb 2025 17:08:25 +0100 Subject: [PATCH 55/75] refactor(fieldtestingcsv): use protocol and pytest --- .../fieldtestingcsv/__init__.py | 4 +- .../fieldtestingcsv/model.py | 52 +-- .../fieldtestingcsv/parser.py | 16 +- .../fieldtestingcsv/__init__.py | 3 + .../fieldtestingcsv/test_parser.py | 435 ++++++++++-------- 5 files changed, 260 insertions(+), 250 deletions(-) diff --git a/aggregatetunnelmetrics/fieldtestingcsv/__init__.py b/aggregatetunnelmetrics/fieldtestingcsv/__init__.py index efffa6b..c822140 100644 --- a/aggregatetunnelmetrics/fieldtestingcsv/__init__.py +++ b/aggregatetunnelmetrics/fieldtestingcsv/__init__.py @@ -10,6 +10,6 @@ See https://0xacab.org/leap/solitech-compose-client/-/blob/main/images/obfsvpn-o # SPDX-License-Identifier: GPL-3.0-or-later from .model import Entry -from .parser import parse_file +from .parser import Streamer, parse_file -__all__ = ["Entry", "parse_file"] +__all__ = ["Streamer", "Entry", "parse_file"] diff --git a/aggregatetunnelmetrics/fieldtestingcsv/model.py b/aggregatetunnelmetrics/fieldtestingcsv/model.py index 85d1f25..65db309 100644 --- a/aggregatetunnelmetrics/fieldtestingcsv/model.py +++ b/aggregatetunnelmetrics/fieldtestingcsv/model.py @@ -6,54 +6,8 @@ We expect you to import `fieldtestingcsv` directly, not this module. # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass -from datetime import datetime +from ..spec import fieldtesting -@dataclass(frozen=True) -class Entry: - """ - Models a single field-testing entry read from the CSV datastore. - - The order of the fields in this dataclass it the same - of the fields within the CSV file. - """ - - # Fields present in the CSV file format as of 2024-12-06 - filename: str - date: datetime - asn: str - isp: str - est_city: str - user: str - region: str - server_fqdn: str - server_ip: str - mobile: bool - tunnel: str # 'baseline', 'tunnel', 'ERROR/baseline', 'ERROR/tunnel' - throughput_download: float - throughput_upload: float - latency_download: float - latency_upload: float - retransmission_download: float - retransmission_upload: float - ping_packets_loss: float - ping_roundtrip_min: float - ping_roundtrip_avg: float - ping_roundtrip_max: float - err_message: str - protocol: str - - # TODO(bassosimone): do we need to specialize on the ping target address - # or shall we just consider it to be a constant? - - def is_tunnel_measurement(self) -> bool: - """ - Return whether this is a tunnel measurement, which includes both - successful and failed tunnel measurements. - """ - return self.tunnel in ("tunnel", "ERROR/tunnel") - - def is_tunnel_error_measurement(self) -> bool: - """Return whether this is a failed tunnel measurement""" - return self.tunnel == "ERROR/tunnel" +Entry = fieldtesting.Entry +"""Type alias for the field-testing CSV entry type.""" diff --git a/aggregatetunnelmetrics/fieldtestingcsv/parser.py b/aggregatetunnelmetrics/fieldtestingcsv/parser.py index 3e70dbd..e2eec85 100644 --- a/aggregatetunnelmetrics/fieldtestingcsv/parser.py +++ b/aggregatetunnelmetrics/fieldtestingcsv/parser.py @@ -86,7 +86,13 @@ def parse_single_row(row: Dict[str, str]) -> Entry: ping_roundtrip_avg=float(row["ping_roundtrip_avg"]), ping_roundtrip_max=float(row["ping_roundtrip_max"]), err_message=str(row["err_message"]).strip(), - protocol=str(row["PT"]), # rename from "PT" to "protocol" + # Rename from "PT" to "protocol" + protocol=str(row["PT"]), + # Add the new fields from the spec with None values + ping_target_address=None, + ndt_target_hostname=None, + ndt_target_address=None, + ndt_target_port=None, ) @@ -119,3 +125,11 @@ def stream_file(filename: str) -> Iterator[Entry]: except (ValueError, KeyError) as exc: logging.warning(f"cannot import row: {exc}") continue + + +class Streamer: + """Implements aggregatetunnelmetrics.fieldtesting.Streamer""" + + def stream(self, filepath: str) -> Iterator[Entry]: + """Implements aggregatetunnelmetrics.fieldtesting.Streamer""" + return stream_file(filepath) diff --git a/tests/aggregatetunnelmetrics/fieldtestingcsv/__init__.py b/tests/aggregatetunnelmetrics/fieldtestingcsv/__init__.py index e69de29..f6a61d4 100644 --- a/tests/aggregatetunnelmetrics/fieldtestingcsv/__init__.py +++ b/tests/aggregatetunnelmetrics/fieldtestingcsv/__init__.py @@ -0,0 +1,3 @@ +"""Tests for the fieldtestingcsv package.""" + +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py b/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py index bd002c4..7e8732d 100644 --- a/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py +++ b/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py @@ -1,212 +1,208 @@ -"""Tests for the field-testing CSV parser functionality.""" +"""Tests for the fieldtestingcsv.parser module.""" # SPDX-License-Identifier: GPL-3.0-or-later -import unittest from datetime import datetime, timezone import tempfile import os import logging +import pytest + +from aggregatetunnelmetrics import fieldtestingcsv +from aggregatetunnelmetrics.spec import fieldtesting + logging.basicConfig(level=logging.ERROR) # do not log when running tests -from aggregatetunnelmetrics.fieldtestingcsv.model import Entry -from aggregatetunnelmetrics.fieldtestingcsv.parser import ( - parse_datetime, - parse_bool, - parse_file, - parse_single_row, - stream_file, -) - - -class TestParserFunctions(unittest.TestCase): - """Test individual parser functions.""" - - def test_parse_datetime(self): - """Test parse_datetime with valid and invalid inputs""" - # Valid UTC date - dt = parse_datetime("Mon Jan 01 12:00:00 UTC 2024") - self.assertEqual(dt, datetime(2024, 1, 1, 12, 0, tzinfo=timezone.utc)) - - # Non-UTC timezone - with self.assertRaises(ValueError): - parse_datetime("Mon Jan 01 12:00:00 EST 2024") - - # Invalid format - with self.assertRaises(ValueError): - parse_datetime("2024-01-01 12:00:00") - - def test_parse_datetime_non_utc(self): - """Test parse_datetime explicitly with non-UTC timezone""" - with self.assertRaises(ValueError): - parse_datetime("Mon Jan 01 12:00:00 EST 2024") - - def test_parse_bool(self): - """Test parse_bool with various inputs""" - self.assertTrue(parse_bool("true")) - self.assertTrue(parse_bool("TRUE")) - self.assertFalse(parse_bool("false")) - self.assertFalse(parse_bool("FALSE")) - self.assertFalse(parse_bool("invalid")) - - -class TestRowParsing(unittest.TestCase): - """Test parsing individual rows.""" - - def test_parse_valid_row(self): - """Test parsing a valid row dictionary""" - row = { - "filename": "test.csv", - "date": "Mon Jan 01 12:00:00 UTC 2024", - "asn": "AS12345", - "isp": "TestISP", - "est_city": "TestCity", - "user": "user1", - "region": "TestRegion", - "server_fqdn": "ndt.example.com", - "server_ip": "1.2.3.4", - "mobile": "false", - "tunnel": "tunnel", - "throughput_download": "100.0", - "throughput_upload": "50.0", - "latency_download": "20.0", - "latency_upload": "25.0", - "retransmission_download": "0.01", - "retransmission_upload": "0.02", - "ping_packets_loss": "0.0", - "ping_roundtrip_min": "10.0", - "ping_roundtrip_avg": "12.0", - "ping_roundtrip_max": "15.0", - "err_message": "", - "PT": "obfs4", - } - - entry = parse_single_row(row) - self.assertEqual(entry.filename, "test.csv") - self.assertEqual(entry.protocol, "obfs4") - self.assertEqual(entry.throughput_download, 100.0) - self.assertEqual(entry.mobile, False) - self.assertEqual(entry.date, datetime(2024, 1, 1, 12, 0, tzinfo=timezone.utc)) - self.assertEqual(entry.asn, "AS12345") - self.assertEqual(entry.isp, "TestISP") - self.assertEqual(entry.est_city, "TestCity") - self.assertEqual(entry.user, "user1") - self.assertEqual(entry.region, "TestRegion") - self.assertEqual(entry.server_fqdn, "ndt.example.com") - self.assertEqual(entry.server_ip, "1.2.3.4") - self.assertEqual(entry.tunnel, "tunnel") - self.assertEqual(entry.throughput_upload, 50.0) - self.assertEqual(entry.latency_download, 20.0) - self.assertEqual(entry.latency_upload, 25.0) - self.assertEqual(entry.retransmission_download, 0.01) - self.assertEqual(entry.retransmission_upload, 0.02) - self.assertEqual(entry.ping_packets_loss, 0.0) - self.assertEqual(entry.ping_roundtrip_min, 10.0) - self.assertEqual(entry.ping_roundtrip_avg, 12.0) - self.assertEqual(entry.ping_roundtrip_max, 15.0) - self.assertEqual(entry.err_message, "") - - def test_parse_row_missing_field(self): - """Test parsing row with missing required field""" - row = { - # Missing most required fields - "filename": "test.csv", - "date": "Mon Jan 01 12:00:00 UTC 2024", - } - - with self.assertRaises(KeyError): - parse_single_row(row) - - def test_parse_row_invalid_numeric(self): - """Test parsing row with invalid numeric values""" - valid_row = { - "filename": "test.csv", - "date": "Mon Jan 01 12:00:00 UTC 2024", - "asn": "AS12345", - "isp": "TestISP", - "est_city": "TestCity", - "user": "user1", - "region": "TestRegion", - "server_fqdn": "ndt.example.com", - "server_ip": "1.2.3.4", - "mobile": "false", - "tunnel": "tunnel", - "throughput_download": "invalid", # Invalid numeric value - "throughput_upload": "50.0", - "latency_download": "20.0", - "latency_upload": "25.0", - "retransmission_download": "0.01", - "retransmission_upload": "0.02", - "ping_packets_loss": "0.0", - "ping_roundtrip_min": "10.0", - "ping_roundtrip_avg": "12.0", - "ping_roundtrip_max": "15.0", - "err_message": "", - "PT": "obfs4", - } - - with self.assertRaises(ValueError): - parse_single_row(valid_row) - - -class TestFileOperations(unittest.TestCase): - """Test file-level operations (parse and stream).""" - - def setUp(self): - self.temp_dir = tempfile.mkdtemp() - self.csv_path = os.path.join(self.temp_dir, "test.csv") - - def tearDown(self): - try: - os.unlink(self.csv_path) - os.rmdir(self.temp_dir) - except FileNotFoundError: - pass - - def write_csv(self, content: str): - """Helper to write CSV content to temp file""" - with open(self.csv_path, "w") as f: - f.write(content) - - def test_stream_valid_file(self): - """Test streaming from a valid CSV file""" - csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT + +def test_streamer_implements_interface(): + """Test that Streamer implements the Protocol interface""" + assert isinstance(fieldtestingcsv.Streamer(), fieldtesting.Streamer) + + +def test_parse_datetime(): + """Test parse_datetime with valid UTC date""" + dt = fieldtestingcsv.parser.parse_datetime("Mon Jan 01 12:00:00 UTC 2024") + assert dt == datetime(2024, 1, 1, 12, 0, tzinfo=timezone.utc) + + +def test_parse_datetime_non_utc(): + """Test parse_datetime explicitly with non-UTC timezone""" + with pytest.raises(ValueError): + fieldtestingcsv.parser.parse_datetime("Mon Jan 01 12:00:00 EST 2024") + + +def test_parse_datetime_invalid_format(): + """Test parse_datetime with invalid format""" + with pytest.raises(ValueError): + fieldtestingcsv.parser.parse_datetime("2024-01-01 12:00:00") + + +def test_parse_bool(): + """Test parse_bool with various inputs""" + assert fieldtestingcsv.parser.parse_bool("true") is True + assert fieldtestingcsv.parser.parse_bool("TRUE") is True + assert fieldtestingcsv.parser.parse_bool("false") is False + assert fieldtestingcsv.parser.parse_bool("FALSE") is False + assert fieldtestingcsv.parser.parse_bool("invalid") is False + + +def test_parse_valid_row(): + """Test parsing a valid row dictionary""" + row = { + "filename": "test.csv", + "date": "Mon Jan 01 12:00:00 UTC 2024", + "asn": "AS12345", + "isp": "TestISP", + "est_city": "TestCity", + "user": "user1", + "region": "TestRegion", + "server_fqdn": "ndt.example.com", + "server_ip": "1.2.3.4", + "mobile": "false", + "tunnel": "tunnel", + "throughput_download": "100.0", + "throughput_upload": "50.0", + "latency_download": "20.0", + "latency_upload": "25.0", + "retransmission_download": "0.01", + "retransmission_upload": "0.02", + "ping_packets_loss": "0.0", + "ping_roundtrip_min": "10.0", + "ping_roundtrip_avg": "12.0", + "ping_roundtrip_max": "15.0", + "err_message": "", + "PT": "obfs4", + } + + entry = fieldtestingcsv.parser.parse_single_row(row) + assert entry.filename == "test.csv" + assert entry.protocol == "obfs4" + assert entry.throughput_download == 100.0 + assert entry.mobile is False + assert entry.date == datetime(2024, 1, 1, 12, 0, tzinfo=timezone.utc) + assert entry.asn == "AS12345" + assert entry.isp == "TestISP" + assert entry.est_city == "TestCity" + assert entry.user == "user1" + assert entry.region == "TestRegion" + assert entry.server_fqdn == "ndt.example.com" + assert entry.server_ip == "1.2.3.4" + assert entry.tunnel == "tunnel" + assert entry.throughput_upload == 50.0 + assert entry.latency_download == 20.0 + assert entry.latency_upload == 25.0 + assert entry.retransmission_download == 0.01 + assert entry.retransmission_upload == 0.02 + assert entry.ping_packets_loss == 0.0 + assert entry.ping_roundtrip_min == 10.0 + assert entry.ping_roundtrip_avg == 12.0 + assert entry.ping_roundtrip_max == 15.0 + assert entry.err_message == "" + + +def test_parse_row_missing_field(): + """Test parsing row with missing required field""" + row = { + # Missing most required fields + "filename": "test.csv", + "date": "Mon Jan 01 12:00:00 UTC 2024", + } + + with pytest.raises(KeyError): + fieldtestingcsv.parser.parse_single_row(row) + + +def test_parse_row_invalid_numeric(): + """Test parsing row with invalid numeric values""" + valid_row = { + "filename": "test.csv", + "date": "Mon Jan 01 12:00:00 UTC 2024", + "asn": "AS12345", + "isp": "TestISP", + "est_city": "TestCity", + "user": "user1", + "region": "TestRegion", + "server_fqdn": "ndt.example.com", + "server_ip": "1.2.3.4", + "mobile": "false", + "tunnel": "tunnel", + "throughput_download": "invalid", # Invalid numeric value + "throughput_upload": "50.0", + "latency_download": "20.0", + "latency_upload": "25.0", + "retransmission_download": "0.01", + "retransmission_upload": "0.02", + "ping_packets_loss": "0.0", + "ping_roundtrip_min": "10.0", + "ping_roundtrip_avg": "12.0", + "ping_roundtrip_max": "15.0", + "err_message": "", + "PT": "obfs4", + } + + with pytest.raises(ValueError): + fieldtestingcsv.parser.parse_single_row(valid_row) + + +@pytest.fixture +def temp_csv_file(): + """Fixture providing a temporary CSV file""" + temp_dir = tempfile.mkdtemp() + csv_path = os.path.join(temp_dir, "test.csv") + yield csv_path + try: + os.unlink(csv_path) + os.rmdir(temp_dir) + except FileNotFoundError: + pass + + +def write_csv(path: str, content: str): + """Helper to write CSV content to file""" + with open(path, "w") as f: + f.write(content) + + +def test_stream_valid_file(temp_csv_file): + """Test streaming from a valid CSV file""" + csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4""" - self.write_csv(csv_content) - entries = list(stream_file(self.csv_path)) - self.assertEqual(len(entries), 1) - self.assertEqual(entries[0].protocol, "obfs4") + write_csv(temp_csv_file, csv_content) + entries = list(fieldtestingcsv.parser.stream_file(temp_csv_file)) + assert len(entries) == 1 + assert entries[0].protocol == "obfs4" - def test_stream_file_with_invalid_rows(self): - """Test streaming from a file with invalid rows""" - csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT + +def test_stream_file_with_invalid_rows(temp_csv_file): + """Test streaming from a file with invalid rows""" + csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT test.csv,Mon Jan 01 12:00:00 EST 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4 test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,not_a_number,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4""" - self.write_csv(csv_content) - entries = list(stream_file(self.csv_path)) - # Both rows should be skipped due to errors - self.assertEqual(len(entries), 0) + write_csv(temp_csv_file, csv_content) + entries = list(fieldtestingcsv.parser.stream_file(temp_csv_file)) + assert len(entries) == 0 + - def test_parse_equivalent_to_stream(self): - """Test that parse_file() returns same results as stream_file()""" - csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT +def test_parse_equivalent_to_stream(temp_csv_file): + """Test that parse_file() returns same results as stream_file()""" + csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4""" - self.write_csv(csv_content) - streamed = list(stream_file(self.csv_path)) - parsed = parse_file(self.csv_path) - self.assertEqual(streamed, parsed) + write_csv(temp_csv_file, csv_content) + streamed = list(fieldtestingcsv.parser.stream_file(temp_csv_file)) + parsed = fieldtestingcsv.parse_file(temp_csv_file) + assert streamed == parsed -class TestEntryMethods(unittest.TestCase): - """Test Entry class methods.""" +@pytest.fixture +def make_entry(): + """Fixture providing an Entry factory function""" - @staticmethod - def make_entry(tunnel: str) -> Entry: - return Entry( + def _make_entry(tunnel: str) -> fieldtestingcsv.Entry: + return fieldtestingcsv.Entry( filename="test.csv", date=datetime.now(timezone.utc), asn="AS12345", @@ -230,19 +226,62 @@ class TestEntryMethods(unittest.TestCase): ping_roundtrip_max=15.0, err_message="", protocol="obfs4", + ping_target_address=None, + ndt_target_hostname=None, + ndt_target_address=None, + ndt_target_port=None, ) - def test_is_tunnel_measurement(self): - """Test is_tunnel_measurement() method""" - self.assertTrue(self.make_entry("tunnel").is_tunnel_measurement()) - self.assertFalse(self.make_entry("baseline").is_tunnel_measurement()) - self.assertTrue(self.make_entry("ERROR/tunnel").is_tunnel_measurement()) + return _make_entry + + +def test_is_tunnel_measurement(make_entry): + """Test is_tunnel_measurement() method""" + assert make_entry("tunnel").is_tunnel_measurement() is True + assert make_entry("baseline").is_tunnel_measurement() is False + assert make_entry("ERROR/tunnel").is_tunnel_measurement() is True + + +def test_is_tunnel_error_measurement(make_entry): + """Test is_tunnel_error_measurement() method""" + assert make_entry("ERROR/tunnel").is_tunnel_error_measurement() is True + assert make_entry("tunnel").is_tunnel_error_measurement() is False + + +def test_streamer_stream_method(temp_csv_file): + """Test that Streamer.stream() works correctly""" + csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT +test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4""" + + write_csv(temp_csv_file, csv_content) + + # Test using the Streamer class + streamer = fieldtestingcsv.Streamer() + entries = list(streamer.stream(temp_csv_file)) + + assert len(entries) == 1 + assert entries[0].protocol == "obfs4" + assert entries[0].filename == "test.csv" + assert entries[0].tunnel == "tunnel" + + +def test_streamer_stream_invalid_file(): + """Test that Streamer.stream() handles invalid files appropriately""" + streamer = fieldtestingcsv.Streamer() + + with pytest.raises(FileNotFoundError): + list(streamer.stream("nonexistent_file.csv")) + + +def test_streamer_stream_matches_direct_parse(temp_csv_file): + """Test that Streamer.stream() returns same results as parse_file()""" + csv_content = """filename,date,asn,isp,est_city,user,region,server_fqdn,server_ip,mobile,tunnel,throughput_download,throughput_upload,latency_download,latency_upload,retransmission_download,retransmission_upload,ping_packets_loss,ping_roundtrip_min,ping_roundtrip_avg,ping_roundtrip_max,err_message,PT +test.csv,Mon Jan 01 12:00:00 UTC 2024,AS12345,TestISP,TestCity,user1,TestRegion,ndt.example.com,1.2.3.4,false,tunnel,100.0,50.0,20.0,25.0,0.01,0.02,0.0,10.0,12.0,15.0,,obfs4""" - def test_is_tunnel_error_measurement(self): - """Test is_tunnel_error_measurement() method""" - self.assertTrue(self.make_entry("ERROR/tunnel").is_tunnel_error_measurement()) - self.assertFalse(self.make_entry("tunnel").is_tunnel_error_measurement()) + write_csv(temp_csv_file, csv_content) + streamer = fieldtestingcsv.Streamer() + streamed_entries = list(streamer.stream(temp_csv_file)) + parsed_entries = fieldtestingcsv.parse_file(temp_csv_file) -if __name__ == "__main__": - unittest.main() + assert streamed_entries == parsed_entries -- GitLab From d0a078e29bf23bdfcb1c84824ba9973318b6a592 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 23 Feb 2025 17:37:12 +0100 Subject: [PATCH 56/75] refactor(oonireport): use protocols and pytest --- .../fieldtestingcsv/__init__.py | 9 +- aggregatetunnelmetrics/oonireport/__init__.py | 2 + aggregatetunnelmetrics/oonireport/__main__.py | 4 +- .../oonireport/collector.py | 97 +--- aggregatetunnelmetrics/oonireport/model.py | 100 +---- aggregatetunnelmetrics/spec/oonicollector.py | 6 +- .../oonireport/__init__.py | 2 + .../oonireport/test_collector.py | 404 +++++++---------- .../oonireport/test_load.py | 238 +++++----- .../oonireport/test_main.py | 423 +++++++----------- .../oonireport/test_model.py | 144 ------ .../spec/test_oonicollector.py | 43 +- 12 files changed, 540 insertions(+), 932 deletions(-) delete mode 100644 tests/aggregatetunnelmetrics/oonireport/test_model.py diff --git a/aggregatetunnelmetrics/fieldtestingcsv/__init__.py b/aggregatetunnelmetrics/fieldtestingcsv/__init__.py index c822140..8a22f7e 100644 --- a/aggregatetunnelmetrics/fieldtestingcsv/__init__.py +++ b/aggregatetunnelmetrics/fieldtestingcsv/__init__.py @@ -1,10 +1,9 @@ """ -Field-Testing CSV -================= +Field-Testing CSV Streamer +========================== -This package contains code for parsing the field-testing CSV data format. - -See https://0xacab.org/leap/solitech-compose-client/-/blob/main/images/obfsvpn-openvpn-client/start.sh. +This package contains Streamer that implements the +spec.fieldtesting.Streamer protocol. """ # SPDX-License-Identifier: GPL-3.0-or-later diff --git a/aggregatetunnelmetrics/oonireport/__init__.py b/aggregatetunnelmetrics/oonireport/__init__.py index e1f1b26..6c2ff35 100644 --- a/aggregatetunnelmetrics/oonireport/__init__.py +++ b/aggregatetunnelmetrics/oonireport/__init__.py @@ -5,6 +5,8 @@ OONI Report Submission Library This package provides classes for interacting with OONI collectors servers and submitting OONI measurements to them. +The CollectorClient class implements spec.oonicollector.Client. + See https://github.com/ooni/spec/blob/master/backends/bk-003-collector.md. We implement the OONI collector protocol v3.0.0. diff --git a/aggregatetunnelmetrics/oonireport/__main__.py b/aggregatetunnelmetrics/oonireport/__main__.py index 87e8a6f..0f656ca 100644 --- a/aggregatetunnelmetrics/oonireport/__main__.py +++ b/aggregatetunnelmetrics/oonireport/__main__.py @@ -88,8 +88,8 @@ def main(args: Optional[List[str]] = None) -> int: report_id = client.create_report_from_measurement(measurement) print(f"oonireport: created report {report_id}", file=sys.stderr) - # Append measurement to the report - measurement_uid = client.update_report(report_id, measurement) + # Submit measurement to the report + measurement_uid = client.submit_measurement(report_id, measurement) if not measurement_uid: measurement_uid = "N/A" print( diff --git a/aggregatetunnelmetrics/oonireport/collector.py b/aggregatetunnelmetrics/oonireport/collector.py index db519c6..e91ed0f 100644 --- a/aggregatetunnelmetrics/oonireport/collector.py +++ b/aggregatetunnelmetrics/oonireport/collector.py @@ -9,24 +9,21 @@ Please, prefer importing the `oonireport` package directly. # SPDX-License-Identifier: GPL-3.0-or-later from dataclasses import dataclass -from datetime import datetime -from typing import Optional from urllib.parse import urljoin import json import urllib.request -from .model import APIError, Measurement, datetime_to_ooni_format +from ..spec import oonicollector +from .model import APIError, Measurement -@dataclass -class CollectorConfig: - """Contains configuration for the OONI collector client.""" - collector_base_url: str # e.g., "https://api.ooni.io/" - timeout: float = 30.0 +CollectorConfig = oonicollector.Config +"""Type alias for the `oonicollector.Config` type.""" +@dataclass(frozen=True) class CollectorClient: """ Implements the OONI collector client protocol. @@ -45,54 +42,20 @@ class CollectorClient: # # However, for now, we're good without implementing retries. - def __init__(self, config: CollectorConfig): - self.config = config + config: CollectorConfig - def create_report_from_measurement(self, measurement: Measurement) -> str: + def create_report_from_measurement( + self, measurement: Measurement + ) -> oonicollector.ReportID: """Convenience method to create report from existing OONI Measurement.""" - return self.create_report( - test_name=measurement.test_name, - test_version=measurement.test_version, - software_name=measurement.software_name, - software_version=measurement.software_version, - probe_asn=measurement.probe_asn, - probe_cc=measurement.probe_cc, - test_start_time=measurement.test_start_time, - ) + return self.create_report(measurement.as_open_report_request()) def create_report( - self, - test_name: str, - test_version: str, - software_name: str, - software_version: str, - probe_asn: str, - probe_cc: str, - test_start_time: datetime, - ) -> str: - """ - Creates a new report and returns the report ID. - - Returns: - Report ID to use for submitting measurements. - - Raises: - model.APIError: in case of failure. - """ - report = { - "data_format_version": "0.2.0", - "format": "json", - "probe_asn": probe_asn, - "probe_cc": probe_cc, - "software_name": software_name, - "software_version": software_version, - "test_name": test_name, - "test_start_time": datetime_to_ooni_format(test_start_time), - "test_version": test_version, - } - - data = json.dumps(report).encode("utf-8") - req = urllib.request.Request( + self, req: oonicollector.OpenReportRequest + ) -> oonicollector.ReportID: + """Creates a new report and returns the report ID.""" + data = json.dumps(req.as_dict()).encode("utf-8") + httpreq = urllib.request.Request( urljoin(self.config.collector_base_url, "report"), data=data, headers={"Content-Type": "application/json"}, @@ -100,7 +63,7 @@ class CollectorClient: ) try: - with urllib.request.urlopen(req, timeout=self.config.timeout) as resp: + with urllib.request.urlopen(httpreq, timeout=self.config.timeout) as resp: if resp.status != 200: raise APIError(f"unexpected status: {resp.status}") response = json.loads(resp.read().decode()) @@ -111,27 +74,13 @@ class CollectorClient: except Exception as exc: raise APIError(f"HTTP error: {exc}") - def update_report( + def submit_measurement( self, - report_id: str, - measurement: Measurement, - ) -> Optional[str]: - """ - Update a report by adding a measurement. - - Args: - report_id: The ID returned by create_report(). - measurement: The measurement to submit. - - Returns: - The measurement_uid, if provided by server, otherwise None. - - Raises: - model.APIError: in case of failure. - """ - measurement = measurement.with_report_id( - report_id - ) # Required for Explorer visualization + rid: oonicollector.ReportID, + m: oonicollector.Measurement, + ) -> oonicollector.MaybeMeasurementID: + """Submit measurement and returns the measurement ID.""" + measurement = m.with_report_id(rid) # Required for Explorer visualization data = json.dumps( { "format": "json", @@ -140,7 +89,7 @@ class CollectorClient: ).encode("utf-8") req = urllib.request.Request( - urljoin(self.config.collector_base_url, f"report/{report_id}"), + urljoin(self.config.collector_base_url, f"report/{rid}"), data=data, headers={"Content-Type": "application/json"}, method="POST", diff --git a/aggregatetunnelmetrics/oonireport/model.py b/aggregatetunnelmetrics/oonireport/model.py index 3fae6d2..6c620c6 100644 --- a/aggregatetunnelmetrics/oonireport/model.py +++ b/aggregatetunnelmetrics/oonireport/model.py @@ -8,105 +8,19 @@ Please, prefer importing the `oonireport` package directly. # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass, field, replace -from datetime import datetime, timezone -from typing import Any, Dict, Protocol +from ..spec import oonicollector class APIError(Exception): """Raised when there are OONI API errors.""" -class TestKeys(Protocol): - """ - Models the OONI measurement test keys. +TestKeys = oonicollector.MeasurementTestKeys +"""Alias for the `oonicollector.MeasurementTestKeys` type.""" - Methods: - as_dict: Converts the test keys to a JSON-serializable dict. - """ - def as_dict(self) -> Dict[str, Any]: ... +Measurement = oonicollector.Measurement +"""Alias for the `oonicollector.Measurement` type.""" - -@dataclass(frozen=True) -class Measurement: - """Models the OONI measurement envelope.""" - - # mandatory fields - annotations: Dict[str, str] - data_format_version: str - input: str # e.g., {protocol}://{provider}/?{query_string} - measurement_start_time: datetime - probe_asn: str # Format: ^AS[0-9]+$ - probe_cc: str # Format: ^[A-Z]{2}$ - software_name: str - software_version: str - test_keys: TestKeys - test_name: str - test_runtime: float - test_start_time: datetime - test_version: str - - # Fields emitted with possibly default values - probe_ip: str = "127.0.0.1" - report_id: str = "" - - # Optional fields - options: list[str] = field(default_factory=list) - probe_network_name: str = "" - resolver_asn: str = "" - resolver_cc: str = "" - resolver_ip: str = "" - resolver_network_name: str = "" - test_helpers: Dict[str, Any] = field(default_factory=dict) - - def as_dict(self) -> Dict: - """Converts the measurement to a JSON-serializable dict""" - - # Add mandatory fields - dct = { - "annotations": self.annotations, - "data_format_version": self.data_format_version, - "input": self.input, - "measurement_start_time": datetime_to_ooni_format( - self.measurement_start_time - ), - "probe_asn": self.probe_asn, - "probe_cc": self.probe_cc, - "software_name": self.software_name, - "software_version": self.software_version, - "test_keys": self.test_keys.as_dict(), - "test_name": self.test_name, - "test_runtime": self.test_runtime, - "test_start_time": datetime_to_ooni_format(self.test_start_time), - "test_version": self.test_version, - } - - # Fields emitted with possibly default values - dct["probe_ip"] = self.probe_ip if self.probe_ip else "127.0.0.1" - dct["report_id"] = self.report_id - - # Add optional fields - if self.options: - dct["options"] = self.options - if self.probe_network_name: - dct["probe_network_name"] = self.probe_network_name - if self.resolver_asn: - dct["resolver_asn"] = self.resolver_asn - if self.resolver_ip: - dct["resolver_ip"] = self.resolver_ip - if self.resolver_network_name: - dct["resolver_network_name"] = self.resolver_network_name - if self.test_helpers: - dct["test_helpers"] = self.test_helpers - - return dct - - def with_report_id(self, report_id: str) -> "Measurement": - """Creates a new Measurement instance with the given report_id.""" - return replace(self, report_id=report_id) - - -def datetime_to_ooni_format(dt: datetime) -> str: - """Converts a datetime to OONI's datetime format (YYYY-mm-dd HH:MM:SS).""" - return dt.astimezone(timezone.utc).strftime("%Y-%m-%d %H:%M:%S") +datetime_to_ooni_format = oonicollector.format_datetime +"""Alias for the `oonicollector.format_datetime` function.""" diff --git a/aggregatetunnelmetrics/spec/oonicollector.py b/aggregatetunnelmetrics/spec/oonicollector.py index 5784d71..787c0f5 100644 --- a/aggregatetunnelmetrics/spec/oonicollector.py +++ b/aggregatetunnelmetrics/spec/oonicollector.py @@ -5,7 +5,7 @@ OONI Collector Model This module contains the OONI collector data model. Classes: - TestKeys: models the OONI measurement experiment-specific test keys. + MeasurementTestKeys: models the OONI measurement experiment-specific test keys. Measurement: models the OONI measurement envelope. Config: configures the OONI collector client. OpenReportRequest: contains data required to open an OONI report. @@ -24,7 +24,7 @@ from typing import Any, Protocol, runtime_checkable @runtime_checkable -class TestKeys(Protocol): +class MeasurementTestKeys(Protocol): """ Models the OONI measurement test keys. @@ -55,7 +55,7 @@ class Measurement: probe_cc: str # Format: ^[A-Z]{2}$ software_name: str software_version: str - test_keys: TestKeys + test_keys: MeasurementTestKeys test_name: str test_runtime: float test_start_time: datetime diff --git a/tests/aggregatetunnelmetrics/oonireport/__init__.py b/tests/aggregatetunnelmetrics/oonireport/__init__.py index 8b13789..51a6999 100644 --- a/tests/aggregatetunnelmetrics/oonireport/__init__.py +++ b/tests/aggregatetunnelmetrics/oonireport/__init__.py @@ -1 +1,3 @@ +"""Tests for the oonireport package.""" +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/tests/aggregatetunnelmetrics/oonireport/test_collector.py b/tests/aggregatetunnelmetrics/oonireport/test_collector.py index df2ec90..600e44b 100644 --- a/tests/aggregatetunnelmetrics/oonireport/test_collector.py +++ b/tests/aggregatetunnelmetrics/oonireport/test_collector.py @@ -2,11 +2,11 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from unittest.mock import patch -from datetime import datetime, timezone +from datetime import datetime, timezone import json -import unittest +import pytest +from unittest.mock import patch from aggregatetunnelmetrics.oonireport import ( APIError, @@ -14,6 +14,14 @@ from aggregatetunnelmetrics.oonireport import ( CollectorConfig, Measurement, ) +from aggregatetunnelmetrics.spec.oonicollector import Client, OpenReportRequest + + +class SimpleTestKeys: + """Simple TestKeys implementation for testing.""" + + def as_dict(self): + return {"simple": "test"} class MockResponse: @@ -33,238 +41,164 @@ class MockResponse: pass -class SimpleTestKeys: - """Simple TestKeys implementation for testing.""" +@pytest.fixture +def config(): + """Returns a basic collector configuration.""" + return CollectorConfig(collector_base_url="https://example.org", timeout=30.0) - def as_dict(self): - return {"simple": "test"} +@pytest.fixture +def client(config): + """Returns a configured collector client.""" + return CollectorClient(config) + + +@pytest.fixture +def test_datetime(): + """Returns a fixed datetime for testing.""" + return datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) + + +@pytest.fixture +def test_report_request(test_datetime): + """Returns a test OpenReportRequest.""" + return OpenReportRequest( + test_name="web_connectivity", + test_version="0.0.1", + software_name="ooniprobe", + software_version="3.0.0", + probe_asn="AS12345", + probe_cc="IT", + test_start_time=test_datetime, + ) + + +def test_collector_client_implements_protocol(client): + """Test that CollectorClient implements the oonicollector.Client protocol.""" + assert isinstance(client, Client) + + +@patch("urllib.request.urlopen") +def test_create_report_success(mock_urlopen, client, test_report_request): + """Test successful report creation.""" + mock_urlopen.return_value = MockResponse(200, {"report_id": "test_report_id"}) + + report_id = client.create_report(test_report_request) + + assert report_id == "test_report_id" + + +@patch("urllib.request.urlopen") +def test_create_report_http_error(mock_urlopen, client, test_report_request): + """Test report creation with HTTP error.""" + mock_urlopen.side_effect = Exception("Connection failed") + + with pytest.raises(APIError, match="HTTP error"): + client.create_report(test_report_request) + + +@patch("urllib.request.urlopen") +def test_create_report_invalid_response(mock_urlopen, client, test_report_request): + """Test report creation with invalid response.""" + mock_urlopen.return_value = MockResponse( + 200, {"wrong_field": "value"} # Missing report_id + ) + + with pytest.raises(APIError, match="missing report_id"): + client.create_report(test_report_request) + + +@pytest.fixture +def test_measurement(test_datetime): + """Returns a test measurement.""" + return Measurement( + annotations={}, + data_format_version="0.2.0", + input="https://example.com", + measurement_start_time=test_datetime, + probe_asn="AS12345", + probe_cc="IT", + software_name="ooniprobe", + software_version="3.0.0", + test_keys=SimpleTestKeys(), + test_name="web_connectivity", + test_runtime=1.0, + test_start_time=test_datetime, + test_version="0.0.1", + ) + + +@patch("urllib.request.urlopen") +def test_submit_measurement_success(mock_urlopen, client, test_measurement): + """Test successful measurement submission.""" + mock_urlopen.return_value = MockResponse( + 200, {"measurement_uid": "test_measurement_uid"} + ) + + measurement_id = client.submit_measurement("test_report_id", test_measurement) + + assert measurement_id == "test_measurement_uid" + + +@patch("urllib.request.urlopen") +def test_submit_measurement_http_error(mock_urlopen, client, test_measurement): + """Test measurement submission with HTTP error.""" + mock_urlopen.side_effect = Exception("Connection failed") + + with pytest.raises(APIError, match="HTTP error"): + client.submit_measurement("test_report_id", test_measurement) + + +@patch("urllib.request.urlopen") +def test_create_report_api_error(mock_urlopen, client, test_report_request): + """Test report creation with non-200 status code.""" + mock_urlopen.return_value = MockResponse(500, {"error": "Internal Server Error"}) + + with pytest.raises(APIError, match="unexpected status: 500"): + client.create_report(test_report_request) + + +@patch("urllib.request.urlopen") +def test_submit_measurement_api_error(mock_urlopen, client, test_measurement): + """Test measurement submission with non-200 status code.""" + mock_urlopen.return_value = MockResponse(403, {"error": "Forbidden"}) + + with pytest.raises(APIError, match="unexpected status: 403"): + client.submit_measurement("test_report_id", test_measurement) + + +@patch("urllib.request.urlopen") +def test_submit_measurement_invalid_json(mock_urlopen, client, test_measurement): + """Test measurement submission with invalid JSON response.""" + + class InvalidJSONResponse: + status = 200 + + def read(self): + return b"invalid json" + + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + mock_urlopen.return_value = InvalidJSONResponse() + + # Should return None when JSON decoding fails but status is 200 + result = client.submit_measurement("test_report_id", test_measurement) + assert result is None + + +@patch("urllib.request.urlopen") +def test_create_report_from_measurement(mock_urlopen, client, test_measurement): + """Test creating a report from an existing measurement.""" + mock_urlopen.return_value = MockResponse(200, {"report_id": "test_report_id"}) + + report_id = client.create_report_from_measurement(test_measurement) -class TestCollectorClient(unittest.TestCase): - def setUp(self): - self.config = CollectorConfig( - collector_base_url="https://example.org", timeout=30.0 - ) - self.client = CollectorClient(self.config) - self.dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) - - @patch("urllib.request.urlopen") - def test_create_report_success(self, mock_urlopen): - mock_urlopen.return_value = MockResponse(200, {"report_id": "test_report_id"}) - - report_id = self.client.create_report( - test_name="web_connectivity", - test_version="0.0.1", - software_name="ooniprobe", - software_version="3.0.0", - probe_asn="AS12345", - probe_cc="IT", - test_start_time=self.dt, - ) - - self.assertEqual(report_id, "test_report_id") - - @patch("urllib.request.urlopen") - def test_create_report_http_error(self, mock_urlopen): - mock_urlopen.side_effect = Exception("Connection failed") - - with self.assertRaises(APIError) as cm: - self.client.create_report( - test_name="web_connectivity", - test_version="0.0.1", - software_name="ooniprobe", - software_version="3.0.0", - probe_asn="AS12345", - probe_cc="IT", - test_start_time=self.dt, - ) - - self.assertIn("HTTP error", str(cm.exception)) - - @patch("urllib.request.urlopen") - def test_create_report_invalid_response(self, mock_urlopen): - mock_urlopen.return_value = MockResponse( - 200, {"wrong_field": "value"} # Missing report_id - ) - - with self.assertRaises(APIError) as cm: - self.client.create_report( - test_name="web_connectivity", - test_version="0.0.1", - software_name="ooniprobe", - software_version="3.0.0", - probe_asn="AS12345", - probe_cc="IT", - test_start_time=self.dt, - ) - - self.assertIn("missing report_id", str(cm.exception)) - - @patch("urllib.request.urlopen") - def test_update_report_success(self, mock_urlopen): - mock_urlopen.return_value = MockResponse( - 200, {"measurement_uid": "test_measurement_uid"} - ) - - measurement = Measurement( - annotations={}, - data_format_version="0.2.0", - input="https://example.com", - measurement_start_time=self.dt, - probe_asn="AS12345", - probe_cc="IT", - software_name="ooniprobe", - software_version="3.0.0", - test_keys=SimpleTestKeys(), - test_name="web_connectivity", - test_runtime=1.0, - test_start_time=datetime.now(timezone.utc), - test_version="0.0.1", - ) - - measurement_id = self.client.update_report("test_report_id", measurement) - - self.assertEqual(measurement_id, "test_measurement_uid") - - @patch("urllib.request.urlopen") - def test_update_report_http_error(self, mock_urlopen): - mock_urlopen.side_effect = Exception("Connection failed") - - measurement = Measurement( - annotations={}, - data_format_version="0.2.0", - input="https://example.com", - measurement_start_time=self.dt, - probe_asn="AS12345", - probe_cc="IT", - software_name="ooniprobe", - software_version="3.0.0", - test_keys=SimpleTestKeys(), - test_name="web_connectivity", - test_runtime=1.0, - test_start_time=datetime.now(timezone.utc), - test_version="0.0.1", - ) - - with self.assertRaises(APIError) as cm: - self.client.update_report("test_report_id", measurement) - - self.assertIn("HTTP error", str(cm.exception)) - - @patch("urllib.request.urlopen") - def test_create_report_from_measurement(self, mock_urlopen): - """Test create_report_from_measurement properly calls create_report.""" - mock_urlopen.return_value = MockResponse(200, {"report_id": "test_report_id"}) - - # Create a measurement instance - measurement = Measurement( - annotations={}, - data_format_version="0.2.0", - input="https://example.com", - measurement_start_time=self.dt, - probe_asn="AS12345", - probe_cc="IT", - software_name="ooniprobe", - software_version="3.0.0", - test_keys=SimpleTestKeys(), - test_name="web_connectivity", - test_runtime=1.0, - test_start_time=self.dt, - test_version="0.0.1", - ) - - report_id = self.client.create_report_from_measurement(measurement) - - self.assertEqual(report_id, "test_report_id") - # Verify the correct URL was called - self.assertEqual(mock_urlopen.call_count, 1) - request = mock_urlopen.call_args[0][0] - self.assertEqual(request.full_url, "https://example.org/report") - - @patch("urllib.request.urlopen") - def test_create_report_api_error(self, mock_urlopen): - """Test create_report handles non-200 status code.""" - mock_urlopen.return_value = MockResponse( - 500, {"error": "Internal Server Error"} - ) - - with self.assertRaises(APIError) as cm: - self.client.create_report( - test_name="web_connectivity", - test_version="0.0.1", - software_name="ooniprobe", - software_version="3.0.0", - probe_asn="AS12345", - probe_cc="IT", - test_start_time=self.dt, - ) - - self.assertIn("unexpected status: 500", str(cm.exception)) - - @patch("urllib.request.urlopen") - def test_update_report_api_error(self, mock_urlopen): - """Test update_report handles non-200 status code.""" - mock_urlopen.return_value = MockResponse(403, {"error": "Forbidden"}) - - measurement = Measurement( - annotations={}, - data_format_version="0.2.0", - input="https://example.com", - measurement_start_time=self.dt, - probe_asn="AS12345", - probe_cc="IT", - software_name="ooniprobe", - software_version="3.0.0", - test_keys=SimpleTestKeys(), - test_name="web_connectivity", - test_runtime=1.0, - test_start_time=self.dt, - test_version="0.0.1", - ) - - with self.assertRaises(APIError) as cm: - self.client.update_report("test_report_id", measurement) - - self.assertIn("unexpected status: 403", str(cm.exception)) - - @patch("urllib.request.urlopen") - def test_update_report_invalid_json(self, mock_urlopen): - """Test update_report handles invalid JSON response.""" - - class InvalidJSONResponse: - status = 200 - - def read(self): - return b"invalid json" - - def __enter__(self): - return self - - def __exit__(self, *args): - pass - - mock_urlopen.return_value = InvalidJSONResponse() - - measurement = Measurement( - annotations={}, - data_format_version="0.2.0", - input="https://example.com", - measurement_start_time=self.dt, - probe_asn="AS12345", - probe_cc="IT", - software_name="ooniprobe", - software_version="3.0.0", - test_keys=SimpleTestKeys(), - test_name="web_connectivity", - test_runtime=1.0, - test_start_time=self.dt, - test_version="0.0.1", - ) - - # Should return None when JSON decoding fails but status is 200 - result = self.client.update_report("test_report_id", measurement) - self.assertIsNone(result) - - -if __name__ == "__main__": - unittest.main() + assert report_id == "test_report_id" + # Verify that the correct request was made + assert mock_urlopen.call_count == 1 + request = mock_urlopen.call_args[0][0] + assert request.full_url == "https://example.org/report" diff --git a/tests/aggregatetunnelmetrics/oonireport/test_load.py b/tests/aggregatetunnelmetrics/oonireport/test_load.py index e4e2211..a70a72b 100644 --- a/tests/aggregatetunnelmetrics/oonireport/test_load.py +++ b/tests/aggregatetunnelmetrics/oonireport/test_load.py @@ -1,11 +1,9 @@ -"""Tests for measurement loading functionality.""" +"""Tests for the oonireport.load module.""" # SPDX-License-Identifier: GPL-3.0-or-later import json -import os -import tempfile -import unittest +import pytest from aggregatetunnelmetrics.oonireport.load import ( load_measurements, @@ -30,125 +28,113 @@ VALID_MEASUREMENT = { } -class TestMeasurementValidation(unittest.TestCase): - """Tests for validating individual measurements.""" - - def test_missing_required_fields(self): - """Test that missing required fields raise ValueError.""" - # Remove a required field - invalid_measurement = VALID_MEASUREMENT.copy() - del invalid_measurement["probe_asn"] - - with self.assertRaises(ValueError) as cm: - load_single_measurement(json.dumps(invalid_measurement)) - - self.assertIn("Missing required fields", str(cm.exception)) - self.assertIn("probe_asn", str(cm.exception)) - - def test_invalid_date_format_measurement_start_time(self): - """Test that invalid measurement_start_time format raises ValueError.""" - invalid_measurement = VALID_MEASUREMENT.copy() - invalid_measurement["measurement_start_time"] = ( - "2023-13-32 25:61:61" # Invalid date - ) - - with self.assertRaises(ValueError) as cm: - load_single_measurement(json.dumps(invalid_measurement)) - - self.assertIn("Invalid datetime format", str(cm.exception)) - - def test_invalid_date_format_test_start_time(self): - """Test that invalid test_start_time format raises ValueError.""" - invalid_measurement = VALID_MEASUREMENT.copy() - invalid_measurement["test_start_time"] = "not-a-date" # Invalid date - - with self.assertRaises(ValueError) as cm: - load_single_measurement(json.dumps(invalid_measurement)) - - self.assertIn("Invalid datetime format", str(cm.exception)) - - -class TestFileLoading(unittest.TestCase): - """Tests for loading measurements from files.""" - - def setUp(self): - """Create a temporary file for testing.""" - self.temp_file = tempfile.NamedTemporaryFile(mode="w", delete=False) - - def tearDown(self): - """Clean up temporary file.""" - os.unlink(self.temp_file.name) - - def test_load_valid_measurements(self): - """Test loading multiple valid measurements.""" - # Write two valid measurements to file - with open(self.temp_file.name, "w") as f: - f.write(json.dumps(VALID_MEASUREMENT) + "\n") - f.write(json.dumps(VALID_MEASUREMENT) + "\n") - - measurements = load_measurements(self.temp_file.name) - - self.assertEqual(len(measurements), 2) - for m in measurements: - # Test all fields - self.assertEqual(m.annotations, {}) - self.assertEqual(m.data_format_version, "0.2.0") - self.assertEqual(m.input, "https://example.com/") - self.assertEqual(m.probe_asn, "AS12345") - self.assertEqual(m.probe_cc, "IT") - self.assertEqual(m.probe_ip, "127.0.0.1") - self.assertEqual(m.report_id, "") - self.assertEqual(m.software_name, "ooniprobe") - self.assertEqual(m.software_version, "3.0.0") - self.assertEqual(m.test_keys.as_dict(), {"simple": "test"}) - self.assertEqual(m.test_name, "web_connectivity") - self.assertEqual(m.test_runtime, 1.0) - self.assertEqual(m.test_version, "0.0.1") - - # Verify we're only loading expected fields - expected_fields = { - "annotations", - "data_format_version", - "input", - "measurement_start_time", - "probe_asn", - "probe_cc", - "probe_ip", - "report_id", - "software_name", - "software_version", - "test_keys", - "test_name", - "test_runtime", - "test_start_time", - "test_version", - } - self.assertEqual(set(m.as_dict().keys()), expected_fields) - - def test_load_empty_file(self): - """Test loading from an empty file.""" - with open(self.temp_file.name, "w") as f: - f.write("") - - measurements = load_measurements(self.temp_file.name) - self.assertEqual(len(measurements), 0) - - def test_load_file_with_blank_lines(self): - """Test loading file with blank lines between measurements.""" - with open(self.temp_file.name, "w") as f: - f.write(json.dumps(VALID_MEASUREMENT) + "\n") - f.write("\n") # blank line - f.write(json.dumps(VALID_MEASUREMENT) + "\n") - f.write("\n") # blank line - - measurements = load_measurements(self.temp_file.name) - self.assertEqual(len(measurements), 2) - - def test_file_not_found(self): - """Test attempting to load from non-existent file.""" - with self.assertRaises(OSError): - load_measurements("nonexistent_file.json") - - -if __name__ == "__main__": - unittest.main() +@pytest.fixture +def temp_measurement_file(tmp_path): + """Create a temporary file for testing.""" + measurement_file = tmp_path / "test_measurements.json" + yield measurement_file + # Cleanup happens automatically with tmp_path + + +def test_missing_required_fields(): + """Test that missing required fields raise ValueError.""" + # Remove a required field + invalid_measurement = VALID_MEASUREMENT.copy() + del invalid_measurement["probe_asn"] + + with pytest.raises(ValueError, match="Missing required fields.*probe_asn"): + load_single_measurement(json.dumps(invalid_measurement)) + + +def test_invalid_date_format_measurement_start_time(): + """Test that invalid measurement_start_time format raises ValueError.""" + invalid_measurement = VALID_MEASUREMENT.copy() + invalid_measurement["measurement_start_time"] = ( + "2023-13-32 25:61:61" # Invalid date + ) + + with pytest.raises(ValueError, match="Invalid datetime format"): + load_single_measurement(json.dumps(invalid_measurement)) + + +def test_invalid_date_format_test_start_time(): + """Test that invalid test_start_time format raises ValueError.""" + invalid_measurement = VALID_MEASUREMENT.copy() + invalid_measurement["test_start_time"] = "not-a-date" # Invalid date + + with pytest.raises(ValueError, match="Invalid datetime format"): + load_single_measurement(json.dumps(invalid_measurement)) + + +def test_load_valid_measurements(temp_measurement_file): + """Test loading multiple valid measurements.""" + # Write two valid measurements to file + with open(temp_measurement_file, "w") as f: + f.write(json.dumps(VALID_MEASUREMENT) + "\n") + f.write(json.dumps(VALID_MEASUREMENT) + "\n") + + measurements = load_measurements(temp_measurement_file) + + assert len(measurements) == 2 + for m in measurements: + # Test all fields + assert m.annotations == {} + assert m.data_format_version == "0.2.0" + assert m.input == "https://example.com/" + assert m.probe_asn == "AS12345" + assert m.probe_cc == "IT" + assert m.probe_ip == "127.0.0.1" + assert m.report_id == "" + assert m.software_name == "ooniprobe" + assert m.software_version == "3.0.0" + assert m.test_keys.as_dict() == {"simple": "test"} + assert m.test_name == "web_connectivity" + assert m.test_runtime == 1.0 + assert m.test_version == "0.0.1" + + # Verify we're only loading expected fields + expected_fields = { + "annotations", + "data_format_version", + "input", + "measurement_start_time", + "probe_asn", + "probe_cc", + "probe_ip", + "report_id", + "software_name", + "software_version", + "test_keys", + "test_name", + "test_runtime", + "test_start_time", + "test_version", + } + assert set(m.as_dict().keys()) == expected_fields + + +def test_load_empty_file(temp_measurement_file): + """Test loading from an empty file.""" + with open(temp_measurement_file, "w") as f: + f.write("") + + measurements = load_measurements(temp_measurement_file) + assert len(measurements) == 0 + + +def test_load_file_with_blank_lines(temp_measurement_file): + """Test loading file with blank lines between measurements.""" + with open(temp_measurement_file, "w") as f: + f.write(json.dumps(VALID_MEASUREMENT) + "\n") + f.write("\n") # blank line + f.write(json.dumps(VALID_MEASUREMENT) + "\n") + f.write("\n") # blank line + + measurements = load_measurements(temp_measurement_file) + assert len(measurements) == 2 + + +def test_file_not_found(): + """Test attempting to load from non-existent file.""" + with pytest.raises(OSError): + load_measurements("nonexistent_file.json") diff --git a/tests/aggregatetunnelmetrics/oonireport/test_main.py b/tests/aggregatetunnelmetrics/oonireport/test_main.py index f989385..cf80329 100644 --- a/tests/aggregatetunnelmetrics/oonireport/test_main.py +++ b/tests/aggregatetunnelmetrics/oonireport/test_main.py @@ -3,12 +3,10 @@ # SPDX-License-Identifier: GPL-3.0-or-later from datetime import datetime, timezone -from unittest.mock import ANY, patch - import json -import tempfile import os -import unittest +import pytest +from unittest.mock import ANY, patch from aggregatetunnelmetrics.oonireport.__main__ import main from aggregatetunnelmetrics.oonireport.model import Measurement @@ -21,249 +19,176 @@ class SimpleTestKeys: return {"simple": "test"} -class TestMain(unittest.TestCase): - def setUp(self): - self.dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) - self.valid_measurement = Measurement( - annotations={}, - data_format_version="0.2.0", - input="https://example.com", - measurement_start_time=self.dt, - probe_asn="AS12345", - probe_cc="IT", - software_name="ooniprobe", - software_version="3.0.0", - test_keys=SimpleTestKeys(), - test_name="web_connectivity", - test_runtime=1.0, - test_start_time=self.dt, - test_version="0.0.1", - ) - - @patch("sys.stderr") - def test_main_no_args(self, mock_stderr): - """Test main with no arguments.""" - with patch("sys.argv", ["oonireport"]): - with self.assertRaises(SystemExit): - main() - - @patch("sys.stderr") - @patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") - def test_main_successful_upload(self, mock_client_class, mock_stderr): - """Test successful measurement upload.""" - # Create temporary file with measurement - with tempfile.NamedTemporaryFile(mode="w", delete=False) as tf: - json.dump(self.valid_measurement.as_dict(), tf) - tf.write("\n") - tf.flush() - temp_path = tf.name - - try: - # Create mock client instance - mock_client = mock_client_class.return_value - - # Mock successful API responses - mock_client.create_report_from_measurement.return_value = "test_report_id" - mock_client.update_report.return_value = "test_measurement_uid" - - # Run main with arguments - args = ["upload", "-f", temp_path] - exit_code = main(args) - - # Verify success - self.assertEqual(exit_code, 0) - - # Verify API calls - mock_client.create_report_from_measurement.assert_called_once_with( - ANY, - ) - mock_client.update_report.assert_called_once_with( - "test_report_id", - ANY, - ) - - finally: - os.unlink(temp_path) - - @patch("sys.stderr") - @patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") - def test_main_failed_upload(self, mock_client_class, mock_stderr): - """Test failed measurement upload.""" - # Create temporary file with measurement - with tempfile.NamedTemporaryFile(mode="w", delete=False) as tf: - json.dump(self.valid_measurement.as_dict(), tf) - tf.write("\n") - tf.flush() - temp_path = tf.name - - try: - # Create mock client instance - mock_client = mock_client_class.return_value - - # Mock API failure - mock_client.create_report_from_measurement.side_effect = Exception( - "API Error" - ) - - # Run main with arguments - args = ["upload", "-f", temp_path] - exit_code = main(args) - - # Verify failure and that the method was actually called - self.assertEqual(exit_code, 1) - mock_client.create_report_from_measurement.assert_called_once_with( - ANY, - ) - - finally: - os.unlink(temp_path) - - @patch("sys.stderr") - @patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") - def test_main_dump_failed(self, mock_client_class, mock_stderr): - """Test dumping failed measurements.""" - # Create temporary file with measurement - with tempfile.NamedTemporaryFile(mode="w", delete=False) as tf: - json.dump(self.valid_measurement.as_dict(), tf) - tf.write("\n") - tf.flush() - temp_path = tf.name - - try: - # Create mock client instance - mock_client = mock_client_class.return_value - - # Mock API failure - mock_client.create_report_from_measurement.side_effect = Exception( - "API Error" - ) - - # Run main with arguments and capture stdout - with patch("sys.stdout") as mock_stdout: - args = ["upload", "-f", temp_path, "--dump-failed"] - exit_code = main(args) - - # Verify failure - self.assertEqual(exit_code, 1) - mock_client.create_report_from_measurement.assert_called_once_with( - ANY, - ) - - # Verify failed measurement was dumped - mock_stdout.write.assert_called() - - finally: - os.unlink(temp_path) - - @patch("sys.stderr") - def test_main_invalid_file(self, mock_stderr): - """Test handling of invalid measurement file.""" - args = ["upload", "-f", "nonexistent_file.json"] +@pytest.fixture +def valid_measurement(): + """Returns a valid measurement instance for testing.""" + dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) + return Measurement( + annotations={}, + data_format_version="0.2.0", + input="https://example.com", + measurement_start_time=dt, + probe_asn="AS12345", + probe_cc="IT", + software_name="ooniprobe", + software_version="3.0.0", + test_keys=SimpleTestKeys(), + test_name="web_connectivity", + test_runtime=1.0, + test_start_time=dt, + test_version="0.0.1", + ) + + +@pytest.fixture +def temp_measurement_file(tmp_path, valid_measurement): + """Creates a temporary file with a valid measurement.""" + measurement_file = tmp_path / "measurements.json" + with open(measurement_file, "w") as f: + json.dump(valid_measurement.as_dict(), f) + f.write("\n") + return measurement_file + + +@patch("sys.stderr") +def test_main_no_args(mock_stderr): + """Test main with no arguments.""" + with pytest.raises(SystemExit): + main() + + +@patch("sys.stderr") +@patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") +def test_main_successful_upload(mock_client_class, mock_stderr, temp_measurement_file): + """Test successful measurement upload.""" + # Create mock client instance + mock_client = mock_client_class.return_value + mock_client.create_report_from_measurement.return_value = "test_report_id" + mock_client.submit_measurement.return_value = "test_measurement_uid" + + # Run main with arguments + args = ["upload", "-f", str(temp_measurement_file)] + exit_code = main(args) + + # Verify success + assert exit_code == 0 + + # Verify API calls + mock_client.create_report_from_measurement.assert_called_once_with(ANY) + mock_client.submit_measurement.assert_called_once_with("test_report_id", ANY) + + +@patch("sys.stderr") +@patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") +def test_main_failed_upload(mock_client_class, mock_stderr, temp_measurement_file): + """Test failed measurement upload.""" + # Create mock client instance + mock_client = mock_client_class.return_value + mock_client.create_report_from_measurement.side_effect = Exception("API Error") + + # Run main with arguments + args = ["upload", "-f", str(temp_measurement_file)] + exit_code = main(args) + + # Verify failure and that the method was actually called + assert exit_code == 1 + mock_client.create_report_from_measurement.assert_called_once_with(ANY) + + +@patch("sys.stderr") +@patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") +def test_main_dump_failed(mock_client_class, mock_stderr, temp_measurement_file): + """Test dumping failed measurements.""" + # Create mock client instance + mock_client = mock_client_class.return_value + mock_client.create_report_from_measurement.side_effect = Exception("API Error") + + # Run main with arguments and capture stdout + with patch("sys.stdout") as mock_stdout: + args = ["upload", "-f", str(temp_measurement_file), "--dump-failed"] + exit_code = main(args) + + # Verify failure + assert exit_code == 1 + mock_client.create_report_from_measurement.assert_called_once_with(ANY) + + # Verify failed measurement was dumped + mock_stdout.write.assert_called() + + +@patch("sys.stderr") +def test_main_invalid_file(mock_stderr): + """Test handling of invalid measurement file.""" + args = ["upload", "-f", "nonexistent_file.json"] + exit_code = main(args) + assert exit_code == 1 + + +@patch("sys.stderr") +@patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") +def test_main_delete_input(mock_client_class, mock_stderr, temp_measurement_file): + """Test deleting input file after successful upload.""" + # Create mock client instance + mock_client = mock_client_class.return_value + mock_client.create_report_from_measurement.return_value = "test_report_id" + mock_client.submit_measurement.return_value = "test_measurement_uid" + + # Run main with delete flag + args = ["upload", "-f", str(temp_measurement_file), "--delete-input-file"] + exit_code = main(args) + + # Verify success and file deletion + assert exit_code == 0 + assert not os.path.exists(temp_measurement_file) + + +@patch("sys.stderr") +def test_main_unknown_command(mock_stderr): + """Test main with unknown command.""" + args = ["unknown-command"] + with pytest.raises(SystemExit, match="2"): # argparse uses exit code 2 + main(args) + + +@patch("sys.stderr") +def test_main_no_measurements(mock_stderr, tmp_path): + """Test main with empty measurement file.""" + empty_file = tmp_path / "empty.json" + empty_file.touch() + + args = ["upload", "-f", str(empty_file)] + exit_code = main(args) + assert exit_code == 1 + + +@patch("sys.stderr") +@patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") +def test_main_no_measurement_uid(mock_client_class, mock_stderr, temp_measurement_file): + """Test handling of missing measurement UID.""" + mock_client = mock_client_class.return_value + mock_client.create_report_from_measurement.return_value = "test_report_id" + mock_client.submit_measurement.return_value = None # No measurement UID + + args = ["upload", "-f", str(temp_measurement_file)] + exit_code = main(args) + + assert exit_code == 0 + mock_client.submit_measurement.assert_called_once_with("test_report_id", ANY) + + +@patch("sys.stderr") +@patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") +def test_main_unlink_failure(mock_client_class, mock_stderr, temp_measurement_file): + """Test handling of file deletion failure.""" + mock_client = mock_client_class.return_value + mock_client.create_report_from_measurement.return_value = "test_report_id" + mock_client.submit_measurement.return_value = "test_measurement_uid" + + with patch("os.unlink") as mock_unlink: + mock_unlink.side_effect = OSError("Permission denied") + + args = ["upload", "-f", str(temp_measurement_file), "--delete-input-file"] exit_code = main(args) - self.assertEqual(exit_code, 1) - - @patch("sys.stderr") - @patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") - @patch("os.unlink") - def test_main_delete_input(self, mock_unlink, mock_client_class, mock_stderr): - """Test deleting input file after successful upload.""" - # Create temporary file with measurement - with tempfile.NamedTemporaryFile(mode="w", delete=False) as tf: - json.dump(self.valid_measurement.as_dict(), tf) - tf.write("\n") - tf.flush() - temp_path = tf.name - - try: - # Create mock client instance - mock_client = mock_client_class.return_value - - # Mock successful API responses - mock_client.create_report_from_measurement.return_value = "test_report_id" - mock_client.update_report.return_value = "test_measurement_uid" - - # Run main with delete flag - args = ["upload", "-f", temp_path, "--delete-input-file"] - exit_code = main(args) - - # Verify success and file deletion - self.assertEqual(exit_code, 0) - mock_unlink.assert_called_once_with(temp_path) - - finally: - if os.path.exists(temp_path): - os.unlink(temp_path) - - @patch("sys.stderr") - def test_main_unknown_command(self, mock_stderr): - """Test main with unknown command.""" - args = ["unknown-command"] - with self.assertRaises(SystemExit) as cm: - main(args) - self.assertEqual(cm.exception.code, 2) # argparse uses exit code 2 - - @patch("sys.stderr") - def test_main_no_measurements(self, mock_stderr): - """Test main with empty measurement file.""" - with tempfile.NamedTemporaryFile(mode="w", delete=False) as tf: - temp_path = tf.name - - try: - args = ["upload", "-f", temp_path] - exit_code = main(args) - self.assertEqual(exit_code, 1) - finally: - os.unlink(temp_path) - - @patch("sys.stderr") - @patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") - def test_main_no_measurement_uid(self, mock_client_class, mock_stderr): - """Test handling of missing measurement UID.""" - with tempfile.NamedTemporaryFile(mode="w", delete=False) as tf: - json.dump(self.valid_measurement.as_dict(), tf) - tf.write("\n") - tf.flush() - temp_path = tf.name - - try: - mock_client = mock_client_class.return_value - mock_client.create_report_from_measurement.return_value = "test_report_id" - mock_client.update_report.return_value = None # No measurement UID - - args = ["upload", "-f", temp_path] - exit_code = main(args) - - self.assertEqual(exit_code, 0) - mock_client.update_report.assert_called_once_with("test_report_id", ANY) - finally: - os.unlink(temp_path) - - @patch("sys.stderr") - @patch("aggregatetunnelmetrics.oonireport.__main__.CollectorClient") - def test_main_unlink_failure(self, mock_client_class, mock_stderr): - """Test handling of file deletion failure.""" - with tempfile.NamedTemporaryFile(mode="w", delete=False) as tf: - json.dump(self.valid_measurement.as_dict(), tf) - tf.write("\n") - tf.flush() - temp_path = tf.name - - try: - mock_client = mock_client_class.return_value - mock_client.create_report_from_measurement.return_value = "test_report_id" - mock_client.update_report.return_value = "test_measurement_uid" - - with patch("os.unlink") as mock_unlink: - mock_unlink.side_effect = OSError("Permission denied") - - args = ["upload", "-f", temp_path, "--delete-input-file"] - exit_code = main(args) - - self.assertEqual(exit_code, 1) # Should fail if can't delete - mock_unlink.assert_called_once_with(temp_path) - finally: - os.unlink(temp_path) - - -if __name__ == "__main__": - unittest.main() + + assert exit_code == 1 # Should fail if can't delete + mock_unlink.assert_called_once() diff --git a/tests/aggregatetunnelmetrics/oonireport/test_model.py b/tests/aggregatetunnelmetrics/oonireport/test_model.py deleted file mode 100644 index 40a8ea4..0000000 --- a/tests/aggregatetunnelmetrics/oonireport/test_model.py +++ /dev/null @@ -1,144 +0,0 @@ -"""Tests for the model module.""" - -# SPDX-License-Identifier: GPL-3.0-or-later - -from datetime import datetime, timedelta, timezone - -import unittest - -from aggregatetunnelmetrics.oonireport import Measurement, datetime_to_ooni_format - - -class SimpleTestKeys: - """Simple TestKeys implementation for testing.""" - - def as_dict(self): - return {"simple": "test"} - - -class TestModel(unittest.TestCase): - - def test_measurement_as_dict_minimal(self): - """Test with just the mandatory fields.""" - dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) - measurement = Measurement( - annotations={"annotation_key": "value"}, - data_format_version="0.2.0", - input="https://example.com", - measurement_start_time=dt, - probe_asn="AS12345", - probe_cc="IT", - software_name="ooniprobe", - software_version="3.0.0", - test_keys=SimpleTestKeys(), - test_name="web_connectivity", - test_runtime=1.0, - test_start_time=dt, - test_version="0.0.1", - ) - - data = measurement.as_dict() - - # Check mandatory fields - self.assertEqual(data["annotations"], {"annotation_key": "value"}) - self.assertEqual(data["data_format_version"], "0.2.0") - self.assertEqual(data["input"], "https://example.com") - self.assertEqual(data["measurement_start_time"], "2023-01-01 12:00:00") - self.assertEqual(data["probe_asn"], "AS12345") - self.assertEqual(data["probe_cc"], "IT") - self.assertEqual(data["software_name"], "ooniprobe") - self.assertEqual(data["software_version"], "3.0.0") - self.assertEqual(data["test_keys"], {"simple": "test"}) - self.assertEqual(data["test_name"], "web_connectivity") - self.assertEqual(data["test_runtime"], 1.0) - self.assertEqual(data["test_start_time"], "2023-01-01 12:00:00") - self.assertEqual(data["test_version"], "0.0.1") - - # Check default values - self.assertEqual(data["probe_ip"], "127.0.0.1") - self.assertEqual(data["report_id"], "") - - # Check that optional fields are not present - self.assertNotIn("options", data) - self.assertNotIn("probe_network_name", data) - self.assertNotIn("resolver_asn", data) - self.assertNotIn("resolver_ip", data) - self.assertNotIn("resolver_network_name", data) - self.assertNotIn("test_helpers", data) - - def test_measurement_as_dict_all_fields(self): - """Test with all fields, including optional ones.""" - dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) - measurement = Measurement( - # Mandatory fields - annotations={"annotation_key": "value"}, - data_format_version="0.2.0", - input="https://example.com", - measurement_start_time=dt, - probe_asn="AS12345", - probe_cc="IT", - software_name="ooniprobe", - software_version="3.0.0", - test_keys=SimpleTestKeys(), - test_name="web_connectivity", - test_runtime=1.0, - test_start_time=dt, - test_version="0.0.1", - # Fields with default values - probe_ip="93.184.216.34", - report_id="20230101_IT_test", - # Optional fields - options=["option1", "option2"], - probe_network_name="Example ISP", - resolver_asn="AS12346", - resolver_ip="8.8.8.8", - resolver_network_name="Example DNS", - test_helpers={"dns": "8.8.8.8", "web": "web-connectivity.example.org"}, - ) - - data = measurement.as_dict() - - # Check mandatory fields - self.assertEqual(data["annotations"], {"annotation_key": "value"}) - self.assertEqual(data["data_format_version"], "0.2.0") - self.assertEqual(data["input"], "https://example.com") - self.assertEqual(data["measurement_start_time"], "2023-01-01 12:00:00") - self.assertEqual(data["probe_asn"], "AS12345") - self.assertEqual(data["probe_cc"], "IT") - self.assertEqual(data["software_name"], "ooniprobe") - self.assertEqual(data["software_version"], "3.0.0") - self.assertEqual(data["test_keys"], {"simple": "test"}) - self.assertEqual(data["test_name"], "web_connectivity") - self.assertEqual(data["test_runtime"], 1.0) - self.assertEqual(data["test_start_time"], "2023-01-01 12:00:00") - self.assertEqual(data["test_version"], "0.0.1") - - # Check fields with default values - self.assertEqual(data["probe_ip"], "93.184.216.34") - self.assertEqual(data["report_id"], "20230101_IT_test") - - # Check optional fields - self.assertEqual(data["options"], ["option1", "option2"]) - self.assertEqual(data["probe_network_name"], "Example ISP") - self.assertEqual(data["resolver_asn"], "AS12346") - self.assertEqual(data["resolver_ip"], "8.8.8.8") - self.assertEqual(data["resolver_network_name"], "Example DNS") - self.assertEqual( - data["test_helpers"], - {"dns": "8.8.8.8", "web": "web-connectivity.example.org"}, - ) - - def test_datetime_to_ooni_format_utc(self): - dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone.utc) - formatted = datetime_to_ooni_format(dt) - self.assertEqual(formatted, "2023-01-01 12:00:00") - - def test_datetime_to_ooni_format_timezone(self): - # Ensure we're correctly converting when there's a timezone - dt = datetime(2023, 1, 1, 12, 0, tzinfo=timezone(timedelta(hours=2))) - formatted = datetime_to_ooni_format(dt) - self.assertEqual(formatted, "2023-01-01 10:00:00") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/aggregatetunnelmetrics/spec/test_oonicollector.py b/tests/aggregatetunnelmetrics/spec/test_oonicollector.py index 31864d4..f893520 100644 --- a/tests/aggregatetunnelmetrics/spec/test_oonicollector.py +++ b/tests/aggregatetunnelmetrics/spec/test_oonicollector.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later -from datetime import datetime, timezone +from datetime import timedelta, datetime, timezone from aggregatetunnelmetrics.spec import oonicollector @@ -151,3 +151,44 @@ def test_open_report_request_as_dict(): "test_start_time": "2024-01-01 00:00:00", "test_version": "0.1.0", } + + +def test_datetime_format_with_timezone(): + dt = datetime(2024, 1, 1, 12, 0, tzinfo=timezone(timedelta(hours=2))) + formatted = oonicollector.format_datetime(dt) + assert formatted == "2024-01-01 10:00:00" # Should convert to UTC + + +def test_measurement_as_dict_minimal_without_optionals(): + """Test that optional fields are not present in minimal measurement.""" + now = datetime(2024, 1, 1, tzinfo=timezone.utc) + measurement = oonicollector.Measurement( + annotations={}, + data_format_version="0.2.0", + input="test", + measurement_start_time=now, + probe_asn="AS12345", + probe_cc="US", + software_name="test", + software_version="1.0", + test_keys=MockTestKeys(), + test_name="test", + test_runtime=1.0, + test_start_time=now, + test_version="1.0", + ) + + result = measurement.as_dict() + + # Verify default values + assert result["probe_ip"] == "127.0.0.1" + assert result["report_id"] == "" + + # Verify optional fields are not present + assert "options" not in result + assert "probe_network_name" not in result + assert "resolver_asn" not in result + assert "resolver_cc" not in result + assert "resolver_ip" not in result + assert "resolver_network_name" not in result + assert "test_helpers" not in result -- GitLab From 23c48f302e51935cf8f82ce63fdc52d15c89681e Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 23 Feb 2025 18:26:19 +0100 Subject: [PATCH 57/75] feat(aggregators): start moving code in this package --- .../aggregators/__init__.py | 9 + aggregatetunnelmetrics/aggregators/common.py | 16 +- aggregatetunnelmetrics/aggregators/privacy.py | 160 ++++++++++++++++++ 3 files changed, 180 insertions(+), 5 deletions(-) create mode 100644 aggregatetunnelmetrics/aggregators/privacy.py diff --git a/aggregatetunnelmetrics/aggregators/__init__.py b/aggregatetunnelmetrics/aggregators/__init__.py index e69de29..42d52c7 100644 --- a/aggregatetunnelmetrics/aggregators/__init__.py +++ b/aggregatetunnelmetrics/aggregators/__init__.py @@ -0,0 +1,9 @@ +""" +Aggregation Code +================ + +This package contains code to aggregate tunnel metrics +using specific aggregation policies. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/aggregatetunnelmetrics/aggregators/common.py b/aggregatetunnelmetrics/aggregators/common.py index 5becd17..ca260d8 100644 --- a/aggregatetunnelmetrics/aggregators/common.py +++ b/aggregatetunnelmetrics/aggregators/common.py @@ -2,7 +2,8 @@ Common Aggregation Code ======================= -TODO... +This module contains the base classes and functions +necessary for aggregating tunnel metrics. """ # SPDX-License-Identifier: GPL-3.0-or-later @@ -10,11 +11,11 @@ TODO... from dataclasses import dataclass, field from statistics import quantiles -from typing import Iterator, Protocol, runtime_checkable -from . import fieldtesting -from . import metrics -from . import oonicollector +from ..spec import ( + fieldtesting, + metrics, +) def make_distribution(values: list[float]) -> metrics.Distribution | None: @@ -107,7 +108,10 @@ class PingMetricsPerTarget: metrics.TunnelPingStatement( target_address=self.target_address, sample_size=self.num_samples, + latency_min=make_distribution(self.min), latency_avg=make_distribution(self.avg), + latency_max=make_distribution(self.max), + loss=make_distribution(self.loss), ) ] @@ -195,6 +199,7 @@ class NDTMetricsPerTarget: sample_size=self.num_samples, latency=make_distribution(self.download_latency), speed=make_distribution(self.download_throughput), + rexmit=make_distribution(self.download_rexmit), ), metrics.TunnelNDTStatement( direction="upload", @@ -204,6 +209,7 @@ class NDTMetricsPerTarget: sample_size=self.num_samples, latency=make_distribution(self.upload_latency), speed=make_distribution(self.upload_throughput), + rexmit=make_distribution(self.upload_rexmit), ), ] diff --git a/aggregatetunnelmetrics/aggregators/privacy.py b/aggregatetunnelmetrics/aggregators/privacy.py new file mode 100644 index 0000000..66cf716 --- /dev/null +++ b/aggregatetunnelmetrics/aggregators/privacy.py @@ -0,0 +1,160 @@ +""" +Privacy and Quantization Filters +================================ + +This module contains filters that ensure privacy and +reasonable quantization of metrics by: + +1. Filtering out small sample sizes from the output +2. Rounding sample sizes to reduce precision +3. Applying consistent rules across all metrics + +The filters are stateless and work by transforming input +structures into filtered versions. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from dataclasses import dataclass, replace + +from ..spec import metrics + + +@dataclass(frozen=True) +class Config: + """ + Configuration for privacy and quantization filters. + + Fields: + min_sample_size: minimum samples (default: 1000) + round_to: round sample sizes to this value (default: 100) + """ + + min_sample_size: int = 1000 + round_to: int = 100 + + +def filter_sample_size(sample_size: int, config: Config) -> int | None: + """ + Apply privacy rules to sample sizes. + + Args: + sample_size: The raw sample size + config: The filtering configuration + + Returns: + The filtered sample size or None if below minimum + """ + return ( + (sample_size // config.round_to) * config.round_to + if sample_size >= config.min_sample_size + else None + ) + + +def filter_network_error( + stmt: metrics.NetworkErrorStatement, + config: Config, +) -> metrics.NetworkErrorStatement: + """ + Filter network error statements. + + Args: + stmt: The statement to filter + config: The filtering configuration + + Returns: + Filtered statement with sample size removed if below threshold + """ + if stmt.sample_size is None: + return stmt + + # Round failure ratio to 2 decimal places + rounded_ratio = round(stmt.failure_ratio, 2) + + # Sample size may be None if below threshold + filtered_size = filter_sample_size(stmt.sample_size, config) + + return replace( + stmt, + sample_size=filtered_size, + failure_ratio=rounded_ratio, + ) + + +def filter_tunnel_ping( + stmt: metrics.TunnelPingStatement, config: Config +) -> metrics.TunnelPingStatement: + """ + Filter tunnel ping statements. + + Args: + stmt: The statement to filter + config: The filtering configuration + + Returns: + Filtered statement with sample size removed if below threshold + """ + # Sample size may be None if below threshold + filtered_size = filter_sample_size(stmt.sample_size, config) + + return replace(stmt, sample_size=filtered_size) + + +def filter_tunnel_ndt( + stmt: metrics.TunnelNDTStatement, config: Config +) -> metrics.TunnelNDTStatement: + """ + Filter tunnel NDT statements. + + Args: + stmt: The statement to filter + config: The filtering configuration + + Returns: + Filtered statement with sample size removed if below threshold + """ + if stmt.sample_size is None: + return stmt + + # Sample size may be None if below threshold + filtered_size = filter_sample_size(stmt.sample_size, config) + + return replace(stmt, sample_size=filtered_size) + + +def filter_statement(stmt: metrics.Statement, config: Config) -> metrics.Statement: + """ + Filter any statement type. + + Args: + stmt: The statement to filter + config: The filtering configuration + + Returns: + Filtered statement with sample size removed if below threshold + """ + if isinstance(stmt, metrics.NetworkErrorStatement): + return filter_network_error(stmt, config) + elif isinstance(stmt, metrics.TunnelPingStatement): + return filter_tunnel_ping(stmt, config) + elif isinstance(stmt, metrics.TunnelNDTStatement): + return filter_tunnel_ndt(stmt, config) + return stmt + + +def filter_test_keys( + tk: metrics.MeasurementTestKeys, config: Config +) -> metrics.MeasurementTestKeys: + """ + Filter measurement test keys. + + Args: + tk: The test keys to filter + config: The filtering configuration + + Returns: + New test keys with filtered statements + """ + filtered_bodies = [filter_statement(stmt, config) for stmt in tk.bodies] + return replace(tk, bodies=filtered_bodies) -- GitLab From 65e8a36c3cc1161160c068ab325de95b9a564d6e Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 23 Feb 2025 18:27:51 +0100 Subject: [PATCH 58/75] feat: add endpoint_pool_name to CSV spec --- aggregatetunnelmetrics/spec/fieldtesting.py | 1 + tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py | 1 + 2 files changed, 2 insertions(+) diff --git a/aggregatetunnelmetrics/spec/fieldtesting.py b/aggregatetunnelmetrics/spec/fieldtesting.py index 08dfa91..3a84345 100644 --- a/aggregatetunnelmetrics/spec/fieldtesting.py +++ b/aggregatetunnelmetrics/spec/fieldtesting.py @@ -55,6 +55,7 @@ class Entry: ndt_target_hostname: str | None ndt_target_address: str | None ndt_target_port: int | None + endpoint_pool_name: str | None def is_tunnel_measurement(self) -> bool: """ diff --git a/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py b/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py index 7e8732d..a38c531 100644 --- a/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py +++ b/tests/aggregatetunnelmetrics/fieldtestingcsv/test_parser.py @@ -230,6 +230,7 @@ def make_entry(): ndt_target_hostname=None, ndt_target_address=None, ndt_target_port=None, + endpoint_pool_name=None, ) return _make_entry -- GitLab From f7dae8464dbe07bb9ce372791a9b79dd1d9b8802 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 23 Feb 2025 20:26:32 +0100 Subject: [PATCH 59/75] feat: finish writing the aggregators package --- aggregatetunnelmetrics/aggregators/common.py | 42 ++-- .../aggregators/endpointpool.py | 171 ++++++++++++++++ aggregatetunnelmetrics/aggregators/privacy.py | 6 +- .../fieldtestingcsv/parser.py | 1 + .../aggregators/__init__.py | 3 + .../aggregators/conftest.py | 43 ++++ .../aggregators/test_common.py | 168 ++++++++++++++++ .../aggregators/test_endpointpool.py | 58 ++++++ .../aggregators/test_privacy.py | 183 ++++++++++++++++++ .../spec/test_fieldtesting.py | 1 + 10 files changed, 662 insertions(+), 14 deletions(-) create mode 100644 aggregatetunnelmetrics/aggregators/endpointpool.py create mode 100644 tests/aggregatetunnelmetrics/aggregators/__init__.py create mode 100644 tests/aggregatetunnelmetrics/aggregators/conftest.py create mode 100644 tests/aggregatetunnelmetrics/aggregators/test_common.py create mode 100644 tests/aggregatetunnelmetrics/aggregators/test_endpointpool.py create mode 100644 tests/aggregatetunnelmetrics/aggregators/test_privacy.py diff --git a/aggregatetunnelmetrics/aggregators/common.py b/aggregatetunnelmetrics/aggregators/common.py index ca260d8..2bb0b80 100644 --- a/aggregatetunnelmetrics/aggregators/common.py +++ b/aggregatetunnelmetrics/aggregators/common.py @@ -10,7 +10,7 @@ necessary for aggregating tunnel metrics. from dataclasses import dataclass, field -from statistics import quantiles +from statistics import StatisticsError, quantiles from ..spec import ( fieldtesting, @@ -20,10 +20,10 @@ from ..spec import ( def make_distribution(values: list[float]) -> metrics.Distribution | None: """Generates an empirical distribution from a list of values.""" - # TODO(bassosimone): this code is most likely wrong - if not values: + try: + q = quantiles(values, n=100, method="exclusive") + except StatisticsError: return None - q = quantiles(values, n=100, method="exclusive") return metrics.Distribution( p25=q[24], p50=q[49], @@ -249,11 +249,9 @@ class NDTMetricsOverall: @dataclass -class AggregationUnitMetrics: +class ProtocolMetrics: """ - Allows tracking an aggregation unit metrics. - - The aggregation unit depends on the aggregation policy. + Tracks metrics for a specific VPN/bridge protocol. Fields: creation: The creation metrics. @@ -261,9 +259,9 @@ class AggregationUnitMetrics: tunnel_ndt: The tunnel NDT metrics. """ - creation: CreationMetrics = CreationMetrics() - tunnel_ping: PingMetricsOverall = PingMetricsOverall() - tunnel_ndt: NDTMetricsOverall = NDTMetricsOverall() + creation: CreationMetrics = field(default_factory=CreationMetrics) + tunnel_ping: PingMetricsOverall = field(default_factory=PingMetricsOverall) + tunnel_ndt: NDTMetricsOverall = field(default_factory=NDTMetricsOverall) def update(self, entry: fieldtesting.Entry) -> None: """Updates the metrics with a new entry.""" @@ -278,3 +276,25 @@ class AggregationUnitMetrics: result.extend(self.tunnel_ping.statements()) result.extend(self.tunnel_ndt.statements()) return result + + +@dataclass(frozen=True) +class UpstreamCollector: + """ + Describes the collector that collected the field + testing data, which is "upstream" when observed from + the point of view of the OONI collector. + + Fields: + asn: The Autonomous System Number (ASN) of the upstream collector. + cc: The country code of the upstream collector. + name: The name of the upstream collector. + software_name: The name of the software used by the upstream collector. + software_version: The version of the software used by the upstream collector + """ + + asn: str + cc: str + name: str + software_name: str + software_version: str diff --git a/aggregatetunnelmetrics/aggregators/endpointpool.py b/aggregatetunnelmetrics/aggregators/endpointpool.py new file mode 100644 index 0000000..b71e22e --- /dev/null +++ b/aggregatetunnelmetrics/aggregators/endpointpool.py @@ -0,0 +1,171 @@ +""" +Endpoint Pool Scope Aggregation +=============================== + +This module implements aggregation using endpoint pool scope. +""" + +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Iterator +from urllib.parse import urlencode, urlunparse + +from ..spec import ( + fieldtesting, + metrics, + oonicollector, +) + +from .common import ProtocolMetrics, UpstreamCollector +from . import privacy + + +@dataclass +class ProtocolAggregator: + """ + State for aggregating by protocol. + + Fields: + protocols: Maps protocol name to ProtocolSpecificMetrics. + """ + + protocols: dict[str, ProtocolMetrics] = field(default_factory=dict) + + def update(self, entry: fieldtesting.Entry) -> None: + """Updates the metrics with a new entry.""" + key = f"{entry.protocol}" + if not key in self.protocols: + self.protocols[key] = ProtocolMetrics() + self.protocols[key].update(entry) + + +@dataclass +class PoolAggregator: + """ + State for the endpoint_pool-scope aggregation. + + Fields: + pools: Maps pool name to ProtocolAggregator. + """ + + pools: dict[str, ProtocolAggregator] = field(default_factory=dict) + + def update(self, entry: fieldtesting.Entry) -> None: + """Updates the metrics with a new entry.""" + key = f"{entry.endpoint_pool_name}" + if not key in self.pools: + self.pools[key] = ProtocolAggregator() + self.pools[key].update(entry) + + +@dataclass(frozen=True) +class Aggregator: + """ + Implements aggregator.Logic for endpoint pool scope. + + Fields: + provider: Name of the VPN provider. + privacy_config: Configuration for privacy filters. + state: Mutable aggregation state. + """ + + provider: str + pool_country: str + time_window: metrics.TimeWindow + upstream_collector: UpstreamCollector + privacy_config: privacy.Config = privacy.Config() + state: PoolAggregator = field(default_factory=PoolAggregator) + + def aggregate( + self, + entries: Iterator[fieldtesting.Entry], + ) -> Iterator[oonicollector.Measurement]: + # Walk through entries updating the mutable state + for entry in entries: + self.state.update(entry) + + # Serialize and yield each measurement + for pool_name, pool_values in self.state.pools.items(): + for proto_name, proto_metrics in pool_values.protocols.items(): + yield self._create_measurement(pool_name, proto_name, proto_metrics) + + def _create_measurement( + self, + pool_name: str, + proto_name: str, + proto_metrics: ProtocolMetrics, + ) -> oonicollector.Measurement: + """ + Creates a new OONI Measurement from the given metrics. + + Args: + aggr: The metrics to convert. + + Returns: + The OONI Measurement. + """ + + # Serialize the bodies to a list of statements + bodies = proto_metrics.statements() + + # Apply privacy filters to the bodies + bodies = [ + privacy.filter_statement(stmt, self.privacy_config) for stmt in bodies + ] + + # Get the current time for the measurement timestamps + measurement_time = datetime.now(timezone.utc) + + # Create the test keys with the filtered statements + test_keys = metrics.MeasurementTestKeys( + provider=self.provider, + scope=metrics.EndpointPoolScope( + protocol=proto_name, + cc=self.pool_country, + ), + bodies=bodies, + time_window=self.time_window, + ) + + return oonicollector.Measurement( + annotations={ + "upstream_collector": self.upstream_collector.name, + }, + data_format_version="0.2.0", + input=self._create_input_url(pool_name, proto_name), + measurement_start_time=measurement_time, + probe_asn=self.upstream_collector.asn, + probe_cc=self.upstream_collector.cc, + software_name=self.upstream_collector.software_name, + software_version=self.upstream_collector.software_version, + test_keys=test_keys, + test_name="aggregate_tunnel_metrics", + test_runtime=0.0, # Not relevant for aggregate data + test_start_time=measurement_time, + test_version="0.1.0", + ) + + def _create_input_url(self, pool_name: str, proto_name: str) -> str: + """Create the measurement input URL""" + # Only include the pool name if it is not None, which is + # serialized to the "None" string by the code above. + # + # This is slightly sketchy but overall not a disaster + # because this code belongs to the same file. + # + # TODO(bassosimone): consider using a cleaner approach here. + query = {} + if pool_name and pool_name != "None": + query["endpoint_pool_name"] = pool_name + + # Build URL using urlunparse for safety + return urlunparse( + ( + proto_name, # scheme (e.g., "openvpn+obfs4") + self.provider, # netloc (e.g., "riseup.net") + "/", # path + "", # params + urlencode(query), # query (e.g., "address=1.2.3.4&...") + "", # fragment + ) + ) diff --git a/aggregatetunnelmetrics/aggregators/privacy.py b/aggregatetunnelmetrics/aggregators/privacy.py index 66cf716..7e349c2 100644 --- a/aggregatetunnelmetrics/aggregators/privacy.py +++ b/aggregatetunnelmetrics/aggregators/privacy.py @@ -136,11 +136,11 @@ def filter_statement(stmt: metrics.Statement, config: Config) -> metrics.Stateme """ if isinstance(stmt, metrics.NetworkErrorStatement): return filter_network_error(stmt, config) - elif isinstance(stmt, metrics.TunnelPingStatement): + if isinstance(stmt, metrics.TunnelPingStatement): return filter_tunnel_ping(stmt, config) - elif isinstance(stmt, metrics.TunnelNDTStatement): + if isinstance(stmt, metrics.TunnelNDTStatement): return filter_tunnel_ndt(stmt, config) - return stmt + raise TypeError(f"unsupported statement type {type(stmt)}") def filter_test_keys( diff --git a/aggregatetunnelmetrics/fieldtestingcsv/parser.py b/aggregatetunnelmetrics/fieldtestingcsv/parser.py index e2eec85..fd037d5 100644 --- a/aggregatetunnelmetrics/fieldtestingcsv/parser.py +++ b/aggregatetunnelmetrics/fieldtestingcsv/parser.py @@ -93,6 +93,7 @@ def parse_single_row(row: Dict[str, str]) -> Entry: ndt_target_hostname=None, ndt_target_address=None, ndt_target_port=None, + endpoint_pool_name=None, ) diff --git a/tests/aggregatetunnelmetrics/aggregators/__init__.py b/tests/aggregatetunnelmetrics/aggregators/__init__.py new file mode 100644 index 0000000..2a6977d --- /dev/null +++ b/tests/aggregatetunnelmetrics/aggregators/__init__.py @@ -0,0 +1,3 @@ +"""Tests for the aggregators package.""" + +# SPDX-License-Identifier: GPL-3.0-or-later diff --git a/tests/aggregatetunnelmetrics/aggregators/conftest.py b/tests/aggregatetunnelmetrics/aggregators/conftest.py new file mode 100644 index 0000000..a94e140 --- /dev/null +++ b/tests/aggregatetunnelmetrics/aggregators/conftest.py @@ -0,0 +1,43 @@ +"""Common test fixtures.""" + +# SPDX-License-Identifier: GPL-3.0-or-later +# +from datetime import datetime, timezone + +import pytest + +from aggregatetunnelmetrics.spec import fieldtesting + + +@pytest.fixture +def sample_entry(): + return fieldtesting.Entry( + filename="test.csv", + date=datetime(2024, 1, 1, 0, 0, 0, 0, tzinfo=timezone.utc), + asn="AS12345", + isp="Test ISP", + est_city="Test City", + user="test_user", + region="test_region", + server_fqdn="test.server.com", + server_ip="1.1.1.1", + mobile=False, + tunnel="tunnel", + throughput_download=10.0, + throughput_upload=5.0, + latency_download=100.0, + latency_upload=120.0, + retransmission_download=0.01, + retransmission_upload=0.02, + ping_packets_loss=0.0, + ping_roundtrip_min=10.0, + ping_roundtrip_avg=15.0, + ping_roundtrip_max=20.0, + err_message="", + protocol="openvpn", + ping_target_address="8.8.8.8", + ndt_target_hostname="ndt.server.com", + ndt_target_address="2.2.2.2", + ndt_target_port=443, + endpoint_pool_name="test_pool", + ) diff --git a/tests/aggregatetunnelmetrics/aggregators/test_common.py b/tests/aggregatetunnelmetrics/aggregators/test_common.py new file mode 100644 index 0000000..3861219 --- /dev/null +++ b/tests/aggregatetunnelmetrics/aggregators/test_common.py @@ -0,0 +1,168 @@ +"""Tests for the aggregators.common module.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from dataclasses import replace + +import pytest +import statistics + +from aggregatetunnelmetrics.aggregators import common +from aggregatetunnelmetrics.spec import fieldtesting, metrics + + +def test_creation_metrics(sample_entry): + mx = common.CreationMetrics() + mx.update(sample_entry) + + assert mx.num_samples == 1 + assert len(mx.errors) == 0 + + # Test error case + error_entry = replace(sample_entry, tunnel="ERROR/tunnel") + mx.update(error_entry) + + assert mx.num_samples == 2 + assert len(mx.errors) == 1 + assert mx.errors["bootstrap.generic_error"] == 1 + + +def test_creation_metrics_statements(sample_entry): + """Test that CreationMetrics.statements() properly generates statements""" + mx = common.CreationMetrics() + + # Add an error entry to ensure we get statements + error_entry = fieldtesting.Entry( + **{**sample_entry.__dict__, "tunnel": "ERROR/tunnel"} + ) + mx.update(error_entry) + + statements = mx.statements() + assert len(statements) == 1 + assert isinstance(statements[0], metrics.NetworkErrorStatement) + assert statements[0].error == "bootstrap.generic_error" + assert statements[0].sample_size == 1 + + +def test_ping_metrics_per_target(sample_entry): + metrics = common.PingMetricsPerTarget(target_address="8.8.8.8") + metrics.update(sample_entry) + + assert metrics.num_samples == 1 + assert len(metrics.min) == 1 + assert metrics.min[0] == 10.0 + assert metrics.avg[0] == 15.0 + assert metrics.max[0] == 20.0 + assert metrics.loss[0] == 0.0 + + +def test_ndt_metrics_per_target(sample_entry): + metrics = common.NDTMetricsPerTarget( + target_hostname="ndt.server.com", target_address="2.2.2.2", target_port=443 + ) + metrics.update(sample_entry) + + assert metrics.num_samples == 1 + assert len(metrics.download_throughput) == 1 + assert metrics.download_throughput[0] == 10.0 + assert metrics.upload_throughput[0] == 5.0 + + +def test_make_distribution_empty(): + """Test distribution with no values.""" + assert common.make_distribution([]) is None + + +def test_make_distribution_single_value(): + """Test distribution with a single value""" + dist = common.make_distribution([1.0]) + assert dist is not None + assert dist.p25 == pytest.approx(1.0) + assert dist.p50 == pytest.approx(1.0) + assert dist.p75 == pytest.approx(1.0) + assert dist.p99 == pytest.approx(1.0) + + +def test_make_distribution_two_values(): + """Test distribution with two values""" + dist = common.make_distribution([1.0, 2.0]) + assert dist is not None + assert dist.p25 == pytest.approx(0.75) + assert dist.p50 == pytest.approx(1.5) + assert dist.p75 == pytest.approx(2.25) + assert dist.p99 == pytest.approx(2.97) + +def test_make_distribution_three_values(): + """Test distribution with three values""" + # With three values, we should get valid quantiles + dist = common.make_distribution([1.0, 2.0, 3.0]) + assert dist is not None + assert dist.p25 == pytest.approx(1.0) + assert dist.p50 == pytest.approx(2.0) + assert dist.p75 == pytest.approx(3.0) + assert dist.p99 == pytest.approx(3.96) + + +def test_make_distribution_four_values(): + """Test distribution with four values""" + dist = common.make_distribution([1.0, 2.0, 3.0, 4.0]) + assert dist is not None + assert dist.p25 == pytest.approx(1.25) + assert dist.p50 == pytest.approx(2.5) + assert dist.p75 == pytest.approx(3.75) + assert dist.p99 == pytest.approx(4.95) + + +def test_make_distribution_five_values(): + """Test distribution with five values""" + dist = common.make_distribution([1.0, 2.0, 3.0, 4.0, 5.0]) + assert dist is not None + assert dist.p25 == pytest.approx(1.5) + assert dist.p50 == pytest.approx(3.0) + assert dist.p75 == pytest.approx(4.5) + assert dist.p99 == pytest.approx(5.94) + + +def test_make_distribution_large_dataset(): + """Test distribution with a larger dataset""" + # Create array with values from 1 to 100 + values = [float(x) for x in range(1, 101)] + dist = common.make_distribution(values) + assert dist is not None + assert dist.p25 == pytest.approx(25.25) + assert dist.p50 == pytest.approx(50.5) + assert dist.p75 == pytest.approx(75.75) + assert dist.p99 == pytest.approx(99.99) + + +def test_make_distribution_with_duplicates(): + """Test distribution with duplicate values""" + values = [1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0] + dist = common.make_distribution(values) + assert dist is not None + assert dist.p25 == pytest.approx(1.25) + assert dist.p50 == pytest.approx(2.5) + assert dist.p75 == pytest.approx(3.75) + assert dist.p99 == pytest.approx(4.0) + + +def test_make_distribution_with_negative_values(): + """Test distribution with negative values""" + values = [-4.0, -3.0, -2.0, -1.0, 0.0, 1.0, 2.0, 3.0, 4.0] + dist = common.make_distribution(values) + assert dist is not None + assert dist.p25 == pytest.approx(-2.5) + assert dist.p50 == pytest.approx(0.0) + assert dist.p75 == pytest.approx(2.5) + assert dist.p99 == pytest.approx(4.9) + + +def test_make_distribution_all_same_value(): + """Test distribution where all values are the same""" + values = [1.0] * 10 + dist = common.make_distribution(values) + assert dist is not None + assert dist.p25 == pytest.approx(1.0) + assert dist.p50 == pytest.approx(1.0) + assert dist.p75 == pytest.approx(1.0) + assert dist.p99 == pytest.approx(1.0) diff --git a/tests/aggregatetunnelmetrics/aggregators/test_endpointpool.py b/tests/aggregatetunnelmetrics/aggregators/test_endpointpool.py new file mode 100644 index 0000000..31461f5 --- /dev/null +++ b/tests/aggregatetunnelmetrics/aggregators/test_endpointpool.py @@ -0,0 +1,58 @@ +"""Tests for the aggregators.endpointpool module.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from datetime import datetime, timezone + +import pytest + +from aggregatetunnelmetrics.aggregators import endpointpool +from aggregatetunnelmetrics.spec import metrics + + +@pytest.fixture +def upstream_collector(): + return endpointpool.UpstreamCollector( + asn="AS12345", + cc="XX", + name="test_collector", + software_name="test_software", + software_version="1.0.0", + ) + + +def test_protocol_aggregator(sample_entry): + aggregator = endpointpool.ProtocolAggregator() + aggregator.update(sample_entry) + + assert len(aggregator.protocols) == 1 + assert "openvpn" in aggregator.protocols + + +def test_pool_aggregator(sample_entry): + aggregator = endpointpool.PoolAggregator() + aggregator.update(sample_entry) + + assert len(aggregator.pools) == 1 + assert "test_pool" in aggregator.pools + + +def test_aggregator(sample_entry, upstream_collector): + time_window = metrics.TimeWindow( + start=datetime.now(timezone.utc), end=datetime.now(timezone.utc) + ) + + aggregator = endpointpool.Aggregator( + provider="test_provider", + pool_country="XX", + time_window=time_window, + upstream_collector=upstream_collector, + ) + + measurements = list(aggregator.aggregate(iter([sample_entry]))) + assert len(measurements) == 1 + + measurement = measurements[0] + assert measurement.test_name == "aggregate_tunnel_metrics" + assert measurement.probe_asn == upstream_collector.asn + assert measurement.probe_cc == upstream_collector.cc diff --git a/tests/aggregatetunnelmetrics/aggregators/test_privacy.py b/tests/aggregatetunnelmetrics/aggregators/test_privacy.py new file mode 100644 index 0000000..37a2480 --- /dev/null +++ b/tests/aggregatetunnelmetrics/aggregators/test_privacy.py @@ -0,0 +1,183 @@ +"""Tests for the aggregators.privacy module.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from datetime import datetime, timezone + +import pytest + +from aggregatetunnelmetrics.spec import metrics +from aggregatetunnelmetrics.aggregators import privacy + + +def test_filter_sample_size(): + config = privacy.Config(min_sample_size=1000, round_to=100) + + # Test below minimum + assert privacy.filter_sample_size(999, config) is None + + # Test rounding + assert privacy.filter_sample_size(1234, config) == 1200 + assert privacy.filter_sample_size(1050, config) == 1000 + + +def test_filter_network_error(): + config = privacy.Config(min_sample_size=1000, round_to=100) + + stmt = metrics.NetworkErrorStatement( + sample_size=1234, failure_ratio=0.1234, error="test_error" + ) + + filtered = privacy.filter_network_error(stmt, config) + assert filtered.sample_size == 1200 + assert filtered.failure_ratio == 0.12 # Rounded to 2 decimal places + assert filtered.error == "test_error" + + # Test below minimum sample size + stmt = metrics.NetworkErrorStatement( + sample_size=500, failure_ratio=0.1, error="test_error" + ) + filtered = privacy.filter_network_error(stmt, config) + assert filtered.sample_size is None + + +def test_filter_tunnel_ping(): + config = privacy.Config(min_sample_size=1000, round_to=100) + + distribution = metrics.Distribution(p25=1.0, p50=2.0, p75=3.0, p99=4.0) + stmt = metrics.TunnelPingStatement( + target_address="1.1.1.1", + sample_size=1234, + latency_min=distribution, + latency_avg=distribution, + latency_max=distribution, + loss=distribution, + ) + + filtered = privacy.filter_tunnel_ping(stmt, config) + assert filtered.sample_size == 1200 + assert filtered.target_address == "1.1.1.1" + assert filtered.latency_min == distribution + + +def test_filter_tunnel_ndt(): + config = privacy.Config(min_sample_size=1000, round_to=100) + + distribution = metrics.Distribution(p25=1.0, p50=2.0, p75=3.0, p99=4.0) + stmt = metrics.TunnelNDTStatement( + direction="download", + target_hostname="example.com", + target_address="1.1.1.1", + target_port=443, + sample_size=1234, + latency=distribution, + speed=distribution, + rexmit=distribution, + ) + + filtered = privacy.filter_tunnel_ndt(stmt, config) + assert filtered.sample_size == 1200 + assert filtered.direction == "download" + assert filtered.target_hostname == "example.com" + + +def test_filter_network_error_none_sample(): + """Test filtering network error statement with None sample size""" + stmt = metrics.NetworkErrorStatement( + sample_size=None, failure_ratio=0.5, error="test_error" + ) + config = privacy.Config() + + filtered = privacy.filter_network_error(stmt, config) + assert filtered.sample_size is None + assert filtered.failure_ratio == 0.5 + assert filtered.error == "test_error" + + +def test_filter_tunnel_ndt_none_sample(): + """Test filtering NDT statement with None sample size""" + distribution = metrics.Distribution(p25=1.0, p50=2.0, p75=3.0, p99=4.0) + stmt = metrics.TunnelNDTStatement( + direction="download", + target_hostname="example.com", + target_address="1.1.1.1", + target_port=443, + sample_size=None, + latency=distribution, + speed=distribution, + rexmit=distribution, + ) + config = privacy.Config() + + filtered = privacy.filter_tunnel_ndt(stmt, config) + assert filtered.sample_size is None + + +def test_filter_statement_network_error(): + """Test filtering network error through generic filter""" + stmt = metrics.NetworkErrorStatement( + sample_size=1234, failure_ratio=0.5, error="test_error" + ) + config = privacy.Config() + + filtered = privacy.filter_statement(stmt, config) + assert isinstance(filtered, metrics.NetworkErrorStatement) + assert filtered.sample_size == 1200 # rounded as per config + + +class UnsupportedStatement(metrics.Statement): + """Mock unsupported statement type""" + + def as_dict(self) -> dict: + return {} + + +def test_filter_statement_unsupported(): + """Test that filtering unsupported statement type raises TypeError""" + stmt = UnsupportedStatement() + config = privacy.Config() + + with pytest.raises(TypeError, match="unsupported statement type"): + privacy.filter_statement(stmt, config) + + +def test_filter_test_keys(): + """Test filtering complete test keys""" + # Create a mix of different statement types + statements = [ + metrics.NetworkErrorStatement( + sample_size=1234, failure_ratio=0.5, error="test_error" + ), + metrics.TunnelPingStatement( + target_address="1.1.1.1", + sample_size=2345, + latency_min=metrics.Distribution(p25=1.0, p50=2.0, p75=3.0, p99=4.0), + latency_avg=metrics.Distribution(p25=1.0, p50=2.0, p75=3.0, p99=4.0), + latency_max=metrics.Distribution(p25=1.0, p50=2.0, p75=3.0, p99=4.0), + loss=metrics.Distribution(p25=0.0, p50=0.0, p75=0.0, p99=0.1), + ), + ] + + test_keys = metrics.MeasurementTestKeys( + provider="test_provider", + scope=metrics.GlobalScope(protocol="test_protocol"), + time_window=metrics.TimeWindow( + start=datetime.now(timezone.utc), end=datetime.now(timezone.utc) + ), + bodies=statements, + ) + + config = privacy.Config() + filtered = privacy.filter_test_keys(test_keys, config) + + assert len(filtered.bodies) == 2 + + # Check each statement's rounded sample size individually + assert isinstance(filtered.bodies[0], metrics.NetworkErrorStatement) + assert filtered.bodies[0].sample_size == 1200 # 1234 rounded down + assert isinstance(filtered.bodies[1], metrics.TunnelPingStatement) + assert filtered.bodies[1].sample_size == 2300 # 2345 rounded down + + # Verify the statements maintain their original types + assert isinstance(filtered.bodies[0], metrics.NetworkErrorStatement) + assert isinstance(filtered.bodies[1], metrics.TunnelPingStatement) diff --git a/tests/aggregatetunnelmetrics/spec/test_fieldtesting.py b/tests/aggregatetunnelmetrics/spec/test_fieldtesting.py index b426ffb..4d6204b 100644 --- a/tests/aggregatetunnelmetrics/spec/test_fieldtesting.py +++ b/tests/aggregatetunnelmetrics/spec/test_fieldtesting.py @@ -36,6 +36,7 @@ def test_entry_tunnel_measurement(): ndt_target_hostname="ndt.server.com", ndt_target_address="2.2.2.2", ndt_target_port=3001, + endpoint_pool_name="default", ) assert entry.is_tunnel_measurement() is True -- GitLab From 5389e5dd0c455afd4b95c017f5202da3b220e289 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 23 Feb 2025 23:57:25 +0100 Subject: [PATCH 60/75] feat: restructure the pipeline to use new code --- .../aggregators/endpointpool.py | 6 +- aggregatetunnelmetrics/pipeline/__init__.py | 11 +- aggregatetunnelmetrics/pipeline/config.py | 69 ------ aggregatetunnelmetrics/pipeline/pipeline.py | 216 ++++++++++++++++++ aggregatetunnelmetrics/pipeline/processor.py | 151 ------------ aggregatetunnelmetrics/pipeline/state.py | 26 ++- .../pipeline/windowpolicy.py | 12 +- aggregatetunnelmetrics/spec/aggregator.py | 6 +- .../aggregators/test_common.py | 1 + .../pipeline/test_state.py | 24 +- 10 files changed, 259 insertions(+), 263 deletions(-) delete mode 100644 aggregatetunnelmetrics/pipeline/config.py create mode 100644 aggregatetunnelmetrics/pipeline/pipeline.py delete mode 100644 aggregatetunnelmetrics/pipeline/processor.py diff --git a/aggregatetunnelmetrics/aggregators/endpointpool.py b/aggregatetunnelmetrics/aggregators/endpointpool.py index b71e22e..97d9c59 100644 --- a/aggregatetunnelmetrics/aggregators/endpointpool.py +++ b/aggregatetunnelmetrics/aggregators/endpointpool.py @@ -7,7 +7,7 @@ This module implements aggregation using endpoint pool scope. from dataclasses import dataclass, field from datetime import datetime, timezone -from typing import Iterator +from typing import Generator from urllib.parse import urlencode, urlunparse from ..spec import ( @@ -78,8 +78,8 @@ class Aggregator: def aggregate( self, - entries: Iterator[fieldtesting.Entry], - ) -> Iterator[oonicollector.Measurement]: + entries: Generator[fieldtesting.Entry], + ) -> Generator[oonicollector.Measurement]: # Walk through entries updating the mutable state for entry in entries: self.state.update(entry) diff --git a/aggregatetunnelmetrics/pipeline/__init__.py b/aggregatetunnelmetrics/pipeline/__init__.py index db9fad4..b567d27 100644 --- a/aggregatetunnelmetrics/pipeline/__init__.py +++ b/aggregatetunnelmetrics/pipeline/__init__.py @@ -7,18 +7,15 @@ and submitting the resulting metrics to OONI. # SPDX-License-Identifier: GPL-3.0-or-later -from .config import ProcessConfig, FileIOConfig from .errors import PipelineError, StateError -from .processor import MetricsProcessor -from .state import ProcessorState +from .pipeline import Pipeline +from .state import PipelineState from .windowpolicy import Policy __all__ = [ - "FileIOConfig", - "MetricsProcessor", + "Pipeline", "PipelineError", "Policy", - "ProcessorState", - "ProcessConfig", + "PipelineState", "StateError", ] diff --git a/aggregatetunnelmetrics/pipeline/config.py b/aggregatetunnelmetrics/pipeline/config.py deleted file mode 100644 index 6f3c488..0000000 --- a/aggregatetunnelmetrics/pipeline/config.py +++ /dev/null @@ -1,69 +0,0 @@ -""" -Configuration classes for the metrics processing pipeline. - -This module is internal. Please, import `pipeline` directly instead. -""" - -# SPDX-License-Identifier: GPL-3.0-or-later - -from dataclasses import dataclass - -from .. import lockedfile - - -@dataclass(frozen=True) -class ProcessConfig: - """ - Configuration for the full metrics processing pipeline. - - For safety, you must explicitly provide a collector_base_url. You should configure - it to point either to a testing server or to `https://api.ooni.io/`. - - Fields: - provider: Name of the metrics provider. - upstream_collector: Name of the collector used to collect the CSV files. - probe_asn: ASN of the collector (becomes probe_asn in the OONI measurement). - probe_cc: Country code of the collector (becomes probe_cc in the OONI measurement). - min_sample_size: Minimum number of samples to include statistical information. - collector_base_url: Base URL of the OONI collector to use (mandatory). - timeout: Timeout for HTTP requests. - """ - - # Core identification - provider: str - - # Configuration for filling the measurement - upstream_collector: str - probe_asn: str - probe_cc: str - - # Mandatory collector configuration - collector_base_url: str - - # Optional measurement-filling configuration - min_sample_size: int = 1000 - - # Optional collector configuration - timeout: float = 30.0 - - -@dataclass(frozen=True) -class FileIOConfig: - """ - Configuration for file I/O operations. - - Fields: - state_file: Path to the file where to store state information. - num_retries: Number of retries to perform when acquiring the file lock. - sleep_interval: Time to wait between retries when acquiring the lock. - """ - - state_file: str - num_retries: int = 10 - sleep_interval: float = 0.1 - - def as_lockedfile_fileio_config(self) -> lockedfile.ReadWriteConfig: - """Convert to a lockedfile.FileIOConfig.""" - return lockedfile.ReadWriteConfig( - num_retries=self.num_retries, sleep_interval=self.sleep_interval - ) diff --git a/aggregatetunnelmetrics/pipeline/pipeline.py b/aggregatetunnelmetrics/pipeline/pipeline.py new file mode 100644 index 0000000..00d3bd3 --- /dev/null +++ b/aggregatetunnelmetrics/pipeline/pipeline.py @@ -0,0 +1,216 @@ +""" +Main implementation of the pipeline. + +This module is internal. Please, import `pipeline` directly instead. +""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from dataclasses import dataclass, field +from typing import Generator + +from .state import PipelineState +from .windowpolicy import Policy, WeeklyPolicy, Window, generate_windows + +from ..spec import aggregator, fieldtesting, filelocking, metrics, oonicollector +from ..aggregators import common, endpointpool, privacy +from .. import lockedfile + + +@dataclass(frozen=True) +class WindowEntries: + """ + A time window along with its data entries. + + Fields: + entries: List of entries for this window. + window: The time window. + """ + + entries: Generator[fieldtesting.Entry] + window: Window + + +@dataclass(frozen=True) +class WindowMeasurements: + """ + A time window along with its measurements. + + Fields: + measurements: List of measurements for this window. + window: The time window. + """ + + measurements: Generator[oonicollector.Measurement] + window: Window + + +@dataclass +class Pipeline: + """ + Pipeline for processing and submitting tunnel metrics. + + Fields: + collector_client: OONI collector client. + provider: Tunnel provider name. + state_file_path: Path to pipeline state file. + upstream_collector: Descriptor for the CSV files collector. + lockedfile_api: (optional) Locked file API. + lockedfile_config: (optional) Locked file configuration. + privacy_config: (optional) Privacy configuration. + window_policy: (optional) Window policy. + """ + + collector_client: oonicollector.Client + provider: str + state_file_path: str + upstream_collector: common.UpstreamCollector + + lockedfile_api: filelocking.API = field(default_factory=lockedfile.API) + lockedfile_config: filelocking.ReadWriteConfig = field( + default_factory=filelocking.ReadWriteConfig + ) + privacy_config: privacy.Config = field(default_factory=privacy.Config) + csv_streamer: fieldtesting.Streamer = field(default_factory=fieldtesting.Streamer) + window_policy: Policy = field(default_factory=WeeklyPolicy) + + def process_csv_file(self, csv_path: str) -> None: + """ + Process a CSV file through the pipeline stages. + + The pipeline: + 1. Splits input CSV into time windows + 2. Transforms window data into measurements using aggregation + 3. Submits measurements to OONI collector + 4. Updates pipeline state + + Args: + csv_path: Path to field testing CSV file + + Raises: + FileLockError: If cannot acquire locks + StateError: If state file operations fail + ValueError: If CSV parsing fails + """ + + # Build state for this specific processing job + state = PipelineState.load( + self.lockedfile_api, + self.state_file_path, + self.lockedfile_config, + ) + + # Hold lock for the whole pipeline run preventing well-behaving + # concurrent processes from interfering with us. + with self.lockedfile_api.mutex(f"{csv_path}.lock"): + # Build the pipeline in stages + stage1 = self._make_window_entries(csv_path, state) + stage2 = self._transform_to_measurements(stage1) + stage3 = self._submit_to_collector(stage2, state) + + # Process each window and ensure we clear the report ID on error + for window in stage3: + # TODO(bassosimone): figure out a way to print progress + pass + + def _make_window_entries(self, csv_path: str, state: PipelineState,) -> Generator[WindowEntries]: + """ + Stage 1: Split CSV data into a window and its entries. + + Args: + csv_path: Input CSV file path + + Returns: + Iterator of window entries + """ + # Get the list of windows we should process + windows = generate_windows( + policy=self.window_policy, + reference=state.next_submission_after, + ) + + # Stream from the CSV file + entries = self.csv_streamer.stream(csv_path) + + # Simplified algorithm: map each window to its entries + # + # Note: it would be possible to rewrite this code by avoiding to + # use a dictionary *as long *as the CSV files are sorted. + mapping: dict[Window, list[fieldtesting.Entry]] = {} + for window in windows: + for entry in entries: + if window.includes_datetime(entry.date): + mapping.setdefault(window, []).append(entry) + + # Yield each window and the related entries + for window, entries in mapping.items(): + yield WindowEntries(entries=(e for e in entries), window=window) + + def _transform_to_measurements( + self, + entries: Generator[WindowEntries], + ) -> Generator[WindowMeasurements]: + """ + Stage 2: Transform WindowEntries into Measurement using aggregation. + + Args: + window_files: Generator of WindowEntries to process. + + Returns: + Generator of WindowMeasurements. + """ + for entry in entries: + # Create the aggregator for this window + aggr: aggregator.Logic = endpointpool.Aggregator( + provider=self.provider, + pool_country=self.upstream_collector.cc, + time_window=metrics.TimeWindow( + start=entry.window.start, + end=entry.window.end, + ), + upstream_collector=self.upstream_collector, + privacy_config=self.privacy_config, + ) + + # Stream measurements back + yield WindowMeasurements( + measurements=aggr.aggregate((e for e in entry.entries)), + window=entry.window, + ) + + def _submit_to_collector( + self, + entries: Generator[WindowMeasurements], + state: PipelineState, + ) -> Generator[Window]: + """ + Stage 3: Submit measurements to OONI collector and update + the pipeline state to know when we stopped processing. + + Args: + measurements: Generator of WindowMeasurements to submit. + """ + + report_id: oonicollector.ReportID | None = None + for entry in entries: + for measurement in entry.measurements: + + # Automatically open a report if needed + if not report_id: + report_id = self.collector_client.create_report( + measurement.as_open_report_request() + ) + + # Attempt to submit the current measurement + self.collector_client.submit_measurement( + report_id, measurement.with_report_id(report_id) + ) + + # Simplified algorithm - update state *after* we have + # processed a whole window, which is suboptimal because + # we could lose data if the process crashes. + state.next_submission_after = entry.window.end + state.save(self.state_file_path, self.lockedfile_config) + + # Let the caller know we're done with this window + yield entry.window diff --git a/aggregatetunnelmetrics/pipeline/processor.py b/aggregatetunnelmetrics/pipeline/processor.py deleted file mode 100644 index c166562..0000000 --- a/aggregatetunnelmetrics/pipeline/processor.py +++ /dev/null @@ -1,151 +0,0 @@ -""" -Main implementation of the metrics processor. - -This module is internal. Please, import `pipeline` directly instead. -""" - -# SPDX-License-Identifier: GPL-3.0-or-later - -import os -import tempfile - -from .config import ProcessConfig, FileIOConfig -from .state import ProcessorState -from .windowpolicy import Policy, Window, generate_windows - -from .. import fieldtestingcsv -from .. import globalscope -from .. import lockedfile -from .. import ooniformatter -from .. import oonireport - - -class MetricsProcessor: - """High-level API for processing and submitting tunnel metrics.""" - - def __init__( - self, - process_config: ProcessConfig, - fileio_config: FileIOConfig, - window_policy: Policy, - ): - self.process_config = process_config - self.fileio_config = fileio_config - self.window_policy = window_policy - - # Initialize configs for sub-components - self.aggregator_config = globalscope.AggregatorConfig( - provider=process_config.provider - ) - - self.collector_client = oonireport.CollectorClient( - oonireport.CollectorConfig( - collector_base_url=process_config.collector_base_url, - timeout=process_config.timeout, - ) - ) - - # Load initial state - self.state = ProcessorState.load( - fileio_config.state_file, - fileio_config.as_lockedfile_fileio_config(), - ) - - # Track current report ID - self._current_report_id: str | None = None - - def _submit_measurement(self, measurement: oonireport.Measurement) -> None: - """Submit a single measurement to OONI.""" - if not self._current_report_id: - self._current_report_id = ( - self.collector_client.create_report_from_measurement(measurement) - ) - - self.collector_client.update_report(self._current_report_id, measurement) - - def process_csv_file(self, csv_path: str) -> None: - """Process CSV file and submit measurements for complete windows. - - Args: - csv_path: Path to field testing CSV file - - Raises: - FileLockError: If cannot acquire locks - StateError: If state file operations fail - ValueError: If CSV parsing fails - SerializationConfigError: If measurement creation fails - """ - # Use mutex to ensure exclusive access - with lockedfile.Mutex(f"{csv_path}.lock"): - # Get a consistent snapshot of CSV - # TODO(bassosimone): consider whether streaming would be possible here - csv_content = lockedfile.read( - csv_path, - self.fileio_config.as_lockedfile_fileio_config(), - ) - - # Process in temp file - with tempfile.NamedTemporaryFile(mode="w", delete=False) as tmp: - tmp.write(csv_content) - tmp_path = tmp.name - - try: - # Get time windows to process - windows = generate_windows( - policy=self.window_policy, - reference=self.state.next_submission_after, - ) - - # Process each window - for window in windows: - # TODO(bassosimone): this design has the issue that we parse the CSV file - # multiple times. Should we instead just parse it once or instead see whether - # we could split the file on creation into well-defined buckets? - self._process_window(tmp_path, window) - - # Update state after successful window processing - self.state.next_submission_after = window.end - self.state.save( - self.fileio_config.state_file, - self.fileio_config.as_lockedfile_fileio_config(), - ) - - finally: - os.unlink(tmp_path) - # Reset report ID for next processing - self._current_report_id = None - - def _process_window(self, csv_path: str, window: Window) -> None: - """Process entries within a specific time window.""" - # Parse and filter entries for window - entries = fieldtestingcsv.parse_file(csv_path) - window_entries = [e for e in entries if window.includes_datetime(e.date)] - - if not window_entries: - return - - # Create aggregate state - state = globalscope.AggregateState( - config=self.aggregator_config, - window_start=window.start, - window_end=window.end, - ) - - # Update state with entries - for entry in window_entries: - state.update(entry) - - # Create and submit measurements - serializer = ooniformatter.Serializer( - self.aggregator_config, - ooniformatter.Config( - upstream_collector=self.process_config.upstream_collector, - probe_asn=self.process_config.probe_asn, - probe_cc=self.process_config.probe_cc, - min_sample_size=self.process_config.min_sample_size, - ), - ) - - measurements = serializer.serialize_global(state) - for measurement in measurements: - self._submit_measurement(measurement) diff --git a/aggregatetunnelmetrics/pipeline/state.py b/aggregatetunnelmetrics/pipeline/state.py index 76eeea5..3c13d30 100644 --- a/aggregatetunnelmetrics/pipeline/state.py +++ b/aggregatetunnelmetrics/pipeline/state.py @@ -13,22 +13,28 @@ from dataclasses import dataclass from datetime import datetime, timezone import json -from .errors import StateError +from ..spec import filelocking -from .. import lockedfile +from .errors import StateError @dataclass -class ProcessorState: +class PipelineState: """Persistent state of the metrics processor.""" + lockedfile_api: filelocking.API next_submission_after: datetime | None = None @classmethod - def load(cls, path: str, config: lockedfile.ReadWriteConfig) -> ProcessorState: + def load( + cls, + api: filelocking.API, + path: str, + config: filelocking.ReadWriteConfig, + ) -> PipelineState: """Load state from file with proper locking.""" try: - content = lockedfile.read(path, config) + content = api.readfile(path, config) data = json.loads(content) # Parse next_submission_after if present @@ -36,18 +42,18 @@ class ProcessorState: next_after_dt = datetime.strptime(next_after, "%Y%m%dT%H%M%SZ").replace( tzinfo=timezone.utc ) - return cls(next_submission_after=next_after_dt) + return cls(lockedfile_api=api, next_submission_after=next_after_dt) - return cls() + return cls(lockedfile_api=api) except FileNotFoundError: - return cls() + return cls(lockedfile_api=api) except json.JSONDecodeError as e: raise StateError(f"Corrupt state file: {e}") except ValueError as e: raise StateError(f"Invalid datetime in state: {e}") - def save(self, path: str, config: lockedfile.ReadWriteConfig) -> None: + def save(self, path: str, config: filelocking.ReadWriteConfig) -> None: """Save state to file with proper locking.""" data = { "next_submission_after": ( @@ -56,4 +62,4 @@ class ProcessorState: else None ) } - lockedfile.write(path, json.dumps(data), config) + self.lockedfile_api.writefile(path, json.dumps(data), config) diff --git a/aggregatetunnelmetrics/pipeline/windowpolicy.py b/aggregatetunnelmetrics/pipeline/windowpolicy.py index 7f4897c..945beed 100644 --- a/aggregatetunnelmetrics/pipeline/windowpolicy.py +++ b/aggregatetunnelmetrics/pipeline/windowpolicy.py @@ -11,7 +11,7 @@ from __future__ import annotations from dataclasses import dataclass from datetime import datetime, timedelta, timezone -from typing import Protocol, runtime_checkable +from typing import Generator, Protocol, runtime_checkable @dataclass(frozen=True) @@ -147,7 +147,7 @@ def generate_windows( policy: Policy, reference: datetime | None = None, now: datetime | None = None, -) -> list[Window]: +) -> Generator[Window]: """Generates all the windows between a reference time and now using a policy. Args: @@ -157,7 +157,7 @@ def generate_windows( now: Optional current time (defaults to current UTC time). Returns: - WindowList containing generated windows and the next start time. + Generator returning all the windows. Raises: ValueError: If the policy generates invalid windows (through Window validation). @@ -175,13 +175,9 @@ def generate_windows( _validate_utc(now, "now") # Initialize by creating the initial window - windows: list[Window] = [] window = policy.start_window(reference) # Generate windows until the current window contains now while window.before_datetime(now) and not window.includes_datetime(now): - windows.append(window) + yield window window = window.next_window() - - # Return the generated windows - return windows diff --git a/aggregatetunnelmetrics/spec/aggregator.py b/aggregatetunnelmetrics/spec/aggregator.py index c852ae3..fdb0cca 100644 --- a/aggregatetunnelmetrics/spec/aggregator.py +++ b/aggregatetunnelmetrics/spec/aggregator.py @@ -11,7 +11,7 @@ Classes: # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Iterator, Protocol, runtime_checkable +from typing import Generator, Protocol, runtime_checkable from . import fieldtesting from . import oonicollector @@ -28,5 +28,5 @@ class Logic(Protocol): def aggregate( self, - entries: Iterator[fieldtesting.Entry], - ) -> Iterator[oonicollector.Measurement]: ... + entries: Generator[fieldtesting.Entry], + ) -> Generator[oonicollector.Measurement]: ... diff --git a/tests/aggregatetunnelmetrics/aggregators/test_common.py b/tests/aggregatetunnelmetrics/aggregators/test_common.py index 3861219..4b53939 100644 --- a/tests/aggregatetunnelmetrics/aggregators/test_common.py +++ b/tests/aggregatetunnelmetrics/aggregators/test_common.py @@ -92,6 +92,7 @@ def test_make_distribution_two_values(): assert dist.p75 == pytest.approx(2.25) assert dist.p99 == pytest.approx(2.97) + def test_make_distribution_three_values(): """Test distribution with three values""" # With three values, we should get valid quantiles diff --git a/tests/aggregatetunnelmetrics/pipeline/test_state.py b/tests/aggregatetunnelmetrics/pipeline/test_state.py index 1344bb9..6f5b1ad 100644 --- a/tests/aggregatetunnelmetrics/pipeline/test_state.py +++ b/tests/aggregatetunnelmetrics/pipeline/test_state.py @@ -9,7 +9,7 @@ import tempfile import unittest from unittest.mock import patch -from aggregatetunnelmetrics.pipeline.state import ProcessorState +from aggregatetunnelmetrics.pipeline.state import PipelineState from aggregatetunnelmetrics.pipeline.errors import StateError from aggregatetunnelmetrics.lockedfile import ReadWriteConfig @@ -33,12 +33,12 @@ class TestProcessorState(unittest.TestCase): def test_initial_state(self): """Test initial state creation.""" - state = ProcessorState() + state = PipelineState() self.assertIsNone(state.next_submission_after) def test_load_nonexistent_file(self): """Test loading from a nonexistent file returns default state.""" - state = ProcessorState.load(self.state_file, self.config) + state = PipelineState.load(self.state_file, self.config) self.assertIsNone(state.next_submission_after) def test_load_valid_state(self): @@ -48,7 +48,7 @@ class TestProcessorState(unittest.TestCase): with open(self.state_file, "w") as f: json.dump({"next_submission_after": test_time}, f) - state = ProcessorState.load(self.state_file, self.config) + state = PipelineState.load(self.state_file, self.config) expected_dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) self.assertEqual(state.next_submission_after, expected_dt) @@ -59,7 +59,7 @@ class TestProcessorState(unittest.TestCase): f.write("not valid json{") with self.assertRaises(StateError) as cm: - ProcessorState.load(self.state_file, self.config) + PipelineState.load(self.state_file, self.config) self.assertIn("Corrupt state file", str(cm.exception)) def test_load_invalid_datetime(self): @@ -69,7 +69,7 @@ class TestProcessorState(unittest.TestCase): json.dump({"next_submission_after": "invalid-date"}, f) with self.assertRaises(StateError) as cm: - ProcessorState.load(self.state_file, self.config) + PipelineState.load(self.state_file, self.config) self.assertIn("Invalid datetime in state", str(cm.exception)) def test_load_missing_next_submission_after(self): @@ -78,13 +78,13 @@ class TestProcessorState(unittest.TestCase): with open(self.state_file, "w") as f: json.dump({}, f) - state = ProcessorState.load(self.state_file, self.config) + state = PipelineState.load(self.state_file, self.config) self.assertIsNone(state.next_submission_after) def test_save_state(self): """Test saving state to file.""" dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) - state = ProcessorState(next_submission_after=dt) + state = PipelineState(next_submission_after=dt) state.save(self.state_file, self.config) # Verify saved content @@ -94,7 +94,7 @@ class TestProcessorState(unittest.TestCase): def test_save_none_state(self): """Test saving state with None datetime.""" - state = ProcessorState(next_submission_after=None) + state = PipelineState(next_submission_after=None) state.save(self.state_file, self.config) # Verify saved content @@ -104,7 +104,7 @@ class TestProcessorState(unittest.TestCase): def test_save_with_file_error(self): """Test saving state with file write error.""" - state = ProcessorState() + state = PipelineState() # Mock file write to fail with patch("aggregatetunnelmetrics.lockedfile.write") as mock_write: @@ -117,11 +117,11 @@ class TestProcessorState(unittest.TestCase): """Test that saved state can be loaded correctly.""" # Create and save initial state dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) - original_state = ProcessorState(next_submission_after=dt) + original_state = PipelineState(next_submission_after=dt) original_state.save(self.state_file, self.config) # Load state back - loaded_state = ProcessorState.load(self.state_file, self.config) + loaded_state = PipelineState.load(self.state_file, self.config) # Verify loaded state matches original self.assertEqual(loaded_state.next_submission_after, dt) -- GitLab From d1a5f693281ec04d5d8f8e113e844ad127519ad0 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Mon, 24 Feb 2025 00:55:12 +0100 Subject: [PATCH 61/75] refactor: finish rewriting the pipeline --- aggregatetunnelmetrics/aggregators/common.py | 2 + .../aggregators/endpointpool.py | 11 +- aggregatetunnelmetrics/pipeline/pipeline.py | 6 +- .../pipeline/windowpolicy.py | 14 +- .../pipeline/test_config.py | 91 --- .../pipeline/test_pipeline.py | 273 ++++++++ .../pipeline/test_processor.py | 354 ---------- .../pipeline/test_state.py | 233 +++---- .../pipeline/test_windowpolicy.py | 634 +++++++++--------- 9 files changed, 713 insertions(+), 905 deletions(-) delete mode 100644 tests/aggregatetunnelmetrics/pipeline/test_config.py create mode 100644 tests/aggregatetunnelmetrics/pipeline/test_pipeline.py delete mode 100644 tests/aggregatetunnelmetrics/pipeline/test_processor.py diff --git a/aggregatetunnelmetrics/aggregators/common.py b/aggregatetunnelmetrics/aggregators/common.py index 2bb0b80..098f06a 100644 --- a/aggregatetunnelmetrics/aggregators/common.py +++ b/aggregatetunnelmetrics/aggregators/common.py @@ -296,5 +296,7 @@ class UpstreamCollector: asn: str cc: str name: str + + # TODO(bassosimone): it's wrong to have them here. software_name: str software_version: str diff --git a/aggregatetunnelmetrics/aggregators/endpointpool.py b/aggregatetunnelmetrics/aggregators/endpointpool.py index 97d9c59..18964e7 100644 --- a/aggregatetunnelmetrics/aggregators/endpointpool.py +++ b/aggregatetunnelmetrics/aggregators/endpointpool.py @@ -87,13 +87,17 @@ class Aggregator: # Serialize and yield each measurement for pool_name, pool_values in self.state.pools.items(): for proto_name, proto_metrics in pool_values.protocols.items(): - yield self._create_measurement(pool_name, proto_name, proto_metrics) + statements = proto_metrics.statements() + + # Make sure we don't generate empty measurements + if statements: + yield self._create_measurement(pool_name, proto_name, statements) def _create_measurement( self, pool_name: str, proto_name: str, - proto_metrics: ProtocolMetrics, + bodies: list[metrics.Statement], ) -> oonicollector.Measurement: """ Creates a new OONI Measurement from the given metrics. @@ -105,9 +109,6 @@ class Aggregator: The OONI Measurement. """ - # Serialize the bodies to a list of statements - bodies = proto_metrics.statements() - # Apply privacy filters to the bodies bodies = [ privacy.filter_statement(stmt, self.privacy_config) for stmt in bodies diff --git a/aggregatetunnelmetrics/pipeline/pipeline.py b/aggregatetunnelmetrics/pipeline/pipeline.py index 00d3bd3..274d138 100644 --- a/aggregatetunnelmetrics/pipeline/pipeline.py +++ b/aggregatetunnelmetrics/pipeline/pipeline.py @@ -113,7 +113,11 @@ class Pipeline: # TODO(bassosimone): figure out a way to print progress pass - def _make_window_entries(self, csv_path: str, state: PipelineState,) -> Generator[WindowEntries]: + def _make_window_entries( + self, + csv_path: str, + state: PipelineState, + ) -> Generator[WindowEntries]: """ Stage 1: Split CSV data into a window and its entries. diff --git a/aggregatetunnelmetrics/pipeline/windowpolicy.py b/aggregatetunnelmetrics/pipeline/windowpolicy.py index 945beed..96c35ca 100644 --- a/aggregatetunnelmetrics/pipeline/windowpolicy.py +++ b/aggregatetunnelmetrics/pipeline/windowpolicy.py @@ -27,8 +27,8 @@ class Window: return f"Window({self.start.isoformat()} -> {self.end.isoformat()})" def __post_init__(self): - _validate_utc(self.start, "window start") - _validate_utc(self.end, "window end") + validate_utc(self.start, "window start") + validate_utc(self.end, "window end") if self.start >= self.end: raise ValueError("window start must be before end") if self.delta <= timedelta(0): @@ -78,7 +78,7 @@ class DailyPolicy: A Window starting from 00:00 UTC of the current day, ending the following day. """ - _validate_utc(reference, "reference") + validate_utc(reference, "reference") today_at_midnight = reference.replace(hour=0, minute=0, second=0, microsecond=0) delta = timedelta(days=1) tomorrow_at_midnight = today_at_midnight + delta @@ -103,7 +103,7 @@ class WeeklyPolicy: A Window starting from Monday 00:00 UTC of the current week, ending the following Monday. """ - _validate_utc(reference, "reference") + validate_utc(reference, "reference") # Get midnight of the reference day today_at_midnight = reference.replace(hour=0, minute=0, second=0, microsecond=0) @@ -129,7 +129,7 @@ def datetime_utcnow() -> datetime: return datetime.now(timezone.utc) -def _validate_utc(dt: datetime, param_name: str) -> None: +def validate_utc(dt: datetime, param_name: str) -> None: """Validate that a datetime is UTC. Args: @@ -167,12 +167,12 @@ def generate_windows( # Ensure reference is a valid datetime if reference is None: reference = project_start_time - _validate_utc(reference, "reference") + validate_utc(reference, "reference") # Ensure now is a valid datetime if now is None: now = datetime_utcnow() - _validate_utc(now, "now") + validate_utc(now, "now") # Initialize by creating the initial window window = policy.start_window(reference) diff --git a/tests/aggregatetunnelmetrics/pipeline/test_config.py b/tests/aggregatetunnelmetrics/pipeline/test_config.py deleted file mode 100644 index c9af26a..0000000 --- a/tests/aggregatetunnelmetrics/pipeline/test_config.py +++ /dev/null @@ -1,91 +0,0 @@ -"""Tests for the pipeline configuration module.""" - -# SPDX-License-Identifier: GPL-3.0-or-later - -import unittest - -from aggregatetunnelmetrics.pipeline.config import ProcessConfig, FileIOConfig -from aggregatetunnelmetrics.lockedfile import ReadWriteConfig as LockedFileIOConfig - - -class TestProcessConfig(unittest.TestCase): - """Test ProcessConfig functionality.""" - - def test_valid_minimal_config(self): - """Test creating config with just mandatory fields.""" - config = ProcessConfig( - provider="test-provider", - upstream_collector="test-collector", - probe_asn="AS12345", - probe_cc="XX", - collector_base_url="https://api.ooni.io/", - ) - - self.assertEqual(config.provider, "test-provider") - self.assertEqual(config.upstream_collector, "test-collector") - self.assertEqual(config.probe_asn, "AS12345") - self.assertEqual(config.probe_cc, "XX") - self.assertEqual(config.collector_base_url, "https://api.ooni.io/") - self.assertEqual(config.min_sample_size, 1000) # default value - self.assertEqual(config.timeout, 30.0) # default value - - def test_valid_full_config(self): - """Test creating config with all fields specified.""" - config = ProcessConfig( - provider="test-provider", - upstream_collector="test-collector", - probe_asn="AS12345", - probe_cc="XX", - collector_base_url="https://api.ooni.io/", - min_sample_size=500, - timeout=60.0, - ) - - self.assertEqual(config.provider, "test-provider") - self.assertEqual(config.upstream_collector, "test-collector") - self.assertEqual(config.probe_asn, "AS12345") - self.assertEqual(config.probe_cc, "XX") - self.assertEqual(config.collector_base_url, "https://api.ooni.io/") - self.assertEqual(config.min_sample_size, 500) - self.assertEqual(config.timeout, 60.0) - - -class TestFileIOConfig(unittest.TestCase): - """Test FileIOConfig functionality.""" - - def test_valid_minimal_config(self): - """Test creating config with just mandatory fields.""" - config = FileIOConfig(state_file="/path/to/state.json") - - self.assertEqual(config.state_file, "/path/to/state.json") - self.assertEqual(config.num_retries, 10) # default value - self.assertEqual(config.sleep_interval, 0.1) # default value - - def test_valid_full_config(self): - """Test creating config with all fields specified.""" - config = FileIOConfig( - state_file="/path/to/state.json", - num_retries=5, - sleep_interval=0.2, - ) - - self.assertEqual(config.state_file, "/path/to/state.json") - self.assertEqual(config.num_retries, 5) - self.assertEqual(config.sleep_interval, 0.2) - - def test_conversion_to_lockedfile_config(self): - """Test conversion to lockedfile.FileIOConfig.""" - config = FileIOConfig( - state_file="/path/to/state.json", - num_retries=5, - sleep_interval=0.2, - ) - - locked_config = config.as_lockedfile_fileio_config() - self.assertIsInstance(locked_config, LockedFileIOConfig) - self.assertEqual(locked_config.num_retries, 5) - self.assertEqual(locked_config.sleep_interval, 0.2) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/aggregatetunnelmetrics/pipeline/test_pipeline.py b/tests/aggregatetunnelmetrics/pipeline/test_pipeline.py new file mode 100644 index 0000000..6a3a9ae --- /dev/null +++ b/tests/aggregatetunnelmetrics/pipeline/test_pipeline.py @@ -0,0 +1,273 @@ +"""Tests for the metrics processing pipeline.""" + +# SPDX-License-Identifier: GPL-3.0-or-later + +from datetime import datetime, timedelta, timezone +from typing import Generator + +import pytest + +from aggregatetunnelmetrics.pipeline.errors import StateError +from aggregatetunnelmetrics.pipeline.pipeline import Pipeline +from aggregatetunnelmetrics.pipeline.state import PipelineState +from aggregatetunnelmetrics.spec import fieldtesting, filelocking, oonicollector +from aggregatetunnelmetrics.aggregators import common, privacy +from aggregatetunnelmetrics.pipeline.windowpolicy import DailyPolicy, Window + +# Test fixtures + + +@pytest.fixture +def reference_time(): + """Fixed reference time for testing.""" + return datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + + +@pytest.fixture +def test_window(reference_time): + """Create a test window.""" + return Window( + start=reference_time, + end=reference_time + timedelta(days=1), + delta=timedelta(days=1), + ) + + +@pytest.fixture +def mock_collector_client(): + """Create a mock OONI collector client.""" + + class MockClient: + def __init__(self): + self.report_id = "test-report-id" + self.measurements = [] + + def create_report(self, req: oonicollector.OpenReportRequest) -> str: + return self.report_id + + def submit_measurement(self, rid: str, m: oonicollector.Measurement) -> None: + assert rid == self.report_id + self.measurements.append(m) + + return MockClient() + + +@pytest.fixture +def mock_file_api(): + """Create a mock file locking API.""" + + class MockAPI: + def __init__(self): + self.content = "{}" # Initialize with valid empty JSON object + + def readfile(self, path: str, config: filelocking.ReadWriteConfig) -> str: + return self.content + + def writefile( + self, path: str, data: str, config: filelocking.ReadWriteConfig + ) -> None: + self.content = data + + def mutex(self, path: str) -> filelocking.Mutex: + class MockMutex: + def __enter__(self): + return self + + def __exit__(self, *args): + pass + + return MockMutex() + + return MockAPI() + + +@pytest.fixture +def mock_csv_streamer(): + """Create a mock CSV streamer.""" + + class MockStreamer: + def __init__(self): + self.entries = [] + + def stream(self, filepath: str) -> Generator[fieldtesting.Entry]: + for entry in self.entries: + yield entry + + return MockStreamer() + + +@pytest.fixture +def pipeline(mock_collector_client, mock_file_api, mock_csv_streamer, tmp_path): + """Create a pipeline instance with mocked dependencies.""" + return Pipeline( + collector_client=mock_collector_client, + provider="test-provider", + state_file_path=str(tmp_path / "state.json"), + upstream_collector=common.UpstreamCollector( + asn="AS12345", + cc="US", + name="test-collector", + software_name="solitech_collector", + software_version="0.1.0", + ), + lockedfile_api=mock_file_api, + lockedfile_config=filelocking.ReadWriteConfig(), + privacy_config=privacy.Config(), + csv_streamer=mock_csv_streamer, + window_policy=DailyPolicy(), + ) + + +# Helper functions + + +def create_test_entry(date: datetime, is_tunnel: bool = True) -> fieldtesting.Entry: + """Create a test entry.""" + return fieldtesting.Entry( + filename="test.csv", + date=date, + asn="AS12345", + isp="Test ISP", + est_city="Test City", + user="testuser", + region="testregion", + server_fqdn="test.server.com", + server_ip="1.1.1.1", + mobile=False, + tunnel="tunnel" if is_tunnel else "baseline", + throughput_download=100.0, + throughput_upload=50.0, + latency_download=20.0, + latency_upload=25.0, + retransmission_download=0.01, + retransmission_upload=0.02, + ping_packets_loss=0.0, + ping_roundtrip_min=10.0, + ping_roundtrip_avg=15.0, + ping_roundtrip_max=20.0, + err_message="", + protocol="obfs4", + ping_target_address=None, + ndt_target_hostname=None, + ndt_target_address=None, + ndt_target_port=None, + endpoint_pool_name=None, + ) + + +# Tests + + +def test_empty_csv_processing(pipeline, tmp_path): + """Test processing an empty CSV file.""" + csv_path = tmp_path / "test.csv" + csv_path.touch() + + pipeline.process_csv_file(str(csv_path)) + assert not pipeline.collector_client.measurements + + +def test_successful_processing(pipeline, tmp_path, reference_time): + """Test successful processing of CSV with valid entries.""" + csv_path = tmp_path / "test.csv" + csv_path.touch() + + # Add test entry to mock streamer + entry = create_test_entry(reference_time + timedelta(hours=1)) + pipeline.csv_streamer.entries = [entry] + + pipeline.process_csv_file(str(csv_path)) + assert len(pipeline.collector_client.measurements) > 0 + + +def test_skip_out_of_window_entries(pipeline, tmp_path, test_window): + """Test that entries outside the current window are skipped.""" + csv_path = tmp_path / "test.csv" + csv_path.touch() + + # Create entries outside the window + entries = [ + create_test_entry(test_window.start - timedelta(days=1)), + create_test_entry(test_window.end + timedelta(days=1)), + ] + pipeline.csv_streamer.entries = entries + + pipeline.process_csv_file(str(csv_path)) + assert not pipeline.collector_client.measurements + + +@pytest.mark.parametrize("is_tunnel", [False]) +def test_skip_non_tunnel_entries(pipeline, tmp_path, test_window, is_tunnel): + """Test that non-tunnel entries are skipped.""" + csv_path = tmp_path / "test.csv" + csv_path.touch() + + entries = [ + create_test_entry(test_window.start + timedelta(hours=1), is_tunnel=is_tunnel), + create_test_entry(test_window.start + timedelta(hours=2), is_tunnel=is_tunnel), + ] + pipeline.csv_streamer.entries = entries + + pipeline.process_csv_file(str(csv_path)) + assert len(pipeline.collector_client.measurements) == 0 + + +def test_submission_error_handling(pipeline, tmp_path, reference_time): + """Test handling of measurement submission errors.""" + csv_path = tmp_path / "test.csv" + csv_path.touch() + + # Setup error condition + class MockAPIError(Exception): + pass + + def raise_error(*args): + raise MockAPIError("test error") + + pipeline.collector_client.create_report = raise_error + + entry = create_test_entry(reference_time + timedelta(hours=1)) + pipeline.csv_streamer.entries = [entry] + + with pytest.raises(MockAPIError): + pipeline.process_csv_file(str(csv_path)) + + +def test_state_persistence(pipeline, tmp_path, reference_time): + """Test state persistence between runs.""" + csv_path = tmp_path / "test.csv" + csv_path.touch() + + # Create and process initial state + entry = create_test_entry(reference_time + timedelta(hours=1)) + pipeline.csv_streamer.entries = [entry] + pipeline.process_csv_file(str(csv_path)) + + # Create new pipeline instance with same components + new_csv_streamer = type(pipeline.csv_streamer)() + new_pipeline = Pipeline( + collector_client=pipeline.collector_client, + provider=pipeline.provider, + state_file_path=pipeline.state_file_path, + upstream_collector=pipeline.upstream_collector, + lockedfile_api=pipeline.lockedfile_api, + csv_streamer=new_csv_streamer, + ) + + # Process file again + new_pipeline.process_csv_file(str(csv_path)) + + # Verify state was preserved + assert "next_submission_after" in pipeline.lockedfile_api.content + + +def test_load_empty_content(mock_file_api, tmp_path): + """Test handling of empty content from file API.""" + mock_file_api.content = "" # Explicitly set empty content + + with pytest.raises(StateError, match="Corrupt state file"): + PipelineState.load( + mock_file_api, + str(tmp_path / "state.json"), + filelocking.ReadWriteConfig(), + ) diff --git a/tests/aggregatetunnelmetrics/pipeline/test_processor.py b/tests/aggregatetunnelmetrics/pipeline/test_processor.py deleted file mode 100644 index 2933ad8..0000000 --- a/tests/aggregatetunnelmetrics/pipeline/test_processor.py +++ /dev/null @@ -1,354 +0,0 @@ -"""Tests for the metrics processing pipeline.""" - -# SPDX-License-Identifier: GPL-3.0-or-later - -from datetime import datetime, timedelta, timezone -import os -import tempfile -import unittest -from unittest.mock import Mock, patch - -from aggregatetunnelmetrics.pipeline.processor import MetricsProcessor -from aggregatetunnelmetrics.pipeline.config import ProcessConfig, FileIOConfig -from aggregatetunnelmetrics.pipeline.windowpolicy import DailyPolicy, Window -from aggregatetunnelmetrics.fieldtestingcsv import Entry -from aggregatetunnelmetrics.oonireport import APIError - - -class TestMetricsProcessor(unittest.TestCase): - """Test MetricsProcessor functionality.""" - - def setUp(self): - """Set up test fixtures.""" - # Setup the filesystem - self.temp_dir = tempfile.mkdtemp() - self.csv_path = os.path.join(self.temp_dir, "test.csv") - self.state_path = os.path.join(self.temp_dir, "state.json") - - # Time fixture - using a fixed known time - self.reference_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) - - # Create basic configs - self.process_config = ProcessConfig( - provider="test-provider", - upstream_collector="test-collector", - probe_asn="AS12345", - probe_cc="XX", - collector_base_url="https://example.org/", - min_sample_size=10, # Small value for testing - ) - - self.fileio_config = FileIOConfig( - state_file=self.state_path, - num_retries=1, - sleep_interval=0.1, - ) - - self.window_policy = DailyPolicy() - - # Create processor instance - self.processor = MetricsProcessor( - self.process_config, - self.fileio_config, - self.window_policy, - ) - - # Create empty CSV file - with open(self.csv_path, "w") as f: - f.write("") - - # Setup collector mock - self.collector_patcher = patch( - "aggregatetunnelmetrics.oonireport.CollectorClient" - ) - self.mock_collector = self.collector_patcher.start() - self.mock_collector_instance = Mock() - self.mock_collector.return_value = self.mock_collector_instance - - # Setup file operations mock - self.read_patcher = patch("aggregatetunnelmetrics.lockedfile.read") - self.mock_read = self.read_patcher.start() - self.mock_read.return_value = "" - - def tearDown(self): - """Clean up temporary files.""" - self.collector_patcher.stop() - self.read_patcher.stop() - try: - os.unlink(self.csv_path) - os.unlink(self.state_path) - os.rmdir(self.temp_dir) - except FileNotFoundError: - pass - - def test_initialization(self): - """Test proper initialization of MetricsProcessor.""" - self.assertEqual(self.processor.process_config, self.process_config) - self.assertEqual(self.processor.fileio_config, self.fileio_config) - self.assertEqual(self.processor.window_policy, self.window_policy) - self.assertEqual( - self.processor.aggregator_config.provider, self.process_config.provider - ) - self.assertIsNone(self.processor._current_report_id) - - @patch("aggregatetunnelmetrics.fieldtestingcsv.parse_file") - @patch("aggregatetunnelmetrics.lockedfile.Mutex") - @patch("aggregatetunnelmetrics.pipeline.windowpolicy.generate_windows") - def test_empty_csv_processing(self, mock_generate_windows, mock_mutex, mock_parse): - """Test processing an empty CSV file.""" - # Set up empty window list - mock_generate_windows.return_value = [] - - # Mock parse_file to return empty list - mock_parse.return_value = [] - - # Mock mutex context manager - mock_mutex.return_value.__enter__.return_value = Mock() - mock_mutex.return_value.__exit__.return_value = None - - self.processor.process_csv_file(self.csv_path) - - # Verify mutex was used with correct path - mock_mutex.assert_called_once_with(f"{self.csv_path}.lock") - mock_parse.assert_not_called() - - def _create_test_entry(self, date: datetime, is_tunnel: bool = True) -> Entry: - """Helper to create a test entry.""" - return Entry( - filename="test.csv", - date=date, - asn="AS12345", - isp="Test ISP", - est_city="Test City", - user="testuser", - region="testregion", - server_fqdn="test.server.com", - server_ip="1.1.1.1", - mobile=False, - tunnel="tunnel" if is_tunnel else "baseline", - protocol="obfs4", - throughput_download=100.0, - throughput_upload=50.0, - latency_download=20.0, - latency_upload=25.0, - retransmission_download=0.01, - retransmission_upload=0.02, - ping_packets_loss=0.0, - ping_roundtrip_min=10.0, - ping_roundtrip_avg=15.0, - ping_roundtrip_max=20.0, - err_message="", - ) - - @patch("aggregatetunnelmetrics.fieldtestingcsv.parse_file") - @patch("aggregatetunnelmetrics.lockedfile.Mutex") - @patch("aggregatetunnelmetrics.pipeline.windowpolicy.generate_windows") - def test_successful_processing(self, mock_generate_windows, mock_mutex, mock_parse): - """Test successful processing of CSV with valid entries.""" - # Set up window - window = Window( - start=self.reference_time, - end=self.reference_time + timedelta(days=1), - delta=timedelta(days=1), - ) - mock_generate_windows.return_value = [window] - - # Set up entries - entries = [self._create_test_entry(self.reference_time + timedelta(hours=1))] - mock_parse.return_value = entries - - # Mock mutex context manager - mock_mutex.return_value.__enter__.return_value = Mock() - mock_mutex.return_value.__exit__.return_value = None - - # Set up successful submission - self.mock_collector_instance.create_report_from_measurement.return_value = ( - "test-report" - ) - - # Process the file - self.processor.process_csv_file(self.csv_path) - - # Verify submission happened - self.mock_collector_instance.create_report_from_measurement.assert_called_once() - - @patch("aggregatetunnelmetrics.fieldtestingcsv.parse_file") - @patch("aggregatetunnelmetrics.lockedfile.Mutex") - @patch("aggregatetunnelmetrics.pipeline.windowpolicy.generate_windows") - def test_skip_out_of_window_entries( - self, mock_generate_windows, mock_mutex, mock_parse - ): - """Test that entries outside the current window are skipped.""" - # Set up window - window = Window( - start=self.reference_time, - end=self.reference_time + timedelta(days=1), - delta=timedelta(days=1), - ) - mock_generate_windows.return_value = [window] - - # Create test entries outside the window - entries = [ - self._create_test_entry(window.start - timedelta(days=1)), - self._create_test_entry(window.end + timedelta(days=1)), - ] - mock_parse.return_value = entries - - # Mock mutex context manager - mock_mutex.return_value.__enter__.return_value = Mock() - mock_mutex.return_value.__exit__.return_value = None - - self.processor.process_csv_file(self.csv_path) - - # Verify no measurements were created - self.assertIsNone(self.processor._current_report_id) - - @patch("aggregatetunnelmetrics.fieldtestingcsv.parse_file") - @patch("aggregatetunnelmetrics.lockedfile.Mutex") - @patch("aggregatetunnelmetrics.pipeline.windowpolicy.generate_windows") - def test_skip_non_tunnel_entries( - self, mock_generate_windows, mock_mutex, mock_parse - ): - """Test that non-tunnel entries are skipped.""" - # Set up window - window = Window( - start=self.reference_time, - end=self.reference_time + timedelta(days=1), - delta=timedelta(days=1), - ) - mock_generate_windows.return_value = [window] - - # Create test entries with non-tunnel measurements - entries = [ - self._create_test_entry(window.start + timedelta(hours=1), is_tunnel=False), - self._create_test_entry(window.start + timedelta(hours=2), is_tunnel=False), - ] - mock_parse.return_value = entries - - # Mock mutex context manager - mock_mutex.return_value.__enter__.return_value = Mock() - mock_mutex.return_value.__exit__.return_value = None - - self.processor.process_csv_file(self.csv_path) - - # Verify no measurements were created - self.assertIsNone(self.processor._current_report_id) - - @patch("aggregatetunnelmetrics.fieldtestingcsv.parse_file") - @patch("aggregatetunnelmetrics.lockedfile.Mutex") - @patch("aggregatetunnelmetrics.pipeline.windowpolicy.generate_windows") - def test_submission_error_handling( - self, mock_generate_windows, mock_mutex, mock_parse - ): - """Test handling of measurement submission errors.""" - # Set up window - window = Window( - start=self.reference_time, - end=self.reference_time + timedelta(days=1), - delta=timedelta(days=1), - ) - mock_generate_windows.return_value = [window] - - # Create test entries - entries = [ - self._create_test_entry(self.reference_time + timedelta(hours=1)), - ] - mock_parse.return_value = entries - - # Mock mutex context manager - mock_mutex.return_value.__enter__.return_value = Mock() - mock_mutex.return_value.__exit__.return_value = None - - # Set up submission failure - self.mock_collector_instance.create_report_from_measurement.side_effect = ( - APIError("test error") - ) - - # Process should raise the error - with self.assertRaises(APIError): - self.processor.process_csv_file(self.csv_path) - - @patch("aggregatetunnelmetrics.fieldtestingcsv.parse_file") - @patch("aggregatetunnelmetrics.lockedfile.Mutex") - def test_mutex_error_handling(self, mock_mutex, mock_parse): - """Test handling of mutex acquisition errors.""" - # Mock mutex acquisition failure - mock_mutex.return_value.__enter__.side_effect = Exception("Lock Error") - - with self.assertRaises(Exception): - self.processor.process_csv_file(self.csv_path) - - def test_state_persistence(self): - """Test that state is properly persisted between processing runs.""" - # Create initial state - self.processor.state.next_submission_after = self.reference_time - self.processor.state.save( - self.state_path, self.fileio_config.as_lockedfile_fileio_config() - ) - - # Create new processor instance - new_processor = MetricsProcessor( - self.process_config, - self.fileio_config, - self.window_policy, - ) - - # Verify state was loaded - self.assertEqual(new_processor.state.next_submission_after, self.reference_time) - - @patch("aggregatetunnelmetrics.fieldtestingcsv.parse_file") - @patch("aggregatetunnelmetrics.lockedfile.Mutex") - @patch("aggregatetunnelmetrics.pipeline.windowpolicy.generate_windows") - def test_state_unchanged_on_submission_failure( - self, mock_generate_windows, mock_mutex, mock_parse - ): - """Test that state is not updated when measurement submission fails.""" - # Set up initial state - self.processor.state.next_submission_after = self.reference_time - self.processor.state.save( - self.state_path, self.fileio_config.as_lockedfile_fileio_config() - ) - - # Set up window - window = Window( - start=self.reference_time, - end=self.reference_time + timedelta(days=1), - delta=timedelta(days=1), - ) - mock_generate_windows.return_value = [window] - - # Create test entries - entries = [ - self._create_test_entry(self.reference_time + timedelta(hours=1)), - ] - mock_parse.return_value = entries - - # Mock mutex context manager - mock_mutex.return_value.__enter__.return_value = Mock() - mock_mutex.return_value.__exit__.return_value = None - - # Set up submission failure - self.mock_collector_instance.create_report_from_measurement.side_effect = ( - APIError("test error") - ) - - # Process should raise the error - with self.assertRaises(APIError): - self.processor.process_csv_file(self.csv_path) - - # Verify state hasn't changed - self.assertEqual( - self.processor.state.next_submission_after, self.reference_time - ) - - # Load state from disk and verify it hasn't changed - new_processor = MetricsProcessor( - self.process_config, - self.fileio_config, - self.window_policy, - ) - self.assertEqual(new_processor.state.next_submission_after, self.reference_time) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/aggregatetunnelmetrics/pipeline/test_state.py b/tests/aggregatetunnelmetrics/pipeline/test_state.py index 6f5b1ad..3c72a52 100644 --- a/tests/aggregatetunnelmetrics/pipeline/test_state.py +++ b/tests/aggregatetunnelmetrics/pipeline/test_state.py @@ -3,129 +3,118 @@ # SPDX-License-Identifier: GPL-3.0-or-later from datetime import datetime, timezone + +import pytest import json -import os -import tempfile -import unittest -from unittest.mock import patch from aggregatetunnelmetrics.pipeline.state import PipelineState from aggregatetunnelmetrics.pipeline.errors import StateError -from aggregatetunnelmetrics.lockedfile import ReadWriteConfig - - -class TestProcessorState(unittest.TestCase): - """Test ProcessorState functionality.""" - - def setUp(self): - """Create a temporary file for testing.""" - self.temp_dir = tempfile.mkdtemp() - self.state_file = os.path.join(self.temp_dir, "state.json") - self.config = ReadWriteConfig(num_retries=1, sleep_interval=0.1) - - def tearDown(self): - """Clean up temporary files.""" - try: - os.unlink(self.state_file) - os.rmdir(self.temp_dir) - except FileNotFoundError: - pass - - def test_initial_state(self): - """Test initial state creation.""" - state = PipelineState() - self.assertIsNone(state.next_submission_after) - - def test_load_nonexistent_file(self): - """Test loading from a nonexistent file returns default state.""" - state = PipelineState.load(self.state_file, self.config) - self.assertIsNone(state.next_submission_after) - - def test_load_valid_state(self): - """Test loading valid state from file.""" - # Write valid state file - test_time = "20240101T120000Z" - with open(self.state_file, "w") as f: - json.dump({"next_submission_after": test_time}, f) - - state = PipelineState.load(self.state_file, self.config) - expected_dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) - self.assertEqual(state.next_submission_after, expected_dt) - - def test_load_corrupt_json(self): - """Test loading corrupted JSON state file.""" - # Write invalid JSON - with open(self.state_file, "w") as f: - f.write("not valid json{") - - with self.assertRaises(StateError) as cm: - PipelineState.load(self.state_file, self.config) - self.assertIn("Corrupt state file", str(cm.exception)) - - def test_load_invalid_datetime(self): - """Test loading state with invalid datetime format.""" - # Write state with invalid datetime - with open(self.state_file, "w") as f: - json.dump({"next_submission_after": "invalid-date"}, f) - - with self.assertRaises(StateError) as cm: - PipelineState.load(self.state_file, self.config) - self.assertIn("Invalid datetime in state", str(cm.exception)) - - def test_load_missing_next_submission_after(self): - """Test loading state file with missing next_submission_after field returns default state.""" - # Write state file without next_submission_after field - with open(self.state_file, "w") as f: - json.dump({}, f) - - state = PipelineState.load(self.state_file, self.config) - self.assertIsNone(state.next_submission_after) - - def test_save_state(self): - """Test saving state to file.""" - dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) - state = PipelineState(next_submission_after=dt) - state.save(self.state_file, self.config) - - # Verify saved content - with open(self.state_file) as f: - saved_state = json.load(f) - self.assertEqual(saved_state["next_submission_after"], "20240101T120000Z") - - def test_save_none_state(self): - """Test saving state with None datetime.""" - state = PipelineState(next_submission_after=None) - state.save(self.state_file, self.config) - - # Verify saved content - with open(self.state_file) as f: - saved_state = json.load(f) - self.assertIsNone(saved_state["next_submission_after"]) - - def test_save_with_file_error(self): - """Test saving state with file write error.""" - state = PipelineState() - - # Mock file write to fail - with patch("aggregatetunnelmetrics.lockedfile.write") as mock_write: - mock_write.side_effect = IOError("Mock write error") - - with self.assertRaises(IOError): - state.save(self.state_file, self.config) - - def test_load_save_roundtrip(self): - """Test that saved state can be loaded correctly.""" - # Create and save initial state - dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) - original_state = PipelineState(next_submission_after=dt) - original_state.save(self.state_file, self.config) - - # Load state back - loaded_state = PipelineState.load(self.state_file, self.config) - - # Verify loaded state matches original - self.assertEqual(loaded_state.next_submission_after, dt) - - -if __name__ == "__main__": - unittest.main() +from aggregatetunnelmetrics.lockedfile import ReadWriteConfig, API + + +@pytest.fixture +def locked_file_api(): + """Create a mock locked file API.""" + return API() + + +@pytest.fixture +def config(): + """Create a test configuration.""" + return ReadWriteConfig(num_retries=1, sleep_interval=0.1) + + +@pytest.fixture +def state_file(tmp_path): + """Create a temporary state file path.""" + return tmp_path / "state.json" + + +def test_initial_state(locked_file_api): + """Test initial state creation.""" + state = PipelineState(lockedfile_api=locked_file_api) + assert state.next_submission_after is None + + +def test_load_nonexistent_file(locked_file_api, state_file, config): + """Test loading from a nonexistent file returns default state.""" + state = PipelineState.load(locked_file_api, state_file, config) + assert state.next_submission_after is None + + +def test_load_valid_state(locked_file_api, state_file, config): + """Test loading valid state from file.""" + test_time = "20240101T120000Z" + state_file.write_text(json.dumps({"next_submission_after": test_time})) + + state = PipelineState.load(locked_file_api, state_file, config) + expected_dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + assert state.next_submission_after == expected_dt + + +def test_load_corrupt_json(locked_file_api, state_file, config): + """Test loading corrupted JSON state file.""" + state_file.write_text("not valid json{") + + with pytest.raises(StateError, match="Corrupt state file"): + PipelineState.load(locked_file_api, state_file, config) + + +def test_load_invalid_datetime(locked_file_api, state_file, config): + """Test loading state with invalid datetime format.""" + state_file.write_text(json.dumps({"next_submission_after": "invalid-date"})) + + with pytest.raises(StateError, match="Invalid datetime in state"): + PipelineState.load(locked_file_api, state_file, config) + + +def test_load_missing_next_submission_after(locked_file_api, state_file, config): + """Test loading state file with missing next_submission_after field returns default state.""" + state_file.write_text(json.dumps({})) + + state = PipelineState.load(locked_file_api, state_file, config) + assert state.next_submission_after is None + + +def test_save_state(locked_file_api, state_file, config): + """Test saving state to file.""" + dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + state = PipelineState(lockedfile_api=locked_file_api, next_submission_after=dt) + state.save(state_file, config) + + saved_state = json.loads(state_file.read_text()) + assert saved_state["next_submission_after"] == "20240101T120000Z" + + +def test_save_none_state(locked_file_api, state_file, config): + """Test saving state with None datetime.""" + state = PipelineState(lockedfile_api=locked_file_api, next_submission_after=None) + state.save(state_file, config) + + saved_state = json.loads(state_file.read_text()) + assert saved_state["next_submission_after"] is None + + +def test_save_with_file_error(locked_file_api, state_file, config, monkeypatch): + """Test saving state with file write error.""" + state = PipelineState(lockedfile_api=locked_file_api) + + def mock_write(*args): + raise IOError("Mock write error") + + monkeypatch.setattr(locked_file_api, "writefile", mock_write) + + with pytest.raises(IOError): + state.save(state_file, config) + + +def test_load_save_roundtrip(locked_file_api, state_file, config): + """Test that saved state can be loaded correctly.""" + dt = datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc) + original_state = PipelineState( + lockedfile_api=locked_file_api, next_submission_after=dt + ) + original_state.save(state_file, config) + + loaded_state = PipelineState.load(locked_file_api, state_file, config) + assert loaded_state.next_submission_after == dt diff --git a/tests/aggregatetunnelmetrics/pipeline/test_windowpolicy.py b/tests/aggregatetunnelmetrics/pipeline/test_windowpolicy.py index f5ed3aa..c8b3434 100644 --- a/tests/aggregatetunnelmetrics/pipeline/test_windowpolicy.py +++ b/tests/aggregatetunnelmetrics/pipeline/test_windowpolicy.py @@ -2,333 +2,317 @@ # SPDX-License-Identifier: GPL-3.0-or-later -import unittest +import pytest from datetime import datetime, timedelta, timezone -from aggregatetunnelmetrics.pipeline.windowpolicy import ( - DailyPolicy, - WeeklyPolicy, - Window, - _validate_utc, - datetime_utcnow, - generate_windows, +from aggregatetunnelmetrics.pipeline import windowpolicy + +# Window Tests + + +def test_valid_window_creation(): + """Test creating a valid window.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + delta = timedelta(days=1) + window = windowpolicy.Window(start, end, delta) + assert window.start == start + assert window.end == end + assert window.delta == delta + + +def test_window_validation_non_utc(): + """Test window creation with non-UTC times.""" + start = datetime(2024, 1, 1) # naive datetime + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + delta = timedelta(days=1) + with pytest.raises(ValueError): + windowpolicy.Window(start, end, delta) + + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2) # naive datetime + with pytest.raises(ValueError): + windowpolicy.Window(start, end, delta) + + +def test_window_validation_invalid_times(): + """Test window creation with invalid time combinations.""" + start = datetime(2024, 1, 2, tzinfo=timezone.utc) + end = datetime(2024, 1, 1, tzinfo=timezone.utc) + delta = timedelta(days=1) + with pytest.raises(ValueError): + windowpolicy.Window(start, end, delta) + + +@pytest.mark.parametrize( + "delta", + [ + timedelta(days=-1), + timedelta(0), + ], ) +def test_window_validation_invalid_delta(delta): + """Test window creation with invalid delta.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + with pytest.raises(ValueError): + windowpolicy.Window(start, end, delta) + + +def test_window_validation_mismatched_delta(): + """Test window creation with delta not matching start-end difference.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + delta = timedelta(days=2) + with pytest.raises(ValueError): + windowpolicy.Window(start, end, delta) + + +@pytest.mark.parametrize( + "test_time,expected", + [ + (datetime(2024, 1, 1, tzinfo=timezone.utc), True), + (datetime(2024, 1, 2, tzinfo=timezone.utc), True), + (datetime(2024, 1, 3, tzinfo=timezone.utc), True), + (datetime(2023, 12, 31, tzinfo=timezone.utc), False), + ], +) +def test_window_before_datetime(test_time, expected): + """Test the before_datetime method.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + window = windowpolicy.Window(start, end, timedelta(days=1)) + assert window.before_datetime(test_time) == expected + + +@pytest.mark.parametrize( + "test_time,expected", + [ + (datetime(2024, 1, 1, tzinfo=timezone.utc), True), + (datetime(2024, 1, 1, 12, tzinfo=timezone.utc), True), + (datetime(2024, 1, 2, tzinfo=timezone.utc), False), + (datetime(2023, 12, 31, tzinfo=timezone.utc), False), + ], +) +def test_window_includes_datetime(test_time, expected): + """Test the includes_datetime method.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + window = windowpolicy.Window(start, end, timedelta(days=1)) + assert window.includes_datetime(test_time) == expected + + +def test_window_next_window(): + """Test the next_window method.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + window = windowpolicy.Window(start, end, timedelta(days=1)) + + next_window = window.next_window() + assert next_window.start == window.end + assert next_window.end == datetime(2024, 1, 3, tzinfo=timezone.utc) + assert next_window.delta == window.delta + + +def test_window_string_representation(): + """Test the string representation of Window.""" + start = datetime(2024, 1, 1, tzinfo=timezone.utc) + end = datetime(2024, 1, 2, tzinfo=timezone.utc) + window = windowpolicy.Window(start, end, timedelta(days=1)) + expected = f"Window({start.isoformat()} -> {end.isoformat()})" + assert str(window) == expected + + +# Daily Policy Tests + + +@pytest.mark.parametrize( + "reference,expected_start", + [ + ( + datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 1, 12, 30, tzinfo=timezone.utc), + datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 1, 23, 59, tzinfo=timezone.utc), + datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + ), + ], +) +def test_daily_policy_start_window_different_times(reference, expected_start): + """Test DailyPolicy.start_window with reference at different times of day.""" + policy = windowpolicy.DailyPolicy() + window = policy.start_window(reference) + assert window.start == expected_start + assert window.end == expected_start + timedelta(days=1) + assert window.delta == timedelta(days=1) + + +@pytest.mark.parametrize( + "reference,expected_start", + [ + ( + datetime(2024, 1, 1, 12, tzinfo=timezone.utc), + datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 2, 12, tzinfo=timezone.utc), + datetime(2024, 1, 2, 0, 0, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 3, 12, tzinfo=timezone.utc), + datetime(2024, 1, 3, 0, 0, tzinfo=timezone.utc), + ), + ], +) +def test_daily_policy_across_days(reference, expected_start): + """Test DailyPolicy.start_window across different days.""" + policy = windowpolicy.DailyPolicy() + window = policy.start_window(reference) + assert window.start == expected_start + assert window.end == expected_start + timedelta(days=1) + assert window.delta == timedelta(days=1) + + +def test_daily_policy_non_utc(): + """Test DailyPolicy with non-UTC reference.""" + policy = windowpolicy.DailyPolicy() + reference = datetime(2024, 1, 1, 15, 30) # naive datetime + with pytest.raises(ValueError): + policy.start_window(reference) + + +# Weekly Policy Tests + + +@pytest.mark.parametrize( + "reference,expected_start", + [ + # Monday through Sunday + ( + datetime(2024, 1, 1, tzinfo=timezone.utc), + datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 2, tzinfo=timezone.utc), + datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 3, tzinfo=timezone.utc), + datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 4, tzinfo=timezone.utc), + datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 5, tzinfo=timezone.utc), + datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 6, tzinfo=timezone.utc), + datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + ( + datetime(2024, 1, 7, tzinfo=timezone.utc), + datetime(2024, 1, 1, tzinfo=timezone.utc), + ), + ], +) +def test_weekly_policy_start_window(reference, expected_start): + """Test WeeklyPolicy.start_window for different days of the week.""" + policy = windowpolicy.WeeklyPolicy() + window = policy.start_window(reference) + assert window.start == expected_start + assert window.end == expected_start + timedelta(days=7) + assert window.delta == timedelta(days=7) + + +def test_weekly_policy_non_utc(): + """Test WeeklyPolicy with non-UTC reference.""" + policy = windowpolicy.WeeklyPolicy() + reference = datetime(2024, 1, 1, 15, 30) # naive datetime + with pytest.raises(ValueError): + policy.start_window(reference) + + +# Generate Windows Tests + + +def test_generate_windows_default_params(): + """Test generate_windows with default parameters.""" + policy = windowpolicy.DailyPolicy() + windows = list(windowpolicy.generate_windows(policy)) + assert isinstance(windows, list) + assert len(windows) > 0 + + +def test_generate_windows_custom_reference(): + """Test generate_windows with custom reference time.""" + policy = windowpolicy.DailyPolicy() + reference = datetime(2024, 1, 1, tzinfo=timezone.utc) + now = datetime(2024, 1, 3, tzinfo=timezone.utc) + windows = list(windowpolicy.generate_windows(policy, reference=reference, now=now)) + + assert len(windows) == 2 # Should have Jan 1-2 and Jan 2-3 + assert windows[0].start == reference + + +def test_generate_windows_non_utc(): + """Test generate_windows with non-UTC times.""" + policy = windowpolicy.DailyPolicy() + + # Test with naive reference time + reference = datetime(2024, 1, 1) + now = datetime(2024, 1, 2, tzinfo=timezone.utc) + with pytest.raises(ValueError): + list(windowpolicy.generate_windows(policy, reference=reference, now=now)) + + # Test with naive now time + reference = datetime(2024, 1, 1, tzinfo=timezone.utc) + now = datetime(2024, 1, 2) + with pytest.raises(ValueError): + list(windowpolicy.generate_windows(policy, reference=reference, now=now)) + + +def test_generate_windows_reference_after_now(): + """Test generate_windows with reference time after now.""" + policy = windowpolicy.DailyPolicy() + reference = datetime(2024, 1, 2, tzinfo=timezone.utc) + now = datetime(2024, 1, 1, tzinfo=timezone.utc) + windows = list(windowpolicy.generate_windows(policy, reference=reference, now=now)) + assert len(windows) == 0 + + +def test_generate_windows_same_day(): + """Test generate_windows with reference and now on the same day.""" + policy = windowpolicy.DailyPolicy() + reference = datetime(2024, 1, 1, tzinfo=timezone.utc) + now = datetime(2024, 1, 1, 12, tzinfo=timezone.utc) + windows = list(windowpolicy.generate_windows(policy, reference=reference, now=now)) + assert len(windows) == 0 + + +# Utility Tests + + +def test_datetime_utcnow(): + """Test datetime_utcnow function.""" + now = windowpolicy.datetime_utcnow() + assert now.tzinfo == timezone.utc + +def test_validate_utc(): + """Test _validate_utc function.""" + # Valid UTC datetime + dt = datetime(2024, 1, 1, tzinfo=timezone.utc) + windowpolicy.validate_utc(dt, "test") # Should not raise -class TestWindow(unittest.TestCase): - """Test the Window class.""" - - def test_valid_window_creation(self): - """Test creating a valid window.""" - start = datetime(2024, 1, 1, tzinfo=timezone.utc) - end = datetime(2024, 1, 2, tzinfo=timezone.utc) - delta = timedelta(days=1) - window = Window(start, end, delta) - self.assertEqual(window.start, start) - self.assertEqual(window.end, end) - self.assertEqual(window.delta, delta) - - def test_window_validation_non_utc(self): - """Test window creation with non-UTC times.""" - start = datetime(2024, 1, 1) # naive datetime - end = datetime(2024, 1, 2, tzinfo=timezone.utc) - delta = timedelta(days=1) - with self.assertRaises(ValueError): - Window(start, end, delta) - - start = datetime(2024, 1, 1, tzinfo=timezone.utc) - end = datetime(2024, 1, 2) # naive datetime - with self.assertRaises(ValueError): - Window(start, end, delta) - - def test_window_validation_invalid_times(self): - """Test window creation with invalid time combinations.""" - start = datetime(2024, 1, 2, tzinfo=timezone.utc) - end = datetime(2024, 1, 1, tzinfo=timezone.utc) - delta = timedelta(days=1) - with self.assertRaises(ValueError): - Window(start, end, delta) - - def test_window_validation_invalid_delta(self): - """Test window creation with invalid delta.""" - start = datetime(2024, 1, 1, tzinfo=timezone.utc) - end = datetime(2024, 1, 2, tzinfo=timezone.utc) - delta = timedelta(days=-1) - with self.assertRaises(ValueError): - Window(start, end, delta) - - delta = timedelta(0) - with self.assertRaises(ValueError): - Window(start, end, delta) - - def test_window_validation_mismatched_delta(self): - """Test window creation with delta not matching start-end difference.""" - start = datetime(2024, 1, 1, tzinfo=timezone.utc) - end = datetime(2024, 1, 2, tzinfo=timezone.utc) - delta = timedelta(days=2) - with self.assertRaises(ValueError): - Window(start, end, delta) - - def test_window_before_datetime(self): - """Test the before_datetime method.""" - start = datetime(2024, 1, 1, tzinfo=timezone.utc) - end = datetime(2024, 1, 2, tzinfo=timezone.utc) - window = Window(start, end, timedelta(days=1)) - - # Test various cases - self.assertTrue( - window.before_datetime(datetime(2024, 1, 1, tzinfo=timezone.utc)) - ) - self.assertTrue( - window.before_datetime(datetime(2024, 1, 2, tzinfo=timezone.utc)) - ) - self.assertTrue( - window.before_datetime(datetime(2024, 1, 3, tzinfo=timezone.utc)) - ) - self.assertFalse( - window.before_datetime(datetime(2023, 12, 31, tzinfo=timezone.utc)) - ) - - def test_window_includes_datetime(self): - """Test the includes_datetime method.""" - start = datetime(2024, 1, 1, tzinfo=timezone.utc) - end = datetime(2024, 1, 2, tzinfo=timezone.utc) - window = Window(start, end, timedelta(days=1)) - - # Test various cases - self.assertTrue( - window.includes_datetime(datetime(2024, 1, 1, tzinfo=timezone.utc)) - ) - self.assertTrue( - window.includes_datetime(datetime(2024, 1, 1, 12, tzinfo=timezone.utc)) - ) - self.assertFalse( - window.includes_datetime(datetime(2024, 1, 2, tzinfo=timezone.utc)) - ) - self.assertFalse( - window.includes_datetime(datetime(2023, 12, 31, tzinfo=timezone.utc)) - ) - - def test_window_next_window(self): - """Test the next_window method.""" - start = datetime(2024, 1, 1, tzinfo=timezone.utc) - end = datetime(2024, 1, 2, tzinfo=timezone.utc) - window = Window(start, end, timedelta(days=1)) - - next_window = window.next_window() - self.assertEqual(next_window.start, window.end) - self.assertEqual(next_window.end, datetime(2024, 1, 3, tzinfo=timezone.utc)) - self.assertEqual(next_window.delta, window.delta) - - def test_window_string_representation(self): - """Test the string representation of Window.""" - start = datetime(2024, 1, 1, tzinfo=timezone.utc) - end = datetime(2024, 1, 2, tzinfo=timezone.utc) - window = Window(start, end, timedelta(days=1)) - expected = f"Window({start.isoformat()} -> {end.isoformat()})" - self.assertEqual(str(window), expected) - - -class TestDailyPolicy(unittest.TestCase): - """Test the DailyPolicy class.""" - - def test_daily_policy_start_window_different_times(self): - """Test DailyPolicy.start_window with reference at different times of day.""" - policy = DailyPolicy() - - # Test different times during January 1, 2024 - test_cases = [ - # (reference_time, expected_start) - ( - datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), # midnight - datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), - ( - datetime(2024, 1, 1, 12, 30, tzinfo=timezone.utc), # noon - datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), - ( - datetime(2024, 1, 1, 23, 59, tzinfo=timezone.utc), # end of day - datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), - ] - - for reference, expected_start in test_cases: - with self.subTest(reference=reference): - window = policy.start_window(reference) - self.assertEqual(window.start, expected_start) - self.assertEqual(window.end, expected_start + timedelta(days=1)) - self.assertEqual(window.delta, timedelta(days=1)) - - def test_daily_policy_across_days(self): - """Test DailyPolicy.start_window across different days.""" - policy = DailyPolicy() - - test_cases = [ - # (reference_time, expected_start) - ( - datetime(2024, 1, 1, 12, tzinfo=timezone.utc), - datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), - ), - ( - datetime(2024, 1, 2, 12, tzinfo=timezone.utc), - datetime(2024, 1, 2, 0, 0, tzinfo=timezone.utc), - ), - ( - datetime(2024, 1, 3, 12, tzinfo=timezone.utc), - datetime(2024, 1, 3, 0, 0, tzinfo=timezone.utc), - ), - ] - - for reference, expected_start in test_cases: - with self.subTest(reference=reference): - window = policy.start_window(reference) - self.assertEqual(window.start, expected_start) - self.assertEqual(window.end, expected_start + timedelta(days=1)) - self.assertEqual(window.delta, timedelta(days=1)) - - def test_daily_policy_non_utc(self): - """Test DailyPolicy with non-UTC reference.""" - policy = DailyPolicy() - reference = datetime(2024, 1, 1, 15, 30) # naive datetime - with self.assertRaises(ValueError): - policy.start_window(reference) - - -class TestWeeklyPolicy(unittest.TestCase): - """Test the WeeklyPolicy class.""" - - def test_weekly_policy_start_window(self): - """Test WeeklyPolicy.start_window for different days of the week.""" - policy = WeeklyPolicy() - - # Test for each day of the week - test_cases = [ - # Monday - ( - datetime(2024, 1, 1, tzinfo=timezone.utc), - datetime(2024, 1, 1, tzinfo=timezone.utc), - ), - # Tuesday - ( - datetime(2024, 1, 2, tzinfo=timezone.utc), - datetime(2024, 1, 1, tzinfo=timezone.utc), - ), - # Wednesday - ( - datetime(2024, 1, 3, tzinfo=timezone.utc), - datetime(2024, 1, 1, tzinfo=timezone.utc), - ), - # Thursday - ( - datetime(2024, 1, 4, tzinfo=timezone.utc), - datetime(2024, 1, 1, tzinfo=timezone.utc), - ), - # Friday - ( - datetime(2024, 1, 5, tzinfo=timezone.utc), - datetime(2024, 1, 1, tzinfo=timezone.utc), - ), - # Saturday - ( - datetime(2024, 1, 6, tzinfo=timezone.utc), - datetime(2024, 1, 1, tzinfo=timezone.utc), - ), - # Sunday - ( - datetime(2024, 1, 7, tzinfo=timezone.utc), - datetime(2024, 1, 1, tzinfo=timezone.utc), - ), - ] - - for reference, expected_start in test_cases: - with self.subTest(reference=reference): - window = policy.start_window(reference) - self.assertEqual(window.start, expected_start) - self.assertEqual(window.end, expected_start + timedelta(days=7)) - self.assertEqual(window.delta, timedelta(days=7)) - - def test_weekly_policy_non_utc(self): - """Test WeeklyPolicy with non-UTC reference.""" - policy = WeeklyPolicy() - reference = datetime(2024, 1, 1, 15, 30) # naive datetime - with self.assertRaises(ValueError): - policy.start_window(reference) - - -class TestGenerateWindows(unittest.TestCase): - """Test the generate_windows function.""" - - def test_generate_windows_default_params(self): - """Test generate_windows with default parameters.""" - policy = DailyPolicy() - windows = generate_windows(policy) - self.assertIsInstance(windows, list) - self.assertGreater(len(windows), 0) - - def test_generate_windows_custom_reference(self): - """Test generate_windows with custom reference time.""" - policy = DailyPolicy() - reference = datetime(2024, 1, 1, tzinfo=timezone.utc) - now = datetime(2024, 1, 3, tzinfo=timezone.utc) - windows = generate_windows(policy, reference=reference, now=now) - - self.assertEqual(len(windows), 2) # Should have Jan 1-2 and Jan 2-3 - self.assertEqual(windows[0].start, reference) - - def test_generate_windows_non_utc(self): - """Test generate_windows with non-UTC times.""" - policy = DailyPolicy() - reference = datetime(2024, 1, 1) # naive datetime - now = datetime(2024, 1, 2, tzinfo=timezone.utc) - - with self.assertRaises(ValueError): - generate_windows(policy, reference=reference, now=now) - - reference = datetime(2024, 1, 1, tzinfo=timezone.utc) - now = datetime(2024, 1, 2) # naive datetime - - with self.assertRaises(ValueError): - generate_windows(policy, reference=reference, now=now) - - def test_generate_windows_reference_after_now(self): - """Test generate_windows with reference time after now.""" - policy = DailyPolicy() - reference = datetime(2024, 1, 2, tzinfo=timezone.utc) - now = datetime(2024, 1, 1, tzinfo=timezone.utc) - - windows = generate_windows(policy, reference=reference, now=now) - self.assertEqual(len(windows), 0) - - def test_generate_windows_same_day(self): - """Test generate_windows with reference and now on the same day.""" - policy = DailyPolicy() - reference = datetime(2024, 1, 1, tzinfo=timezone.utc) - now = datetime(2024, 1, 1, 12, tzinfo=timezone.utc) - - windows = generate_windows(policy, reference=reference, now=now) - self.assertEqual(len(windows), 0) - - -class TestUtilities(unittest.TestCase): - """Test utility functions.""" - - def test_datetime_utcnow(self): - """Test datetime_utcnow function.""" - now = datetime_utcnow() - self.assertEqual(now.tzinfo, timezone.utc) - - def test_validate_utc(self): - """Test _validate_utc function.""" - # Valid UTC datetime - dt = datetime(2024, 1, 1, tzinfo=timezone.utc) - _validate_utc(dt, "test") # Should not raise - - # Naive datetime - dt = datetime(2024, 1, 1) - with self.assertRaises(ValueError): - _validate_utc(dt, "test") - - -if __name__ == "__main__": - unittest.main() + # Naive datetime + dt = datetime(2024, 1, 1) + with pytest.raises(ValueError): + windowpolicy.validate_utc(dt, "test") -- GitLab From 2259914cfe85a31b24cb78759eec99b7ae7d512f Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Mon, 24 Feb 2025 01:03:38 +0100 Subject: [PATCH 62/75] move what we don't need to the attic --- Attic/README.md | 1 + {aggregatetunnelmetrics => Attic}/globalscope/__init__.py | 0 {aggregatetunnelmetrics => Attic}/globalscope/aggregate.py | 0 {ooniformat => Attic/ooniformat}/__init__.py | 0 {ooniformat => Attic/ooniformat}/serializer.py | 0 {ooniformat => Attic/ooniformat}/testkeys.py | 0 {aggregatetunnelmetrics => Attic}/ooniformatter/__init__.py | 0 {aggregatetunnelmetrics => Attic}/ooniformatter/formatter.py | 0 {oonisubmitter => Attic/oonisubmitter}/DESIGN.md | 0 {oonisubmitter => Attic/oonisubmitter}/__init__.py | 0 {oonisubmitter => Attic/oonisubmitter}/aggregator.py | 0 {oonisubmitter => Attic/oonisubmitter}/identifiers.py | 0 {oonisubmitter => Attic/oonisubmitter}/lockedfile.py | 0 {oonisubmitter => Attic/oonisubmitter}/model.py | 0 {oonisubmitter => Attic/oonisubmitter}/ooniapi.py | 0 {oonisubmitter => Attic/oonisubmitter}/policy.py | 0 {oonisubmitter => Attic/oonisubmitter}/serializer.py | 0 {oonisubmitter => Attic/oonisubmitter}/submitter.py | 0 .../aggregatetunnelmetrics => Attic/t}/globalscope/__init__.py | 0 .../t}/globalscope/test_aggregate.py | 0 .../aggregatetunnelmetrics => Attic/t}/ooniformatter/__init__.py | 0 .../t}/ooniformatter/test_formatter.py | 0 {tests => Attic/t}/oonisubmitter/__init__.py | 0 {tests => Attic/t}/oonisubmitter/test_aggregator.py | 0 {tests => Attic/t}/oonisubmitter/test_identifiers.py | 0 {tests => Attic/t}/oonisubmitter/testdata/expected_state.json | 0 {tests => Attic/t}/oonisubmitter/testdata/sample.csv | 0 {tunnelmetrics => Attic/tunnelmetrics}/__init__.py | 0 {tunnelmetrics => Attic/tunnelmetrics}/endpoint.py | 0 {tunnelmetrics => Attic/tunnelmetrics}/identifiers.py | 0 {tunnelmetrics => Attic/tunnelmetrics}/model.py | 0 31 files changed, 1 insertion(+) create mode 100644 Attic/README.md rename {aggregatetunnelmetrics => Attic}/globalscope/__init__.py (100%) rename {aggregatetunnelmetrics => Attic}/globalscope/aggregate.py (100%) rename {ooniformat => Attic/ooniformat}/__init__.py (100%) rename {ooniformat => Attic/ooniformat}/serializer.py (100%) rename {ooniformat => Attic/ooniformat}/testkeys.py (100%) rename {aggregatetunnelmetrics => Attic}/ooniformatter/__init__.py (100%) rename {aggregatetunnelmetrics => Attic}/ooniformatter/formatter.py (100%) rename {oonisubmitter => Attic/oonisubmitter}/DESIGN.md (100%) rename {oonisubmitter => Attic/oonisubmitter}/__init__.py (100%) rename {oonisubmitter => Attic/oonisubmitter}/aggregator.py (100%) rename {oonisubmitter => Attic/oonisubmitter}/identifiers.py (100%) rename {oonisubmitter => Attic/oonisubmitter}/lockedfile.py (100%) rename {oonisubmitter => Attic/oonisubmitter}/model.py (100%) rename {oonisubmitter => Attic/oonisubmitter}/ooniapi.py (100%) rename {oonisubmitter => Attic/oonisubmitter}/policy.py (100%) rename {oonisubmitter => Attic/oonisubmitter}/serializer.py (100%) rename {oonisubmitter => Attic/oonisubmitter}/submitter.py (100%) rename {tests/aggregatetunnelmetrics => Attic/t}/globalscope/__init__.py (100%) rename {tests/aggregatetunnelmetrics => Attic/t}/globalscope/test_aggregate.py (100%) rename {tests/aggregatetunnelmetrics => Attic/t}/ooniformatter/__init__.py (100%) rename {tests/aggregatetunnelmetrics => Attic/t}/ooniformatter/test_formatter.py (100%) rename {tests => Attic/t}/oonisubmitter/__init__.py (100%) rename {tests => Attic/t}/oonisubmitter/test_aggregator.py (100%) rename {tests => Attic/t}/oonisubmitter/test_identifiers.py (100%) rename {tests => Attic/t}/oonisubmitter/testdata/expected_state.json (100%) rename {tests => Attic/t}/oonisubmitter/testdata/sample.csv (100%) rename {tunnelmetrics => Attic/tunnelmetrics}/__init__.py (100%) rename {tunnelmetrics => Attic/tunnelmetrics}/endpoint.py (100%) rename {tunnelmetrics => Attic/tunnelmetrics}/identifiers.py (100%) rename {tunnelmetrics => Attic/tunnelmetrics}/model.py (100%) diff --git a/Attic/README.md b/Attic/README.md new file mode 100644 index 0000000..8b149e4 --- /dev/null +++ b/Attic/README.md @@ -0,0 +1 @@ +Stuff that we should most likely not merge. diff --git a/aggregatetunnelmetrics/globalscope/__init__.py b/Attic/globalscope/__init__.py similarity index 100% rename from aggregatetunnelmetrics/globalscope/__init__.py rename to Attic/globalscope/__init__.py diff --git a/aggregatetunnelmetrics/globalscope/aggregate.py b/Attic/globalscope/aggregate.py similarity index 100% rename from aggregatetunnelmetrics/globalscope/aggregate.py rename to Attic/globalscope/aggregate.py diff --git a/ooniformat/__init__.py b/Attic/ooniformat/__init__.py similarity index 100% rename from ooniformat/__init__.py rename to Attic/ooniformat/__init__.py diff --git a/ooniformat/serializer.py b/Attic/ooniformat/serializer.py similarity index 100% rename from ooniformat/serializer.py rename to Attic/ooniformat/serializer.py diff --git a/ooniformat/testkeys.py b/Attic/ooniformat/testkeys.py similarity index 100% rename from ooniformat/testkeys.py rename to Attic/ooniformat/testkeys.py diff --git a/aggregatetunnelmetrics/ooniformatter/__init__.py b/Attic/ooniformatter/__init__.py similarity index 100% rename from aggregatetunnelmetrics/ooniformatter/__init__.py rename to Attic/ooniformatter/__init__.py diff --git a/aggregatetunnelmetrics/ooniformatter/formatter.py b/Attic/ooniformatter/formatter.py similarity index 100% rename from aggregatetunnelmetrics/ooniformatter/formatter.py rename to Attic/ooniformatter/formatter.py diff --git a/oonisubmitter/DESIGN.md b/Attic/oonisubmitter/DESIGN.md similarity index 100% rename from oonisubmitter/DESIGN.md rename to Attic/oonisubmitter/DESIGN.md diff --git a/oonisubmitter/__init__.py b/Attic/oonisubmitter/__init__.py similarity index 100% rename from oonisubmitter/__init__.py rename to Attic/oonisubmitter/__init__.py diff --git a/oonisubmitter/aggregator.py b/Attic/oonisubmitter/aggregator.py similarity index 100% rename from oonisubmitter/aggregator.py rename to Attic/oonisubmitter/aggregator.py diff --git a/oonisubmitter/identifiers.py b/Attic/oonisubmitter/identifiers.py similarity index 100% rename from oonisubmitter/identifiers.py rename to Attic/oonisubmitter/identifiers.py diff --git a/oonisubmitter/lockedfile.py b/Attic/oonisubmitter/lockedfile.py similarity index 100% rename from oonisubmitter/lockedfile.py rename to Attic/oonisubmitter/lockedfile.py diff --git a/oonisubmitter/model.py b/Attic/oonisubmitter/model.py similarity index 100% rename from oonisubmitter/model.py rename to Attic/oonisubmitter/model.py diff --git a/oonisubmitter/ooniapi.py b/Attic/oonisubmitter/ooniapi.py similarity index 100% rename from oonisubmitter/ooniapi.py rename to Attic/oonisubmitter/ooniapi.py diff --git a/oonisubmitter/policy.py b/Attic/oonisubmitter/policy.py similarity index 100% rename from oonisubmitter/policy.py rename to Attic/oonisubmitter/policy.py diff --git a/oonisubmitter/serializer.py b/Attic/oonisubmitter/serializer.py similarity index 100% rename from oonisubmitter/serializer.py rename to Attic/oonisubmitter/serializer.py diff --git a/oonisubmitter/submitter.py b/Attic/oonisubmitter/submitter.py similarity index 100% rename from oonisubmitter/submitter.py rename to Attic/oonisubmitter/submitter.py diff --git a/tests/aggregatetunnelmetrics/globalscope/__init__.py b/Attic/t/globalscope/__init__.py similarity index 100% rename from tests/aggregatetunnelmetrics/globalscope/__init__.py rename to Attic/t/globalscope/__init__.py diff --git a/tests/aggregatetunnelmetrics/globalscope/test_aggregate.py b/Attic/t/globalscope/test_aggregate.py similarity index 100% rename from tests/aggregatetunnelmetrics/globalscope/test_aggregate.py rename to Attic/t/globalscope/test_aggregate.py diff --git a/tests/aggregatetunnelmetrics/ooniformatter/__init__.py b/Attic/t/ooniformatter/__init__.py similarity index 100% rename from tests/aggregatetunnelmetrics/ooniformatter/__init__.py rename to Attic/t/ooniformatter/__init__.py diff --git a/tests/aggregatetunnelmetrics/ooniformatter/test_formatter.py b/Attic/t/ooniformatter/test_formatter.py similarity index 100% rename from tests/aggregatetunnelmetrics/ooniformatter/test_formatter.py rename to Attic/t/ooniformatter/test_formatter.py diff --git a/tests/oonisubmitter/__init__.py b/Attic/t/oonisubmitter/__init__.py similarity index 100% rename from tests/oonisubmitter/__init__.py rename to Attic/t/oonisubmitter/__init__.py diff --git a/tests/oonisubmitter/test_aggregator.py b/Attic/t/oonisubmitter/test_aggregator.py similarity index 100% rename from tests/oonisubmitter/test_aggregator.py rename to Attic/t/oonisubmitter/test_aggregator.py diff --git a/tests/oonisubmitter/test_identifiers.py b/Attic/t/oonisubmitter/test_identifiers.py similarity index 100% rename from tests/oonisubmitter/test_identifiers.py rename to Attic/t/oonisubmitter/test_identifiers.py diff --git a/tests/oonisubmitter/testdata/expected_state.json b/Attic/t/oonisubmitter/testdata/expected_state.json similarity index 100% rename from tests/oonisubmitter/testdata/expected_state.json rename to Attic/t/oonisubmitter/testdata/expected_state.json diff --git a/tests/oonisubmitter/testdata/sample.csv b/Attic/t/oonisubmitter/testdata/sample.csv similarity index 100% rename from tests/oonisubmitter/testdata/sample.csv rename to Attic/t/oonisubmitter/testdata/sample.csv diff --git a/tunnelmetrics/__init__.py b/Attic/tunnelmetrics/__init__.py similarity index 100% rename from tunnelmetrics/__init__.py rename to Attic/tunnelmetrics/__init__.py diff --git a/tunnelmetrics/endpoint.py b/Attic/tunnelmetrics/endpoint.py similarity index 100% rename from tunnelmetrics/endpoint.py rename to Attic/tunnelmetrics/endpoint.py diff --git a/tunnelmetrics/identifiers.py b/Attic/tunnelmetrics/identifiers.py similarity index 100% rename from tunnelmetrics/identifiers.py rename to Attic/tunnelmetrics/identifiers.py diff --git a/tunnelmetrics/model.py b/Attic/tunnelmetrics/model.py similarity index 100% rename from tunnelmetrics/model.py rename to Attic/tunnelmetrics/model.py -- GitLab From b3dd1379bf6cc8e33ed1eada8aaa1091c9f19a89 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 9 Mar 2025 10:25:30 +0100 Subject: [PATCH 63/75] fix: correct docstring --- aggregatetunnelmetrics/aggregators/endpointpool.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aggregatetunnelmetrics/aggregators/endpointpool.py b/aggregatetunnelmetrics/aggregators/endpointpool.py index 18964e7..3e43b4e 100644 --- a/aggregatetunnelmetrics/aggregators/endpointpool.py +++ b/aggregatetunnelmetrics/aggregators/endpointpool.py @@ -103,7 +103,9 @@ class Aggregator: Creates a new OONI Measurement from the given metrics. Args: - aggr: The metrics to convert. + pool_name: The name of the endpoint pool. + proto_name: The name of the protocol. + bodies: The list of metrics statements. Returns: The OONI Measurement. -- GitLab From 96cd0e4380e534b3e3fea7f2e10c178c80f09505 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 9 Mar 2025 10:28:07 +0100 Subject: [PATCH 64/75] fix: rename state to pool_aggr --- aggregatetunnelmetrics/aggregators/endpointpool.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/aggregatetunnelmetrics/aggregators/endpointpool.py b/aggregatetunnelmetrics/aggregators/endpointpool.py index 3e43b4e..953985f 100644 --- a/aggregatetunnelmetrics/aggregators/endpointpool.py +++ b/aggregatetunnelmetrics/aggregators/endpointpool.py @@ -65,8 +65,11 @@ class Aggregator: Fields: provider: Name of the VPN provider. + pool_country: Country code of the VPN pool. + time_window: Time window for the metrics. + upstream_collector: Description of the CSV metrics collector. privacy_config: Configuration for privacy filters. - state: Mutable aggregation state. + pool_aggr: Mutable state for the endpoint pool scope. """ provider: str @@ -74,7 +77,7 @@ class Aggregator: time_window: metrics.TimeWindow upstream_collector: UpstreamCollector privacy_config: privacy.Config = privacy.Config() - state: PoolAggregator = field(default_factory=PoolAggregator) + pool_aggr: PoolAggregator = field(default_factory=PoolAggregator) def aggregate( self, @@ -82,10 +85,10 @@ class Aggregator: ) -> Generator[oonicollector.Measurement]: # Walk through entries updating the mutable state for entry in entries: - self.state.update(entry) + self.pool_aggr.update(entry) # Serialize and yield each measurement - for pool_name, pool_values in self.state.pools.items(): + for pool_name, pool_values in self.pool_aggr.pools.items(): for proto_name, proto_metrics in pool_values.protocols.items(): statements = proto_metrics.statements() -- GitLab From 7839a96af6c73e30e1a4f4b746b4a3465048e19d Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 9 Mar 2025 10:29:57 +0100 Subject: [PATCH 65/75] fix: rename variables as requested --- aggregatetunnelmetrics/pipeline/pipeline.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/aggregatetunnelmetrics/pipeline/pipeline.py b/aggregatetunnelmetrics/pipeline/pipeline.py index 274d138..99afc6c 100644 --- a/aggregatetunnelmetrics/pipeline/pipeline.py +++ b/aggregatetunnelmetrics/pipeline/pipeline.py @@ -152,7 +152,7 @@ class Pipeline: def _transform_to_measurements( self, - entries: Generator[WindowEntries], + window_entries: Generator[WindowEntries], ) -> Generator[WindowMeasurements]: """ Stage 2: Transform WindowEntries into Measurement using aggregation. @@ -163,14 +163,14 @@ class Pipeline: Returns: Generator of WindowMeasurements. """ - for entry in entries: + for window_entry in window_entries: # Create the aggregator for this window aggr: aggregator.Logic = endpointpool.Aggregator( provider=self.provider, pool_country=self.upstream_collector.cc, time_window=metrics.TimeWindow( - start=entry.window.start, - end=entry.window.end, + start=window_entry.window.start, + end=window_entry.window.end, ), upstream_collector=self.upstream_collector, privacy_config=self.privacy_config, @@ -178,8 +178,8 @@ class Pipeline: # Stream measurements back yield WindowMeasurements( - measurements=aggr.aggregate((e for e in entry.entries)), - window=entry.window, + measurements=aggr.aggregate((e for e in window_entry.entries)), + window=window_entry.window, ) def _submit_to_collector( -- GitLab From 0f5642276e5c71623aea9c71330a26bf38620671 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 9 Mar 2025 10:35:49 +0100 Subject: [PATCH 66/75] fix: Aggregator -> AggregatorLogic --- aggregatetunnelmetrics/aggregators/endpointpool.py | 2 +- aggregatetunnelmetrics/pipeline/pipeline.py | 2 +- tests/aggregatetunnelmetrics/aggregators/test_endpointpool.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/aggregatetunnelmetrics/aggregators/endpointpool.py b/aggregatetunnelmetrics/aggregators/endpointpool.py index 953985f..35c987f 100644 --- a/aggregatetunnelmetrics/aggregators/endpointpool.py +++ b/aggregatetunnelmetrics/aggregators/endpointpool.py @@ -59,7 +59,7 @@ class PoolAggregator: @dataclass(frozen=True) -class Aggregator: +class AggregatorLogic: """ Implements aggregator.Logic for endpoint pool scope. diff --git a/aggregatetunnelmetrics/pipeline/pipeline.py b/aggregatetunnelmetrics/pipeline/pipeline.py index 99afc6c..cd4405c 100644 --- a/aggregatetunnelmetrics/pipeline/pipeline.py +++ b/aggregatetunnelmetrics/pipeline/pipeline.py @@ -165,7 +165,7 @@ class Pipeline: """ for window_entry in window_entries: # Create the aggregator for this window - aggr: aggregator.Logic = endpointpool.Aggregator( + aggr: aggregator.Logic = endpointpool.AggregatorLogic( provider=self.provider, pool_country=self.upstream_collector.cc, time_window=metrics.TimeWindow( diff --git a/tests/aggregatetunnelmetrics/aggregators/test_endpointpool.py b/tests/aggregatetunnelmetrics/aggregators/test_endpointpool.py index 31461f5..de98b07 100644 --- a/tests/aggregatetunnelmetrics/aggregators/test_endpointpool.py +++ b/tests/aggregatetunnelmetrics/aggregators/test_endpointpool.py @@ -42,7 +42,7 @@ def test_aggregator(sample_entry, upstream_collector): start=datetime.now(timezone.utc), end=datetime.now(timezone.utc) ) - aggregator = endpointpool.Aggregator( + aggregator = endpointpool.AggregatorLogic( provider="test_provider", pool_country="XX", time_window=time_window, -- GitLab From 84bcf0f524f624e6d6a376d56ef5dd910662fe8a Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@riseup.net> Date: Sun, 9 Mar 2025 09:37:34 +0000 Subject: [PATCH 67/75] Apply 1 suggestion(s) to 1 file(s) Co-authored-by: power puffin <powerpuff@riseup.net> --- aggregatetunnelmetrics/aggregators/endpointpool.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aggregatetunnelmetrics/aggregators/endpointpool.py b/aggregatetunnelmetrics/aggregators/endpointpool.py index 35c987f..366bb50 100644 --- a/aggregatetunnelmetrics/aggregators/endpointpool.py +++ b/aggregatetunnelmetrics/aggregators/endpointpool.py @@ -26,7 +26,7 @@ class ProtocolAggregator: State for aggregating by protocol. Fields: - protocols: Maps protocol name to ProtocolSpecificMetrics. + protocols: Maps protocol name to ProtocolMetrics. """ protocols: dict[str, ProtocolMetrics] = field(default_factory=dict) -- GitLab From 82fed592fca4e8ff6a7ffee3019f9bab6952d78a Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 9 Mar 2025 10:40:39 +0100 Subject: [PATCH 68/75] fix: apply more suggestions by powerpuffins --- aggregatetunnelmetrics/pipeline/pipeline.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aggregatetunnelmetrics/pipeline/pipeline.py b/aggregatetunnelmetrics/pipeline/pipeline.py index cd4405c..a118d7f 100644 --- a/aggregatetunnelmetrics/pipeline/pipeline.py +++ b/aggregatetunnelmetrics/pipeline/pipeline.py @@ -123,6 +123,7 @@ class Pipeline: Args: csv_path: Input CSV file path + state: Pipeline state to use Returns: Iterator of window entries @@ -158,7 +159,7 @@ class Pipeline: Stage 2: Transform WindowEntries into Measurement using aggregation. Args: - window_files: Generator of WindowEntries to process. + window_entries: Generator of WindowEntries to process. Returns: Generator of WindowMeasurements. @@ -192,7 +193,8 @@ class Pipeline: the pipeline state to know when we stopped processing. Args: - measurements: Generator of WindowMeasurements to submit. + entries: Generator of WindowMeasurements to submit. + state: Pipeline state to update. """ report_id: oonicollector.ReportID | None = None -- GitLab From c6daa2ca79a4a8d941644912b18575fad1f3e5ec Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 9 Mar 2025 10:43:44 +0100 Subject: [PATCH 69/75] fix: ndt_download bug spotted by powerpuffins --- aggregatetunnelmetrics/spec/metrics.py | 2 +- tests/aggregatetunnelmetrics/spec/test_metrics.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aggregatetunnelmetrics/spec/metrics.py b/aggregatetunnelmetrics/spec/metrics.py index a36344c..ec60748 100644 --- a/aggregatetunnelmetrics/spec/metrics.py +++ b/aggregatetunnelmetrics/spec/metrics.py @@ -287,7 +287,7 @@ class TunnelNDTStatement: "target_address": self.target_address, "target_port": self.target_port, "sample_size": self.sample_size, - "type": "ndt_download", + "type": f"ndt_{self.direction}", "latency_ms": self.latency.as_dict() if self.latency else None, "speed_mbits": self.speed.as_dict() if self.speed else None, "rexmit": self.rexmit.as_dict() if self.rexmit else None, diff --git a/tests/aggregatetunnelmetrics/spec/test_metrics.py b/tests/aggregatetunnelmetrics/spec/test_metrics.py index bf16b29..dc9b65e 100644 --- a/tests/aggregatetunnelmetrics/spec/test_metrics.py +++ b/tests/aggregatetunnelmetrics/spec/test_metrics.py @@ -126,7 +126,7 @@ def test_tunnel_ndt_statement_none_values(): "target_address": None, "target_port": None, "sample_size": None, - "type": "ndt_download", + "type": "ndt_upload", "latency_ms": None, "speed_mbits": None, "rexmit": None, -- GitLab From 7b83dee939734df66dd9ef189ab0605483cd7eda Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 9 Mar 2025 10:48:04 +0100 Subject: [PATCH 70/75] fix(Attic): apply suggestion by cyberta --- Attic/globalscope/aggregate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Attic/globalscope/aggregate.py b/Attic/globalscope/aggregate.py index a8fd24d..001db00 100644 --- a/Attic/globalscope/aggregate.py +++ b/Attic/globalscope/aggregate.py @@ -30,7 +30,7 @@ def datetime_to_compact_utc(dt: datetime) -> str: @dataclass class AggregateProtocolState: - """Flat representation of the ggregated state at global scope.""" + """Flat representation of the aggregated state at global scope.""" # Core identification provider: str -- GitLab From 3f31905b09eaf7f7226e2d4588d43f4b24f52bac Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 9 Mar 2025 10:49:21 +0100 Subject: [PATCH 71/75] fix: apply suggestion by cyberta --- Attic/tunnelmetrics/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Attic/tunnelmetrics/model.py b/Attic/tunnelmetrics/model.py index ff935f4..5c9540c 100644 --- a/Attic/tunnelmetrics/model.py +++ b/Attic/tunnelmetrics/model.py @@ -28,7 +28,7 @@ class AggregatorConfig: probe_cc: str scope: Scope = Scope.ENDPOINT # for now we only care about this - # threshold below which we emit sample_size + # threshold below which we omit sample_size min_sample_size: int = 1000 # rounding sample_size to the nearest round_to -- GitLab From 5cccd0412bc519874510a5c326b45c58d96facf9 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 9 Mar 2025 10:57:24 +0100 Subject: [PATCH 72/75] chore: capture cyberta's comments into a TODO comment --- aggregatetunnelmetrics/aggregators/common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aggregatetunnelmetrics/aggregators/common.py b/aggregatetunnelmetrics/aggregators/common.py index 098f06a..0542a7f 100644 --- a/aggregatetunnelmetrics/aggregators/common.py +++ b/aggregatetunnelmetrics/aggregators/common.py @@ -52,6 +52,9 @@ class CreationMetrics: """Updates the metrics with a new entry.""" if entry.is_tunnel_measurement(): if entry.is_tunnel_error_measurement(): + # TODO(bassosimone,cyberta,powerpuffin): based on the review comments, + # we agree that here it would be nice to also evaluate the `entry.err_message` + # rather than using a generic error. error_type = "bootstrap.generic_error" self.errors[error_type] = self.errors.get(error_type, 0) + 1 self.num_samples += 1 -- GitLab From a163f38f8fd7f4c111ba4126d32a68af1028e0bd Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 9 Mar 2025 10:58:56 +0100 Subject: [PATCH 73/75] fix: be more specific in TODO comment --- aggregatetunnelmetrics/aggregators/common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aggregatetunnelmetrics/aggregators/common.py b/aggregatetunnelmetrics/aggregators/common.py index 0542a7f..4261fac 100644 --- a/aggregatetunnelmetrics/aggregators/common.py +++ b/aggregatetunnelmetrics/aggregators/common.py @@ -53,8 +53,9 @@ class CreationMetrics: if entry.is_tunnel_measurement(): if entry.is_tunnel_error_measurement(): # TODO(bassosimone,cyberta,powerpuffin): based on the review comments, - # we agree that here it would be nice to also evaluate the `entry.err_message` - # rather than using a generic error. + # we agree that here we should evaluate the `entry.err_message` rather than + # using a generic error. We should additionally consider the need of + # scrubbing the message to avoid leaking sensitive information. error_type = "bootstrap.generic_error" self.errors[error_type] = self.errors.get(error_type, 0) + 1 self.num_samples += 1 -- GitLab From c09ed10af4cd323bd98c7f0741c09a6c2a85bc7a Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 9 Mar 2025 11:07:53 +0100 Subject: [PATCH 74/75] fix(NetworkErrorStatement): make the phase a parameter --- aggregatetunnelmetrics/aggregators/common.py | 1 + aggregatetunnelmetrics/spec/metrics.py | 4 +++- .../aggregatetunnelmetrics/aggregators/test_privacy.py | 10 +++++----- tests/aggregatetunnelmetrics/spec/test_metrics.py | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/aggregatetunnelmetrics/aggregators/common.py b/aggregatetunnelmetrics/aggregators/common.py index 4261fac..0f82737 100644 --- a/aggregatetunnelmetrics/aggregators/common.py +++ b/aggregatetunnelmetrics/aggregators/common.py @@ -66,6 +66,7 @@ class CreationMetrics: for error, count in self.errors.items(): result.append( metrics.NetworkErrorStatement( + phase="creation", sample_size=self.num_samples, failure_ratio=count / self.num_samples, error=error, diff --git a/aggregatetunnelmetrics/spec/metrics.py b/aggregatetunnelmetrics/spec/metrics.py index ec60748..b4f76c9 100644 --- a/aggregatetunnelmetrics/spec/metrics.py +++ b/aggregatetunnelmetrics/spec/metrics.py @@ -160,6 +160,7 @@ class NetworkErrorStatement: Statement about the network errors that occurred. Fields: + phase: the tunnel lifecycle phase in which the error occurred. sample_size: the number of samples. failure_ratio: the ratio of failures. error: the error that occurred. @@ -168,6 +169,7 @@ class NetworkErrorStatement: as_dict: Implements the Statement protocol. """ + phase: str sample_size: int | None failure_ratio: float error: str @@ -175,7 +177,7 @@ class NetworkErrorStatement: def as_dict(self) -> dict: """Implements the Statement protocol.""" return { - "phase": "creation", + "phase": self.phase, "sample_size": self.sample_size, "type": "network-error", "failure_ratio": self.failure_ratio, diff --git a/tests/aggregatetunnelmetrics/aggregators/test_privacy.py b/tests/aggregatetunnelmetrics/aggregators/test_privacy.py index 37a2480..8adf571 100644 --- a/tests/aggregatetunnelmetrics/aggregators/test_privacy.py +++ b/tests/aggregatetunnelmetrics/aggregators/test_privacy.py @@ -25,7 +25,7 @@ def test_filter_network_error(): config = privacy.Config(min_sample_size=1000, round_to=100) stmt = metrics.NetworkErrorStatement( - sample_size=1234, failure_ratio=0.1234, error="test_error" + phase="creation", sample_size=1234, failure_ratio=0.1234, error="test_error" ) filtered = privacy.filter_network_error(stmt, config) @@ -35,7 +35,7 @@ def test_filter_network_error(): # Test below minimum sample size stmt = metrics.NetworkErrorStatement( - sample_size=500, failure_ratio=0.1, error="test_error" + phase="creation", sample_size=500, failure_ratio=0.1, error="test_error" ) filtered = privacy.filter_network_error(stmt, config) assert filtered.sample_size is None @@ -84,7 +84,7 @@ def test_filter_tunnel_ndt(): def test_filter_network_error_none_sample(): """Test filtering network error statement with None sample size""" stmt = metrics.NetworkErrorStatement( - sample_size=None, failure_ratio=0.5, error="test_error" + phase="creation", sample_size=None, failure_ratio=0.5, error="test_error" ) config = privacy.Config() @@ -116,7 +116,7 @@ def test_filter_tunnel_ndt_none_sample(): def test_filter_statement_network_error(): """Test filtering network error through generic filter""" stmt = metrics.NetworkErrorStatement( - sample_size=1234, failure_ratio=0.5, error="test_error" + phase="creation", sample_size=1234, failure_ratio=0.5, error="test_error" ) config = privacy.Config() @@ -146,7 +146,7 @@ def test_filter_test_keys(): # Create a mix of different statement types statements = [ metrics.NetworkErrorStatement( - sample_size=1234, failure_ratio=0.5, error="test_error" + phase="creation", sample_size=1234, failure_ratio=0.5, error="test_error" ), metrics.TunnelPingStatement( target_address="1.1.1.1", diff --git a/tests/aggregatetunnelmetrics/spec/test_metrics.py b/tests/aggregatetunnelmetrics/spec/test_metrics.py index dc9b65e..629ebe0 100644 --- a/tests/aggregatetunnelmetrics/spec/test_metrics.py +++ b/tests/aggregatetunnelmetrics/spec/test_metrics.py @@ -73,7 +73,7 @@ def test_endpoint_scope(): def test_network_error_statement(): stmt = metrics.NetworkErrorStatement( - sample_size=100, failure_ratio=0.1, error="connection_failed" + phase="creation", sample_size=100, failure_ratio=0.1, error="connection_failed" ) assert stmt.as_dict() == { "phase": "creation", -- GitLab From b2f220d553ea4d414d89a9c1ed3a4f7cec81a500 Mon Sep 17 00:00:00 2001 From: Simone Basso <bassosimone@gmail.com> Date: Sun, 9 Mar 2025 11:39:20 +0100 Subject: [PATCH 75/75] doc: document fundamental issue with ERROR/tunnel --- aggregatetunnelmetrics/spec/fieldtesting.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/aggregatetunnelmetrics/spec/fieldtesting.py b/aggregatetunnelmetrics/spec/fieldtesting.py index 3a84345..2a4cac9 100644 --- a/aggregatetunnelmetrics/spec/fieldtesting.py +++ b/aggregatetunnelmetrics/spec/fieldtesting.py @@ -66,6 +66,16 @@ class Entry: def is_tunnel_error_measurement(self) -> bool: """Return whether this is a failed tunnel measurement""" + # TODO(bassosimone): this method is used wrongly and my assumption + # about the semantics of this error is also completely wrong. + # + # The `ERROR/tunnel` error means that anything went wrong during the + # lifecycle of a tunnel AND DOES NOT MEAN that we failed to start + # the tunnel itself. + # + # As a result, we need to revisit the logic. + # + # See https://0xacab.org/solitech/monitoring/-/merge_requests/6#note_1247230. return self.tunnel == "ERROR/tunnel" -- GitLab