diff --git a/Cargo.lock b/Cargo.lock index 3120e6b..38f0e19 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1167,6 +1167,7 @@ dependencies = [ "diesel", "diesel_migrations", "figment", + "futures", "hmac", "jwt", "mimalloc", @@ -1512,9 +1513,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.71" +version = "1.0.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "063bf466a64011ac24040a49009724ee60a57da1b437617ceb32e53ad61bfb19" +checksum = "d0ffa0837f2dfa6fb90868c2b5468cad482e175f7dad97e7421951e663f2b527" dependencies = [ "itoa", "ryu", @@ -1694,9 +1695,9 @@ checksum = "6bdef32e8150c2a081110b42772ffe7d7c9032b606bc226c8260fd97e0976601" [[package]] name = "syn" -version = "1.0.81" +version = "1.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2afee18b8beb5a596ecb4a2dce128c719b4ba399d34126b9e4396e3f9860966" +checksum = "8daf5dd0bb60cbd4137b1b587d2fc0ae729bc07cf01cd70b36a1ed5ade3b9d59" dependencies = [ "proc-macro2", "quote", @@ -1848,6 +1849,7 @@ checksum = "9e99e1983e5d376cd8eb4b66604d2e99e79f5bd988c3055891dcd8c9e2604cc0" dependencies = [ "bytes", "futures-core", + "futures-io", "futures-sink", "log", "pin-project-lite", diff --git a/Cargo.toml b/Cargo.toml index 1cca418..4bb2b2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -40,8 +40,9 @@ base64 = "0.13.0" figment = { version = "*", features = [ "yaml" ] } mimalloc = { version = "0.1.26", default_features = false } reqwest = { version = "0.11.6", features = [ "stream" ] } -tokio-util = {version="*", features=["io"]} +tokio-util = {version="*", features=["compat"]} bytes = "*" +futures = "0.3.18" [profile.dev] lto = "off" diff --git a/src/proxy.rs b/src/proxy.rs index fce6510..8801647 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -1,16 +1,21 @@ -use std::path::{Path, PathBuf}; +use std::{ + io::Cursor, + path::{Path, PathBuf}, +}; + +use futures::stream::TryStreamExt; use reqwest::{ header::{HeaderMap, HeaderName, HeaderValue}, Method as ReqMethod, }; use rocket::{ - response::stream::ByteStream, data::ToByteUnit, http::{Header, Method}, - response::{Redirect}, + response::{stream::ByteStream, Redirect}, route::{Handler, Outcome}, Data, Request, Route, }; +use tokio_util::compat::FuturesAsyncReadCompatExt; #[derive(Clone)] pub struct ProxyServer @@ -26,7 +31,11 @@ impl ProxyServer pub fn new>(root: &str, prefix: P, rank: isize) -> Self { - ProxyServer { root: String::from(root), prefix: prefix.as_ref().into(), rank } + ProxyServer { + root: String::from(root), + prefix: prefix.as_ref().into(), + rank, + } } pub fn from>(root: &str, prefix: P) -> Self @@ -51,8 +60,15 @@ impl Into> for ProxyServer for method in METHODS { let mut route = Route::ranked(self.rank, method, "/", self.clone()); - route.name = - Some(format!("ProxyServer: {} {}{}", method.as_str(), self.root.clone(), self.prefix.to_str().unwrap()).into()); + route.name = Some( + format!( + "ProxyServer: {} {}{}", + method.as_str(), + self.root.clone(), + self.prefix.to_str().unwrap() + ) + .into(), + ); routes.push(route); } @@ -94,7 +110,10 @@ impl Handler for ProxyServer // We first extract all URL segments starting from the mountpoint let segments: PathBuf = req.segments(0..).unwrap(); - let query_part = req.uri().query().map_or(String::from(""), |s| format!("?{}", s)); + let query_part = req + .uri() + .query() + .map_or(String::from(""), |s| format!("?{}", s)); let url = format!( "{}{}/{}{}", @@ -123,19 +142,25 @@ impl Handler for ProxyServer .send() .await .unwrap(); - println!("{:?}", new_res); + let mut res_headers = new_res.headers().clone(); let mut res_builder = rocket::Response::build(); - let mut res_builder = res_builder.status(rocket::http::Status::new(new_res.status().as_u16())); - // .streamed_body(ByteStream::from(new_res.bytes_stream())); - - // reqwest headers -> rocket headers - for (key, value) in new_res.headers().clone().iter_mut() { - res_builder = res_builder.raw_header( - String::from(key.clone().as_str()), - String::from(value.clone().to_str().unwrap()) + let mut res_builder = res_builder + .status(rocket::http::Status::new(new_res.status().as_u16())) + .streamed_body( + new_res + .bytes_stream() + .map_err(|e| futures::io::Error::new(futures::io::ErrorKind::Other, e)) + .into_async_read() + .compat(), ); + // reqwest headers -> rocket headers + for (key, value) in res_headers.iter_mut() { + res_builder = res_builder.raw_header( + String::from(key.clone().as_str()), + String::from(value.clone().to_str().unwrap()), + ); } Outcome::Success(res_builder.finalize())