diff --git a/Cargo.lock b/Cargo.lock index a28b9be..3120e6b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1162,6 +1162,7 @@ name = "rb-gw" version = "0.1.0" dependencies = [ "base64", + "bytes", "chrono", "diesel", "diesel_migrations", @@ -1177,6 +1178,7 @@ dependencies = [ "rust-argon2", "serde", "sha2", + "tokio-util", "uuid", ] diff --git a/Cargo.toml b/Cargo.toml index 9cb2f47..1cca418 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,6 +40,8 @@ base64 = "0.13.0" figment = { version = "*", features = [ "yaml" ] } mimalloc = { version = "0.1.26", default_features = false } reqwest = { version = "0.11.6", features = [ "stream" ] } +tokio-util = {version="*", features=["io"]} +bytes = "*" [profile.dev] lto = "off" diff --git a/src/main.rs b/src/main.rs index e8462db..e17d557 100644 --- a/src/main.rs +++ b/src/main.rs @@ -129,10 +129,10 @@ fn rocket() -> _ .expect("services config"); rocket - .mount("/v1/posts", ProxyServer::from(services_conf.blog.clone())) + .mount("/api/v1/posts", ProxyServer::from(&services_conf.blog, "/v1/posts")) .mount( - "/v1/sections", - ProxyServer::from(services_conf.blog.clone()), + "/api/v1/sections", + ProxyServer::from(&services_conf.blog, "/v1/sections"), ) .manage(jwt_conf) .manage(admin_conf) diff --git a/src/proxy.rs b/src/proxy.rs index 861ae62..53a7605 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -1,11 +1,13 @@ +use std::path::{Path, PathBuf}; use reqwest::{ header::{HeaderMap, HeaderName, HeaderValue}, Method as ReqMethod, }; use rocket::{ + response::stream::ByteStream, data::ToByteUnit, - http::{ Method, Header }, - response::Redirect, + http::{Header, Method}, + response::{Redirect}, route::{Handler, Outcome}, Data, Request, Route, }; @@ -14,6 +16,7 @@ use rocket::{ pub struct ProxyServer { root: String, + prefix: PathBuf, rank: isize, } @@ -21,14 +24,14 @@ impl ProxyServer { const DEFAULT_RANK: isize = 0; - pub fn new(root: String, rank: isize) -> Self + pub fn new>(root: &str, prefix: P, rank: isize) -> Self { - ProxyServer { root, rank } + ProxyServer { root: String::from(root), prefix: prefix.as_ref().into(), rank } } - pub fn from(root: String) -> Self + pub fn from>(root: &str, prefix: P) -> Self { - Self::new(root, Self::DEFAULT_RANK) + Self::new(root, prefix, Self::DEFAULT_RANK) } } @@ -49,7 +52,7 @@ impl Into> for ProxyServer for method in METHODS { let mut route = Route::ranked(self.rank, method, "/", self.clone()); route.name = - Some(format!("ProxyServer: {} {}", method.as_str(), self.root.clone()).into()); + Some(format!("ProxyServer: {} {}{}", method.as_str(), self.root.clone(), self.prefix.to_str().unwrap()).into()); routes.push(route); } @@ -88,7 +91,16 @@ impl Handler for ProxyServer }; // Then the URL - let url = format!("{}{}?{}", self.root, req.uri().path().as_str(), req.uri().query().unwrap()); + // We first extract all URL segments starting from the mountpoint + let segments: PathBuf = req.segments(0..).unwrap(); + + let url = format!( + "{}{}/{}?{}", + self.root, + self.prefix.to_str().unwrap(), + segments.to_str().unwrap(), + req.uri().query().unwrap() + ); // And finally, the data // TODO don't hard-code max request size here @@ -111,17 +123,19 @@ impl Handler for ProxyServer .unwrap(); println!("{:?}", new_res); - let mut rocket_res = rocket::Response::new(); - rocket_res.set_status(rocket::http::Status::new(new_res.status().as_u16())); - // rocket_res.set_streamed_body(new_res.bytes_stream()); + let mut res_builder = rocket::Response::build(); + let mut res_builder = res_builder.status(rocket::http::Status::new(new_res.status().as_u16())); + // .streamed_body(ByteStream::from(new_res.bytes_stream())); // reqwest headers -> rocket headers for (key, value) in new_res.headers().clone().iter_mut() { - println!("{}", rocket_res.set_raw_header(String::from(key.clone().as_str()), String::from(value.clone().to_str().unwrap()))); + res_builder = res_builder.raw_header( + String::from(key.clone().as_str()), + String::from(value.clone().to_str().unwrap()) + ); + } - println!("{:?}", rocket_res); - - Outcome::Success(rocket_res) + Outcome::Success(res_builder.finalize()) } }