use sessions instead of cache

This commit is contained in:
b1ek 2024-10-22 13:43:47 +10:00
parent e083420c5c
commit bab5cdaf2e
Signed by: blek
GPG Key ID: A622C22C9BC616B2
4 changed files with 93 additions and 102 deletions

View File

@ -5,7 +5,7 @@ use std::process::exit;
use reqwest::{header::{HeaderMap, HeaderValue}, Client}; use reqwest::{header::{HeaderMap, HeaderValue}, Client};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{cache::Cache, config::Config}; use crate::{config::Config, session::Session};
use crate::{WARN, RED, RESET}; use crate::{WARN, RED, RESET};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -20,6 +20,12 @@ pub struct ChatPayload {
pub messages: Vec<ChatMessagePayload> pub messages: Vec<ChatMessagePayload>
} }
impl ChatMessagePayload {
pub fn is_ai(&self) -> bool {
self.role == "assistant"
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatChunk { pub struct ChatChunk {
pub role: Option<String>, pub role: Option<String>,
@ -30,6 +36,15 @@ pub struct ChatChunk {
pub model: Option<String> pub model: Option<String>
} }
impl Into<ChatMessagePayload> for ChatChunk {
fn into(self) -> ChatMessagePayload {
ChatMessagePayload {
role: "assistant".to_string(),
content: self.message
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrChatChunk { pub struct ErrChatChunk {
pub action: String, pub action: String,
@ -87,10 +102,15 @@ pub async fn get_vqd(cli: &Client) -> Result<String, Box<dyn Error>> {
} }
} }
pub async fn get_res<'a>(cli: &Client, query: String, vqd: String, cache: &'a mut Cache, config: &Config) { pub async fn get_res<'a>(cli: &Client, query: String, vqd: String, config: &Config) {
let init_msg = ChatMessagePayload { role: "user".into(), content: query };
let mut session = Session::create_or_restore(&vqd);
session.push_message(&init_msg).unwrap();
let payload = ChatPayload { let payload = ChatPayload {
model: config.model.to_string(), model: config.model.to_string(),
messages: vec![ ChatMessagePayload { role: "user".into(), content: query } ] messages: session.get_messages()
}; };
let payload = serde_json::to_string(&payload).unwrap(); let payload = serde_json::to_string(&payload).unwrap();
@ -105,15 +125,16 @@ pub async fn get_res<'a>(cli: &Client, query: String, vqd: String, cache: &'a mu
let new_vqd = res.headers().iter().find(|x| x.0 == "x-vqd-4"); let new_vqd = res.headers().iter().find(|x| x.0 == "x-vqd-4");
let vqd_set_res = let vqd_set_res =
if let Some(new_vqd) = new_vqd { if let Some(new_vqd) = new_vqd {
cache.set_last_vqd(new_vqd.1.as_bytes().iter().map(|x| char::from(*x)).collect::<String>()) session.set_last_vqd(new_vqd.1.as_bytes().iter().map(|x| char::from(*x)).collect::<String>())
} else { } else {
eprintln!("{WARN}Warn: DuckDuckGo did not return new VQD. Ignore this if everything else is ok.{RESET}"); eprintln!("{WARN}Warn: DuckDuckGo did not return new VQD. Ignore this if everything else is ok.{RESET}");
cache.set_last_vqd(vqd.clone()) session.set_last_vqd(vqd.clone())
}; };
if let Err(err) = vqd_set_res { if let Err(err) = vqd_set_res {
eprintln!("{WARN}Warn: Could not save VQD to cache: {err}{RESET}"); eprintln!("{WARN}Warn: Could not save VQD to cache: {err}{RESET}");
} }
let mut error = None;
while let Some(chunk) = res.chunk().await.unwrap() { while let Some(chunk) = res.chunk().await.unwrap() {
if let Ok(obj) = serde_json::from_slice::<ErrChatChunk>(&chunk) { if let Ok(obj) = serde_json::from_slice::<ErrChatChunk>(&chunk) {
@ -127,9 +148,16 @@ pub async fn get_res<'a>(cli: &Client, query: String, vqd: String, cache: &'a mu
let chunk = chunk.replace("data: ", ""); let chunk = chunk.replace("data: ", "");
for line in chunk.lines() { for line in chunk.lines() {
if let Ok(obj) = serde_json::from_str::<ChatChunk>(line) { if let Ok(obj) = serde_json::from_str::<ChatChunk>(line) {
if let Err(err) = session.push_ai_message_chunk(&obj) {
error = Some(err);
}
print!("{}", obj.message); print!("{}", obj.message);
} }
} }
} }
if let Some(err) = error {
eprintln!("Error while writing to session: {err:#?}");
eprintln!("Session may be broken.");
}
println!("\n"); println!("\n");
} }

View File

@ -1,81 +0,0 @@
use std::{env, error::Error, fs, io, path::PathBuf};
use home::home_dir;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Cache {
pub last_vqd: String,
pub last_vqd_time: u64
}
impl Default for Cache {
fn default() -> Self {
Self {
last_vqd: "".into(),
last_vqd_time: 0
}
}
}
impl Cache {
pub fn get_path<T: From<String>>() -> T {
match env::var("HEY_CACHE_PATH") {
Ok(v) => v,
Err(_) =>
match home_dir() {
Some(home) => home.join(".cache/hey").as_os_str().as_encoded_bytes().iter().map(|x| char::from(*x)).collect::<String>(),
None => panic!("Cannot detect your home directory!")
}
}.into()
}
pub fn get_file_name<T: From<String>>() -> T {
match env::var("HEY_CACHE_FILENAME") {
Ok(v) => v,
Err(_) => "cache.json".into()
}.into()
}
fn ensure_dir_exists() -> io::Result<()> {
let path: PathBuf = Self::get_path();
if ! path.is_dir() { fs::create_dir_all(path)? }
Ok(())
}
pub fn load() -> Result<Self, Box<dyn Error>> {
let path: PathBuf = Self::get_path();
Self::ensure_dir_exists()?;
let file_path = path.join(Self::get_file_name::<PathBuf>());
if ! file_path.is_file() {
let def = Self::default();
def.save()?;
Ok(def)
} else {
let file = fs::read_to_string(file_path)?;
Ok(serde_json::from_str(&file)?)
}
}
pub fn save(self: &Self) -> Result<(), Box<dyn Error>> {
let path: PathBuf = Self::get_path();
Self::ensure_dir_exists()?;
let file_path = path.join(Self::get_file_name::<PathBuf>());
fs::write(file_path, serde_json::to_string_pretty(self)?)?;
Ok(())
}
pub fn set_last_vqd<T: Into<String>>(self: &mut Self, vqd: T) -> Result<(), Box<dyn Error>> {
self.last_vqd = vqd.into();
self.last_vqd_time = chrono::Local::now().timestamp_millis() as u64;
self.save()?;
Ok(())
}
pub fn get_last_vqd<'a, T: From<&'a String>>(self: &'a Self) -> Option<T> {
None
}
}

View File

@ -3,12 +3,12 @@ use std::process::exit;
use reqwest::Client; use reqwest::Client;
use clap::Parser; use clap::Parser;
use session::Session;
use std::io::{stdout, IsTerminal}; use std::io::{stdout, IsTerminal};
use crate::{cache::Cache, config::Config}; use crate::config::Config;
use crate::api::{get_res, get_vqd, simulate_browser_reqs}; use crate::api::{get_res, get_vqd, simulate_browser_reqs};
mod cache;
mod config; mod config;
mod api; mod api;
mod session; mod session;
@ -40,9 +40,12 @@ async fn main() {
} }
let args = Args::parse(); let args = Args::parse();
let query = args.query.join(" "); let query = args.query.join(" ").trim().to_string();
if query.len() == 0 {
exit(0);
}
let mut cache = Cache::load().unwrap();
let mut config = Config::load().unwrap(); let mut config = Config::load().unwrap();
if args.agree_tos { if args.agree_tos {
@ -59,11 +62,13 @@ async fn main() {
let cli = Client::new(); let cli = Client::new();
simulate_browser_reqs(&cli).await.unwrap(); simulate_browser_reqs(&cli).await.unwrap();
let vqd = match cache.get_last_vqd() { let vqd = match Session::restore_vqd() {
Some(v) => { println!("using cached vqd"); v}, Some(v) => { v },
None => get_vqd(&cli).await.unwrap() None => get_vqd(&cli).await.unwrap()
}; };
get_res(&cli, query, vqd, &mut cache, &config).await; println!("{vqd:?}");
get_res(&cli, query, vqd, &config).await;
} }

View File

@ -6,13 +6,14 @@ use std::time::Duration;
use chrono::{DateTime, Local}; use chrono::{DateTime, Local};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::api::ChatMessagePayload; use crate::api::{ChatChunk, ChatMessagePayload};
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Session { pub struct Session {
session_id: String, session_id: String,
ttl: DateTime<Local>, ttl: DateTime<Local>,
messages: Vec<ChatMessagePayload> messages: Vec<ChatMessagePayload>,
vqd: String
} }
impl Session { impl Session {
@ -66,11 +67,12 @@ impl Session {
} }
match inner(Self::path_for_id(id)) { match inner(Self::path_for_id(id)) {
Ok(session) => { Ok(mut session) => {
if session.is_expired() { if session.is_expired() {
session.destroy().expect("Couldn't destroy expired session"); session.destroy().expect("Couldn't destroy expired session");
None None
} else { } else {
session.increase_ttl().expect("Couldn't increase TTL");
Some(session) Some(session)
} }
}, },
@ -78,16 +80,25 @@ impl Session {
} }
} }
fn new<T: Into<String>>(id: T) -> Self { fn get_ttl() -> DateTime<Local> {
Local::now() + Duration::from_secs(60 * 5)
}
pub fn increase_ttl(&mut self) -> Result<(), Box<dyn std::error::Error>> {
self.ttl = Self::get_ttl();
Ok(self.save()?)
}
fn new<I: Into<String>, V: Into<String>>(id: I, vqd: V) -> Self {
Self { Self {
session_id: id.into(), session_id: id.into(),
ttl: Local::now() + Duration::from_secs(60 * 5), ttl: Self::get_ttl(),
messages: vec![] messages: vec![],
vqd: vqd.into()
} }
} }
fn save(&self) -> Result<(), Box<dyn std::error::Error>> { fn save(&self) -> Result<(), Box<dyn std::error::Error>> {
println!("{:?}", self.path());
fs::write(self.path(), serde_json::to_string_pretty(self)?)?; fs::write(self.path(), serde_json::to_string_pretty(self)?)?;
Ok(()) Ok(())
} }
@ -100,15 +111,43 @@ impl Session {
Ok(fs::remove_file(self.path())?) Ok(fs::remove_file(self.path())?)
} }
pub fn create_or_restore() -> Self { pub fn create_or_restore<T: Into<String>>(vqd: T) -> Self {
let session_id: String = Self::terminal_session_id(); let session_id: String = Self::terminal_session_id();
match Self::restore_with_id(&session_id) { match Self::restore_with_id(&session_id) {
Some(session) => session, Some(session) => session,
None => { None => {
let session = Self::new(&session_id); let session = Self::new(&session_id, vqd);
session.save().expect("Couldn't save new session"); session.save().expect("Couldn't save new session");
session session
} }
} }
} }
pub fn push_message(&mut self, msg: &ChatMessagePayload) -> Result<(), Box<dyn std::error::Error>> {
self.messages.push(msg.clone());
Ok(self.save()?)
}
pub fn push_ai_message_chunk(&mut self, chunk: &ChatChunk) -> Result<(), Box<dyn std::error::Error>> {
if self.messages.last().unwrap().is_ai() {
self.messages.last_mut().unwrap().content += chunk.message.as_str();
} else {
self.messages.push(chunk.clone().into());
}
Ok(self.save()?)
}
pub fn get_messages(&self) -> Vec<ChatMessagePayload> {
self.messages.clone()
}
pub fn restore_vqd() -> Option<String> {
Self::restore_with_id(Self::terminal_session_id::<String>()).map(|x| x.vqd)
}
pub fn set_last_vqd<T: Into<String>>(&mut self, vqd: T) -> Result<(), Box<dyn std::error::Error>> {
self.vqd = vqd.into();
Ok(self.save()?)
}
} }