Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ edition = "2021"
axum = "0.6.20"
axum-prometheus = "0.4.0"
base64 = "0.21.4"
calendar-duration = "1.0.0"
clap = { version = "4.4.4", features = ["derive"] }
metrics-exporter-prometheus = "0.12.1"
ppoprf = "0.3.1"
Expand Down
105 changes: 86 additions & 19 deletions src/handler.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
//! STAR Randomness web service route implementation

use axum::extract::{Json, State};
use std::sync::RwLockReadGuard;

use axum::extract::{Json, Path, State};
use axum::http::StatusCode;
use base64::prelude::{Engine as _, BASE64_STANDARD as BASE64};
use serde::{Deserialize, Serialize};
use tracing::debug;
use tracing::{debug, instrument};

use crate::OPRFState;
use crate::state::{OPRFInstance, OPRFState};
use ppoprf::ppoprf;

/// Request format for the randomness endpoint
/// Request structure for the randomness endpoint
#[derive(Deserialize, Debug)]
pub struct RandomnessRequest {
/// Array of points to evaluate
Expand All @@ -19,7 +21,7 @@ pub struct RandomnessRequest {
epoch: Option<u8>,
}

/// Response format for the randomness endpoint
/// Response structure for the randomness endpoint
#[derive(Serialize, Debug)]
pub struct RandomnessResponse {
/// Resulting points from the OPRF valuation
Expand All @@ -30,26 +32,34 @@ pub struct RandomnessResponse {
epoch: u8,
}

/// Response format for the info endpoint
/// Response structure for the info endpoint
/// Rename fields to match the earlier golang implementation.
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct InfoResponse {
/// ServerPublicKey used to verify zero-knowledge proof
#[serde(rename = "publicKey")]
public_key: String,
/// Currently active randomness epoch
#[serde(rename = "currentEpoch")]
current_epoch: u8,
/// Timestamp of the next epoch rotation
/// This should be a string in RFC 3339 format,
/// e.g. 2023-03-14T16:33:05Z.
#[serde(rename = "nextEpochTime")]
next_epoch_time: Option<String>,
/// Maximum number of points accepted in a single request
#[serde(rename = "maxPoints")]
max_points: usize,
}

/// Response structure for the "list instances" endpoint.
#[derive(Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct ListInstancesResponse {
/// A list of available instances on the server.
instances: Vec<String>,
/// The default instance on this server.
/// A requests made to /info and /randomness will utilize this instance.
default_instance: String,
}

/// Response returned to report error conditions
#[derive(Serialize, Debug)]
struct ErrorResponse {
Expand All @@ -63,6 +73,8 @@ struct ErrorResponse {
/// handling requests.
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("instance '{0}' not found")]
InstanceNotFound(String),
#[error("Couldn't lock state: RwLock poisoned")]
LockFailure,
#[error("Invalid point")]
Expand Down Expand Up @@ -93,6 +105,7 @@ impl axum::response::IntoResponse for Error {
/// Construct an http response from our error type
fn into_response(self) -> axum::response::Response {
let code = match self {
Error::InstanceNotFound(_) => StatusCode::NOT_FOUND,
// This indicates internal failure.
Error::LockFailure => StatusCode::INTERNAL_SERVER_ERROR,
// Other cases are the client's fault.
Expand All @@ -105,13 +118,28 @@ impl axum::response::IntoResponse for Error {
}
}

type Result<T> = std::result::Result<T, Error>;

fn get_server_from_state<'a>(
state: &'a OPRFState,
instance_name: &'a str,
) -> Result<RwLockReadGuard<'a, OPRFInstance>> {
Ok(state
.instances
.get(instance_name)
.ok_or_else(|| Error::InstanceNotFound(instance_name.to_string()))?
.read()?)
}

/// Process PPOPRF evaluation requests
pub async fn randomness(
State(state): State<OPRFState>,
Json(request): Json<RandomnessRequest>,
) -> Result<Json<RandomnessResponse>, Error> {
#[instrument(skip(state, request))]
async fn randomness(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are no await calls in here, this function could be synchronous, handling the full request in a single future. Might improve throughput, but I didn't check.

state: OPRFState,
instance_name: String,
request: RandomnessRequest,
) -> Result<Json<RandomnessResponse>> {
debug!("recv: {request:?}");
let state = state.read()?;
let state = get_server_from_state(&state, &instance_name)?;
let epoch = request.epoch.unwrap_or(state.epoch);
if epoch != state.epoch {
return Err(Error::BadEpoch(epoch));
Expand All @@ -138,12 +166,29 @@ pub async fn randomness(
Ok(Json(response))
}

/// Process PPOPRF epoch and key requests
pub async fn info(
/// Process PPOPRF evaluation requests using default instance
pub async fn default_instance_randomness(
State(state): State<OPRFState>,
) -> Result<Json<InfoResponse>, Error> {
Json(request): Json<RandomnessRequest>,
) -> Result<Json<RandomnessResponse>> {
let instance_name = state.default_instance.clone();
randomness(state, instance_name, request).await
}

/// Process PPOPRF evaluation requests using specific instance
pub async fn specific_instance_randomness(
State(state): State<OPRFState>,
Path(instance_name): Path<String>,
Json(request): Json<RandomnessRequest>,
) -> Result<Json<RandomnessResponse>> {
randomness(state, instance_name, request).await
}

/// Provide PPOPRF epoch and key metadata
#[instrument(skip(state))]
async fn info(state: OPRFState, instance_name: String) -> Result<Json<InfoResponse>> {
debug!("recv: info request");
let state = state.read()?;
let state = get_server_from_state(&state, &instance_name)?;
let public_key = state.server.get_public_key().serialize_to_bincode()?;
let public_key = BASE64.encode(public_key);
let response = InfoResponse {
Expand All @@ -155,3 +200,25 @@ pub async fn info(
debug!("send: {response:?}");
Ok(Json(response))
}

/// Provide PPOPRF epoch and key metadata using default instance
pub async fn default_instance_info(State(state): State<OPRFState>) -> Result<Json<InfoResponse>> {
let instance_name = state.default_instance.clone();
info(state, instance_name).await
}

/// Provide PPOPRF epoch and key metadata using specific instance
pub async fn specific_instance_info(
State(state): State<OPRFState>,
Path(instance_name): Path<String>,
) -> Result<Json<InfoResponse>> {
info(state, instance_name).await
}

// Lists all available instances, as well as the default instance
pub async fn list_instances(State(state): State<OPRFState>) -> Result<Json<ListInstancesResponse>> {
Ok(Json(ListInstancesResponse {
instances: state.instances.keys().cloned().collect(),
default_instance: state.default_instance.clone(),
}))
}
68 changes: 43 additions & 25 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,23 @@

use axum::{routing::get, routing::post, Router};
use axum_prometheus::PrometheusMetricLayer;
use calendar_duration::CalendarDuration;
use clap::Parser;
use metrics_exporter_prometheus::PrometheusHandle;
use rlimit::Resource;
use std::sync::{Arc, RwLock};
use state::{OPRFServer, OPRFState};
use tikv_jemallocator::Jemalloc;
use time::format_description::well_known::Rfc3339;
use time::OffsetDateTime;
use tracing::{debug, info, metadata::LevelFilter};
use tracing_subscriber::EnvFilter;
use util::{assert_unique_names, parse_timestamp};

#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;

mod handler;
mod state;

pub use state::OPRFState;
mod util;

#[cfg(test)]
mod tests;
Expand All @@ -27,15 +27,22 @@ mod tests;
const MAX_POINTS: usize = 1024;

/// Command line switches
#[derive(Parser, Debug)]
#[derive(Parser, Debug, Clone)]
#[command(author, version, about, long_about = None)]
pub struct Config {
/// Host and port to listen for http connections
#[arg(long, default_value = "127.0.0.1:8080")]
listen: String,
/// Duration of each randomness epoch
#[arg(long, default_value_t = 5)]
epoch_seconds: u32,
/// Name of OPRF instance contained in server. Multiple instances may be defined
/// by defining this switch multiple times. The first defined instance will
/// become the default instance.
#[arg(long = "instance-name", default_value = "main")]
instance_names: Vec<String>,
/// Duration of each randomness epoch. This switch may be defined multiple times
/// to set the epoch length for each respective instance, if multiple instances
/// are defined.
#[arg(long = "epoch-duration", value_name = "Duration string i.e. 1mon5h2s", default_values = ["5s"])]
epoch_durations: Vec<CalendarDuration>,
/// First epoch tag to make available
#[arg(long, default_value_t = 0)]
first_epoch: u8,
Expand All @@ -56,20 +63,25 @@ pub struct Config {
prometheus_listen: Option<String>,
}

/// Parse a timestamp given as a config option
fn parse_timestamp(stamp: &str) -> Result<OffsetDateTime, &'static str> {
OffsetDateTime::parse(stamp, &Rfc3339).map_err(|_| "Try something like '2023-05-15T04:30:00Z'.")
}

/// Initialize an axum::Router for our web service
/// Having this as a separate function makes testing easier.
fn app(oprf_state: OPRFState) -> Router {
Router::new()
// Friendly default route to identify the site
.route("/", get(|| async { "STAR randomness server\n" }))
// Main endpoints
.route("/randomness", post(handler::randomness))
.route("/info", get(handler::info))
// Endpoints for all instances
.route(
"/instances/:instance/randomness",
post(handler::specific_instance_randomness),
)
.route(
"/instances/:instance/info",
get(handler::specific_instance_info),
)
.route("/instances", get(handler::list_instances))
// Endpoints for default instance
.route("/randomness", post(handler::default_instance_randomness))
.route("/info", get(handler::default_instance_info))
// Attach shared state
.with_state(oprf_state)
// Logging must come after active routes
Expand Down Expand Up @@ -126,22 +138,28 @@ async fn main() {
increase_nofile_limit();
}

// Oblivious function state
info!("initializing OPRF state...");
let server = state::OPRFServer::new(&config).expect("Could not initialize PPOPRF state");
info!("epoch now {}", server.epoch);
let oprf_state = Arc::new(RwLock::new(server));
assert_unique_names(&config.instance_names);
assert!(
!config.epoch_durations.iter().any(|d| d.is_zero()),
"all epoch lengths must be non-zero"
);
assert!(
!config.instance_names.is_empty(),
"at least one instance name must be defined"
);
assert!(
config.instance_names.len() == config.epoch_durations.len(),
"instance-name switch count must match epoch-seconds switch count"
);

let metric_layer = config.prometheus_listen.as_ref().map(|listen| {
let (layer, handle) = PrometheusMetricLayer::pair();
start_prometheus_server(handle, listen.clone());
layer
});

// Spawn a background process to advance the epoch
info!("Spawning background epoch rotation task...");
let background_state = oprf_state.clone();
tokio::spawn(async move { state::epoch_loop(background_state, &config).await });
let oprf_state = OPRFServer::new(&config);
oprf_state.start_background_tasks(&config);

// Set up routes and middleware
info!("initializing routes...");
Expand Down
Loading