Unverified Commit 3801eb3e authored by Niklas Adolfsson's avatar Niklas Adolfsson
Browse files

fix more grumbles: no more Infallible

parent e2f1ed5f
Pipeline #207741 canceled with stages
in 3 minutes and 37 seconds
......@@ -24,7 +24,6 @@
// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
// DEALINGS IN THE SOFTWARE.
use std::convert::Infallible;
use std::future::Future;
use std::net::{SocketAddr, TcpListener as StdTcpListener};
use std::pin::Pin;
......@@ -35,7 +34,6 @@ use crate::response::{internal_error, malformed};
use futures_channel::mpsc;
use futures_util::future::FutureExt;
use futures_util::stream::{StreamExt, TryStreamExt};
use futures_util::TryFutureExt;
use hyper::body::HttpBody;
use hyper::header::{HeaderMap, HeaderValue};
use hyper::server::conn::AddrStream;
......@@ -468,10 +466,7 @@ struct ServiceData<L> {
impl<L: Logger> ServiceData<L> {
/// Default behavior for handling the RPC requests.
async fn handle_request(
self,
request: hyper::Request<hyper::Body>,
) -> Result<hyper::Response<hyper::Body>, Infallible> {
async fn handle_request(self, request: hyper::Request<hyper::Body>) -> hyper::Response<hyper::Body> {
let ServiceData {
remote_addr,
methods,
......@@ -492,23 +487,23 @@ impl<L: Logger> ServiceData<L> {
let host = match http_helpers::read_header_value(request.headers(), "host") {
Some(origin) => origin,
None => return Ok(malformed()),
None => return malformed(),
};
let maybe_origin = http_helpers::read_header_value(request.headers(), "origin");
if let Err(e) = acl.verify_host(host) {
tracing::warn!("Denied request: {:?}", e);
return Ok(response::host_not_allowed());
return response::host_not_allowed();
}
if let Err(e) = acl.verify_origin(maybe_origin, host) {
tracing::warn!("Denied request: {:?}", e);
return Ok(response::invalid_allow_origin());
return response::invalid_allow_origin();
}
if let Err(e) = acl.verify_headers(keys, cors_request_headers) {
tracing::warn!("Denied request: {:?}", e);
return Ok(response::invalid_allow_headers());
return response::invalid_allow_headers();
}
// Only `POST` and `OPTIONS` methods are allowed.
......@@ -518,13 +513,13 @@ impl<L: Logger> ServiceData<L> {
Method::OPTIONS => {
let origin = match maybe_origin {
Some(origin) => origin,
None => return Ok(malformed()),
None => return malformed(),
};
let allowed_headers = acl.allowed_headers().to_cors_header_value();
let allowed_header_bytes = allowed_headers.as_bytes();
let res = hyper::Response::builder()
hyper::Response::builder()
.header("access-control-allow-origin", origin)
.header("access-control-allow-methods", "POST")
.header("access-control-allow-headers", allowed_header_bytes)
......@@ -532,9 +527,7 @@ impl<L: Logger> ServiceData<L> {
.unwrap_or_else(|e| {
tracing::error!("Error forming preflight response: {}", e);
internal_error()
});
Ok(res)
})
}
// The actual request. If it's a CORS request we need to remember to add
// the access-control-allow-origin header (despite preflight) to allow it
......@@ -552,12 +545,12 @@ impl<L: Logger> ServiceData<L> {
batch_requests_supported,
request_start,
})
.await?;
.await;
if let Some(origin) = origin {
res.headers_mut().insert("access-control-allow-origin", origin);
}
Ok(res)
res
}
Method::GET => match health_api.as_ref() {
Some(health) if health.path.as_str() == request.uri().path() => {
......@@ -571,11 +564,11 @@ impl<L: Logger> ServiceData<L> {
)
.await
}
_ => Ok(response::method_not_allowed()),
_ => response::method_not_allowed(),
},
// Error scenarios:
Method::POST => Ok(response::unsupported_content_type()),
_ => Ok(response::method_not_allowed()),
Method::POST => response::unsupported_content_type(),
_ => response::method_not_allowed(),
}
}
}
......@@ -611,7 +604,7 @@ impl<L: Logger> hyper::service::Service<hyper::Request<hyper::Body>> for TowerSe
// Note that `handle_request` will never return error.
// The dummy error is set in place to satisfy the server's trait bounds regarding the
// `tower::ServiceBuilder` and the error will never be mapped.
Box::pin(data.handle_request(request).map_err(|_| "".into()))
Box::pin(data.handle_request(request).map(|res| Ok(res)))
}
}
......@@ -764,9 +757,7 @@ struct ProcessValidatedRequest<L: Logger> {
}
/// Process a verified request, it implies a POST request with content type JSON.
async fn process_validated_request<L: Logger>(
input: ProcessValidatedRequest<L>,
) -> Result<hyper::Response<hyper::Body>, Infallible> {
async fn process_validated_request<L: Logger>(input: ProcessValidatedRequest<L>) -> hyper::Response<hyper::Body> {
let ProcessValidatedRequest {
request,
logger,
......@@ -783,11 +774,11 @@ async fn process_validated_request<L: Logger>(
let (body, is_single) = match read_body(&parts.headers, body, max_request_body_size).await {
Ok(r) => r,
Err(GenericTransportError::TooLarge) => return Ok(response::too_large(max_request_body_size)),
Err(GenericTransportError::Malformed) => return Ok(response::malformed()),
Err(GenericTransportError::TooLarge) => return response::too_large(max_request_body_size),
Err(GenericTransportError::Malformed) => return response::malformed(),
Err(GenericTransportError::Inner(e)) => {
tracing::error!("Internal error reading request body: {}", e);
return Ok(response::internal_error());
return response::internal_error();
}
};
......@@ -804,7 +795,7 @@ async fn process_validated_request<L: Logger>(
};
let response = process_single_request(body, call).await;
logger.on_response(&response.result, request_start);
Ok(response::ok_response(response.result))
response::ok_response(response.result)
}
// Batch of requests or notifications
else if !batch_requests_supported {
......@@ -813,7 +804,7 @@ async fn process_validated_request<L: Logger>(
ErrorObject::borrowed(BATCHES_NOT_SUPPORTED_CODE, &BATCHES_NOT_SUPPORTED_MSG, None),
);
logger.on_response(&err.result, request_start);
Ok(response::ok_response(err.result))
response::ok_response(err.result)
}
// Batch of requests or notifications
else {
......@@ -831,7 +822,7 @@ async fn process_validated_request<L: Logger>(
})
.await;
logger.on_response(&response.result, request_start);
Ok(response::ok_response(response.result))
response::ok_response(response.result)
}
}
......@@ -842,7 +833,7 @@ async fn process_health_request<L: Logger>(
max_response_body_size: u32,
request_start: L::Instant,
max_log_length: u32,
) -> Result<hyper::Response<hyper::Body>, Infallible> {
) -> hyper::Response<hyper::Body> {
let trace = RpcTracing::method_call(&health_api.method);
async {
tx_log_from_str("HTTP health API", max_log_length);
......@@ -874,9 +865,9 @@ async fn process_health_request<L: Logger>(
let payload: RpcPayload = serde_json::from_str(&response.result)
.expect("valid JSON-RPC response must have a result field and be valid JSON; qed");
Ok(response::ok_response(payload.result.to_string()))
response::ok_response(payload.result.to_string())
} else {
Ok(response::internal_error())
response::internal_error()
}
}
.instrument(trace.into_span())
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment