From 0f7eba717e5ec478ed4e78f7ba9e638b59962d42 Mon Sep 17 00:00:00 2001 From: Matthew Esposito Date: Fri, 28 Jun 2024 22:39:42 -0400 Subject: [PATCH] fix(client): Handle invalid reddit response of base URL location --- src/client.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/client.rs b/src/client.rs index 7281df13..d6473ca9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -3,6 +3,7 @@ use cached::proc_macro::cached; use futures_lite::future::block_on; use futures_lite::{future::Boxed, FutureExt}; use hyper::client::HttpConnector; +use hyper::header::HeaderValue; use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri}; use hyper_rustls::HttpsConnector; use libflate::gzip; @@ -21,6 +22,7 @@ use crate::server::RequestExt; use crate::utils::format_url; const REDDIT_URL_BASE: &str = "https://oauth.reddit.com"; +const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com"; pub static CLIENT: Lazy>> = Lazy::new(|| { let https = hyper_rustls::HttpsConnectorBuilder::new() @@ -221,12 +223,13 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo if !redirect { return Ok(response); }; - + let location_header = response.headers().get(header::LOCATION); + if location_header == Some(&HeaderValue::from_static("https://www.reddit.com/")) { + return Err("Reddit response was invalid".to_string()); + } return request( method, - response - .headers() - .get(header::LOCATION) + location_header .map(|val| { // We need to make adjustments to the URI // we get back from Reddit. Namely, we @@ -239,7 +242,11 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo // required. // // 2. Percent-encode the path. - let new_path = percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string(); + let new_path = percent_encode(val.as_bytes(), CONTROLS) + .to_string() + .trim_start_matches(REDDIT_URL_BASE) + .trim_start_matches(ALTERNATIVE_REDDIT_URL_BASE) + .to_string(); format!("{new_path}{}raw_json=1", if new_path.contains('?') { "&" } else { "?" }) }) .unwrap_or_default() @@ -298,7 +305,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo } } Err(e) => { - dbg_msg!("{} {}: {}", method, path, e); + dbg_msg!("{method} {REDDIT_URL_BASE}{path}: {}", e); Err(e.to_string()) }