From bab5cdaf2e03d155105d79abb5b54998bd29187c Mon Sep 17 00:00:00 2001 From: b1ek Date: Tue, 22 Oct 2024 13:43:47 +1000 Subject: [PATCH] use sessions instead of cache --- src/api.rs | 38 +++++++++++++++++++---- src/cache.rs | 81 -------------------------------------------------- src/main.rs | 19 +++++++----- src/session.rs | 57 +++++++++++++++++++++++++++++------ 4 files changed, 93 insertions(+), 102 deletions(-) delete mode 100644 src/cache.rs diff --git a/src/api.rs b/src/api.rs index b3c61fa..eea9112 100644 --- a/src/api.rs +++ b/src/api.rs @@ -5,7 +5,7 @@ use std::process::exit; use reqwest::{header::{HeaderMap, HeaderValue}, Client}; use serde::{Deserialize, Serialize}; -use crate::{cache::Cache, config::Config}; +use crate::{config::Config, session::Session}; use crate::{WARN, RED, RESET}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -20,6 +20,12 @@ pub struct ChatPayload { pub messages: Vec } +impl ChatMessagePayload { + pub fn is_ai(&self) -> bool { + self.role == "assistant" + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatChunk { pub role: Option, @@ -30,6 +36,15 @@ pub struct ChatChunk { pub model: Option } +impl Into for ChatChunk { + fn into(self) -> ChatMessagePayload { + ChatMessagePayload { + role: "assistant".to_string(), + content: self.message + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ErrChatChunk { pub action: String, @@ -87,10 +102,15 @@ pub async fn get_vqd(cli: &Client) -> Result> { } } -pub async fn get_res<'a>(cli: &Client, query: String, vqd: String, cache: &'a mut Cache, config: &Config) { +pub async fn get_res<'a>(cli: &Client, query: String, vqd: String, config: &Config) { + let init_msg = ChatMessagePayload { role: "user".into(), content: query }; + + let mut session = Session::create_or_restore(&vqd); + session.push_message(&init_msg).unwrap(); + let payload = ChatPayload { model: config.model.to_string(), - messages: vec![ ChatMessagePayload { role: "user".into(), content: query } ] + messages: session.get_messages() }; let payload = serde_json::to_string(&payload).unwrap(); @@ -105,15 +125,16 @@ pub async fn get_res<'a>(cli: &Client, query: String, vqd: String, cache: &'a mu let new_vqd = res.headers().iter().find(|x| x.0 == "x-vqd-4"); let vqd_set_res = if let Some(new_vqd) = new_vqd { - cache.set_last_vqd(new_vqd.1.as_bytes().iter().map(|x| char::from(*x)).collect::()) + session.set_last_vqd(new_vqd.1.as_bytes().iter().map(|x| char::from(*x)).collect::()) } else { eprintln!("{WARN}Warn: DuckDuckGo did not return new VQD. Ignore this if everything else is ok.{RESET}"); - cache.set_last_vqd(vqd.clone()) + session.set_last_vqd(vqd.clone()) }; if let Err(err) = vqd_set_res { eprintln!("{WARN}Warn: Could not save VQD to cache: {err}{RESET}"); } + let mut error = None; while let Some(chunk) = res.chunk().await.unwrap() { if let Ok(obj) = serde_json::from_slice::(&chunk) { @@ -127,9 +148,16 @@ pub async fn get_res<'a>(cli: &Client, query: String, vqd: String, cache: &'a mu let chunk = chunk.replace("data: ", ""); for line in chunk.lines() { if let Ok(obj) = serde_json::from_str::(line) { + if let Err(err) = session.push_ai_message_chunk(&obj) { + error = Some(err); + } print!("{}", obj.message); } } } + if let Some(err) = error { + eprintln!("Error while writing to session: {err:#?}"); + eprintln!("Session may be broken."); + } println!("\n"); } \ No newline at end of file diff --git a/src/cache.rs b/src/cache.rs deleted file mode 100644 index 969b9b0..0000000 --- a/src/cache.rs +++ /dev/null @@ -1,81 +0,0 @@ - -use std::{env, error::Error, fs, io, path::PathBuf}; -use home::home_dir; -use serde::{Deserialize, Serialize}; - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Cache { - pub last_vqd: String, - pub last_vqd_time: u64 -} - -impl Default for Cache { - fn default() -> Self { - Self { - last_vqd: "".into(), - last_vqd_time: 0 - } - } -} - -impl Cache { - - pub fn get_path>() -> T { - match env::var("HEY_CACHE_PATH") { - Ok(v) => v, - Err(_) => - match home_dir() { - Some(home) => home.join(".cache/hey").as_os_str().as_encoded_bytes().iter().map(|x| char::from(*x)).collect::(), - None => panic!("Cannot detect your home directory!") - } - }.into() - } - - pub fn get_file_name>() -> T { - match env::var("HEY_CACHE_FILENAME") { - Ok(v) => v, - Err(_) => "cache.json".into() - }.into() - } - - fn ensure_dir_exists() -> io::Result<()> { - let path: PathBuf = Self::get_path(); - if ! path.is_dir() { fs::create_dir_all(path)? } - Ok(()) - } - - pub fn load() -> Result> { - let path: PathBuf = Self::get_path(); - Self::ensure_dir_exists()?; - - let file_path = path.join(Self::get_file_name::()); - if ! file_path.is_file() { - let def = Self::default(); - def.save()?; - Ok(def) - } else { - let file = fs::read_to_string(file_path)?; - Ok(serde_json::from_str(&file)?) - } - } - - pub fn save(self: &Self) -> Result<(), Box> { - let path: PathBuf = Self::get_path(); - Self::ensure_dir_exists()?; - - let file_path = path.join(Self::get_file_name::()); - fs::write(file_path, serde_json::to_string_pretty(self)?)?; - Ok(()) - } - - pub fn set_last_vqd>(self: &mut Self, vqd: T) -> Result<(), Box> { - self.last_vqd = vqd.into(); - self.last_vqd_time = chrono::Local::now().timestamp_millis() as u64; - self.save()?; - Ok(()) - } - - pub fn get_last_vqd<'a, T: From<&'a String>>(self: &'a Self) -> Option { - None - } -} diff --git a/src/main.rs b/src/main.rs index 78a265a..a5b8a3c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,12 +3,12 @@ use std::process::exit; use reqwest::Client; use clap::Parser; +use session::Session; use std::io::{stdout, IsTerminal}; -use crate::{cache::Cache, config::Config}; +use crate::config::Config; use crate::api::{get_res, get_vqd, simulate_browser_reqs}; -mod cache; mod config; mod api; mod session; @@ -40,9 +40,12 @@ async fn main() { } let args = Args::parse(); - let query = args.query.join(" "); + let query = args.query.join(" ").trim().to_string(); + + if query.len() == 0 { + exit(0); + } - let mut cache = Cache::load().unwrap(); let mut config = Config::load().unwrap(); if args.agree_tos { @@ -59,11 +62,13 @@ async fn main() { let cli = Client::new(); simulate_browser_reqs(&cli).await.unwrap(); - let vqd = match cache.get_last_vqd() { - Some(v) => { println!("using cached vqd"); v}, + let vqd = match Session::restore_vqd() { + Some(v) => { v }, None => get_vqd(&cli).await.unwrap() }; - get_res(&cli, query, vqd, &mut cache, &config).await; + println!("{vqd:?}"); + + get_res(&cli, query, vqd, &config).await; } diff --git a/src/session.rs b/src/session.rs index b4a25cc..3bfb95d 100644 --- a/src/session.rs +++ b/src/session.rs @@ -6,13 +6,14 @@ use std::time::Duration; use chrono::{DateTime, Local}; use serde::{Deserialize, Serialize}; -use crate::api::ChatMessagePayload; +use crate::api::{ChatChunk, ChatMessagePayload}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Session { session_id: String, ttl: DateTime, - messages: Vec + messages: Vec, + vqd: String } impl Session { @@ -66,11 +67,12 @@ impl Session { } match inner(Self::path_for_id(id)) { - Ok(session) => { + Ok(mut session) => { if session.is_expired() { session.destroy().expect("Couldn't destroy expired session"); None } else { + session.increase_ttl().expect("Couldn't increase TTL"); Some(session) } }, @@ -78,16 +80,25 @@ impl Session { } } - fn new>(id: T) -> Self { + fn get_ttl() -> DateTime { + Local::now() + Duration::from_secs(60 * 5) + } + + pub fn increase_ttl(&mut self) -> Result<(), Box> { + self.ttl = Self::get_ttl(); + Ok(self.save()?) + } + + fn new, V: Into>(id: I, vqd: V) -> Self { Self { session_id: id.into(), - ttl: Local::now() + Duration::from_secs(60 * 5), - messages: vec![] + ttl: Self::get_ttl(), + messages: vec![], + vqd: vqd.into() } } fn save(&self) -> Result<(), Box> { - println!("{:?}", self.path()); fs::write(self.path(), serde_json::to_string_pretty(self)?)?; Ok(()) } @@ -100,15 +111,43 @@ impl Session { Ok(fs::remove_file(self.path())?) } - pub fn create_or_restore() -> Self { + pub fn create_or_restore>(vqd: T) -> Self { let session_id: String = Self::terminal_session_id(); match Self::restore_with_id(&session_id) { Some(session) => session, None => { - let session = Self::new(&session_id); + let session = Self::new(&session_id, vqd); session.save().expect("Couldn't save new session"); session } } } + + pub fn push_message(&mut self, msg: &ChatMessagePayload) -> Result<(), Box> { + self.messages.push(msg.clone()); + Ok(self.save()?) + } + + pub fn push_ai_message_chunk(&mut self, chunk: &ChatChunk) -> Result<(), Box> { + if self.messages.last().unwrap().is_ai() { + self.messages.last_mut().unwrap().content += chunk.message.as_str(); + } else { + self.messages.push(chunk.clone().into()); + } + + Ok(self.save()?) + } + + pub fn get_messages(&self) -> Vec { + self.messages.clone() + } + + pub fn restore_vqd() -> Option { + Self::restore_with_id(Self::terminal_session_id::()).map(|x| x.vqd) + } + + pub fn set_last_vqd>(&mut self, vqd: T) -> Result<(), Box> { + self.vqd = vqd.into(); + Ok(self.save()?) + } } \ No newline at end of file