diff --git a/Cargo.lock b/Cargo.lock index a854ab8..3bc25ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -401,6 +401,7 @@ dependencies = [ "serde", "serde_json", "tokio", + "toml", ] [[package]] @@ -993,6 +994,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_spanned" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1" +dependencies = [ + "serde", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1242,6 +1252,40 @@ dependencies = [ "tracing", ] +[[package]] +name = "toml" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9dd1545e8208b4a5af1aa9bbd0b4cf7e9ea08fabc5d0a5c67fcaafa17433aa3" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e40bb779c5187258fd7aad0eb68cb8706a0a81fa712fbea808ab43c4b8374c4" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + [[package]] name = "tower" version = "0.4.13" @@ -1610,6 +1654,15 @@ version = "0.52.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32b752e52a2da0ddfbdbcc6fceadfeede4c939ed16d13e648833a61dfb611ed8" +[[package]] +name = "winnow" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c976aaaa0e1f90dbb21e9587cdaf1d9679a1cde8875c0d6bd83ab96a208352" +dependencies = [ + "memchr", +] + [[package]] name = "winreg" version = "0.52.0" diff --git a/Cargo.toml b/Cargo.toml index b06b362..fce5a25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,3 +15,4 @@ reqwest = "0.12.3" serde = { version = "1.0.197", features = ["derive"] } serde_json = "1.0.115" tokio = { version = "1.37.0", features = ["full"] } +toml = "0.8.12" diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..2c301d6 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,73 @@ +use std::{env, error::Error, fs, io, path::PathBuf}; + +use home::home_dir; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum Model { + #[serde(rename = "claude-instant-1.2")] + Claude12, + #[serde(rename = "gpt-3.5-turbo-0125")] + GPT35, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + pub model: Model, + pub tos: bool +} + +impl Default for Config { + fn default() -> Self { + Self { + model: Model::Claude12, + tos: false + } + } +} + +impl Config { + pub fn get_path>() -> T { + match env::var("HEY_CONFIG_PATH") { + Ok(v) => v, + Err(_) => + match home_dir() { + Some(home) => home.join(".config/hey").as_os_str().as_encoded_bytes().iter().map(|x| char::from(*x)).collect::(), + None => panic!("Cannot detect your home directory!") + } + }.into() + } + + pub fn get_file_name>() -> T { + match env::var("HEY_CONFIG_FILENAME") { + Ok(v) => v, + Err(_) => "conf.toml".into() + }.into() + } + + fn ensure_dir_exists() -> io::Result<()> { + let path: PathBuf = Self::get_path(); + if ! path.is_dir() { fs::create_dir_all(path)? } + Ok(()) + } + + pub fn save(self: &Self) -> Result<(), Box> { + let path = Self::get_path::(); + Self::ensure_dir_exists()?; + + let file_path = path.join(Self::get_file_name::()); + fs::write(file_path, toml::to_string_pretty(self)?)?; + Ok(()) + } + + pub fn load() -> Result> { + let path = Self::get_path::(); + + let file_path = path.join(Self::get_file_name::()); + if ! file_path.is_file() { + Ok(Self::default()) + } else { + Ok(toml::from_str(&fs::read_to_string(file_path)?)?) + } + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 868d569..5335c4c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,17 +1,19 @@ -use std::{error::Error, process::exit}; +use std::{error::Error, path::PathBuf, process::exit}; use reqwest::{header::{HeaderMap, HeaderValue}, Client}; use serde::{Deserialize, Serialize}; use clap::Parser; +use std::io::{stdout, IsTerminal}; -use crate::cache::Cache; +use crate::{cache::Cache, config::Config}; mod cache; +mod config; const GREEN: &str = "\x1b[1;32m"; const RED: &str = "\x1b[1;31m"; -const WARN: &str = "\x1b[1;34m"; +const WARN: &str = "\x1b[33m"; const RESET: &str = "\x1b[0m"; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -145,22 +147,46 @@ async fn get_res<'a>(cli: &Client, query: String, vqd: String, cache: &'a mut Ca #[derive(Debug, Parser)] #[command(version, about, long_about = None)] -#[clap(trailing_var_arg=true)] struct Args { + #[arg(long, default_value = "false", required = false, help = "If you want to agree to the DuckDuckGo TOS")] + pub agree_tos: bool, #[arg()] - pub query: Vec + pub query: Vec, } #[tokio::main] async fn main() { femme::start(); + + if ! stdout().is_terminal() { + eprintln!("{RED}Refusing to run in a non-terminal environment{RESET}"); + eprintln!("This is done to prevent API scraping."); + exit(2) + } let args = Args::parse(); let query = args.query.join(" "); - println!("{GREEN}Contacting DuckDuckGo chat AI...{RESET}"); - let mut cache = Cache::load().unwrap(); + let mut config = Config::load().unwrap(); + + if args.agree_tos { + if ! config.tos { + println!("{GREEN}TOS accepted{RESET}"); + } + config.tos = true; + config.save().expect("Error saving config"); + } + + if ! config.tos { + eprintln!("{RED}You need to agree to duckduckgo AI chat TOS to continue.{RESET}"); + eprintln!("{RED}Visit it on this URL: https://duckduckgo.com/?q=duckduckgo+ai+chat&ia=chat{RESET}"); + eprintln!("Once you read it, pass --agree-tos parameter to agree."); + eprintln!("{WARN}Note: if you want to, modify `tos` parameter in {}{RESET}", Config::get_path::().join(Config::get_file_name::()).display()); + exit(3); + } + + println!("{GREEN}Contacting DuckDuckGo chat AI...{RESET}"); let cli = Client::new(); simulate_browser_reqs(&cli).await.unwrap();