diff --git a/Cargo.lock b/Cargo.lock index 480d762..82e30a4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -335,6 +335,15 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "envy" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f47e0157f2cb54f5ae1bd371b30a2ae4311e1c028f575cd4e81de7353215965" +dependencies = [ + "serde", +] + [[package]] name = "eventsource-stream" version = "0.2.3" @@ -609,11 +618,14 @@ dependencies = [ "anyhow", "async-openai", "dotenvy", + "envy", "once_cell", + "serde", "serenity", "tokio", "tracing", "tracing-subscriber", + "typed-builder", ] [[package]] @@ -1118,9 +1130,9 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.188" +version = "1.0.189" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" +checksum = "8e422a44e74ad4001bdc8eede9a4570ab52f71190e9c076d14369f38b9200537" dependencies = [ "serde_derive", ] @@ -1137,9 +1149,9 @@ dependencies = [ [[package]] name = "serde_derive" -version = "1.0.188" +version = "1.0.189" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" +checksum = "1e48d1f918009ce3145511378cf68d613e3b3d9137d67272562080d68a2b32d5" dependencies = [ "proc-macro2", "quote", @@ -1527,6 +1539,26 @@ dependencies = [ "webpki", ] +[[package]] +name = "typed-builder" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e47c0496149861b7c95198088cbf36645016b1a0734cf350c50e2a38e070f38a" +dependencies = [ + "typed-builder-macro", +] + +[[package]] +name = "typed-builder-macro" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "982ee4197351b5c9782847ef5ec1fdcaf50503fb19d68f9771adae314e72b492" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.29", +] + [[package]] name = "typemap_rev" version = "0.1.5" diff --git a/Cargo.toml b/Cargo.toml index 86ed4f0..184a011 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,12 +12,18 @@ dotenvy = { version = "0.15" } once_cell = { version = "1.18" } tracing = { version = "0.1" } tracing-subscriber = { version = "0.3" } +envy = { version = "0.4.2" } +typed-builder = { version = "0.18.0" } [dependencies.serenity] version = "0.11" features = ["client", "gateway", "model", "cache", "rustls_backend"] default-features = false +[dependencies.serde] +version = "1.0.189" +features = ["derive"] + [dependencies.tokio] version = "1.33" features = ["macros", "rt-multi-thread", "time"] diff --git a/README.md b/README.md index aa67b46..7eb141e 100644 --- a/README.md +++ b/README.md @@ -24,9 +24,9 @@ docker pull ghcr.io/approvers/ichiyo_ai:vX.Y.Z 設定の例は [.env.example](./.env.example) で確認できます。 -| Key | Description | Default | -|---------------------|-------------------|---------| -| `DISCORD_API_TOKEN` | Discord API のトークン | - | -| `OPENAI_API_KEY` | OpenAI API のトークン | - | -| `GUILD_ID` | 限界開発鯖の ID | - | -| `SUBSCRIBER_ROLE_ID` | 購読者ロールの ID | - | +| Key | Description | required | +|---------------------|-------------------|----------| +| `DISCORD_API_TOKEN` | Discord API のトークン | `Yes` | +| `OPENAI_API_KEY` | OpenAI API のトークン | `Yes` | +| `GUILD_ID` | 限界開発鯖の ID | `Yes` | +| `TAXPAYER_ROLE_ID` | 購読者ロールの ID | `No` | diff --git a/src/adapters/chatgpt.rs b/src/adapters/chatgpt.rs new file mode 100644 index 0000000..f07d3b1 --- /dev/null +++ b/src/adapters/chatgpt.rs @@ -0,0 +1,52 @@ +use crate::model::chatgpt::{RequestMessageModel, ResponseCompletionResultModel}; +use anyhow::Context; +use async_openai::config::OpenAIConfig; +use async_openai::types::CreateChatCompletionRequestArgs; +use async_openai::Client; +use std::time::Duration; +use tokio::time::timeout; + +pub static SYSTEM_CONTEXT: &str = "回答時は以下のルールに従うこと.\n- 1900文字以内に収めること。"; +static TIMEOUT_DURATION: Duration = Duration::from_secs(180); + +async fn create_chatgpt_client() -> anyhow::Result> { + Ok(Client::new()) +} + +pub async fn request_chatgpt_message( + request: RequestMessageModel, +) -> anyhow::Result { + let client = create_chatgpt_client().await?; + + let client_request = CreateChatCompletionRequestArgs::default() + .max_tokens(512u16) + .model(request.model) + .messages(request.replies) + .build()?; + + let response = timeout(TIMEOUT_DURATION, client.chat().create(client_request)) + .await + .context("Timeout. Please try again.")??; + + let choice = response + .choices + .get(0) + .context("No response message found.")?; + let (input_token, output_token, total_token) = response + .usage + .map(|usage| { + ( + usage.prompt_tokens, + usage.completion_tokens, + usage.total_tokens, + ) + }) + .unwrap_or_default(); + + Ok(ResponseCompletionResultModel::builder() + .response_message(choice.message.content.clone().unwrap_or_default()) + .input_token(input_token) + .output_token(output_token) + .total_token(total_token) + .build()) +} diff --git a/src/adapters/discord.rs b/src/adapters/discord.rs new file mode 100644 index 0000000..3009a79 --- /dev/null +++ b/src/adapters/discord.rs @@ -0,0 +1,23 @@ +use crate::model::chatgpt::{usage_pricing, ResponseCompletionResultModel}; +use crate::model::discord::DiscordReplyMessageModel; +use anyhow::Context; + +pub async fn reply_completion_result( + reply_message: DiscordReplyMessageModel, +) -> anyhow::Result<()> { + reply_message + .target_message + .reply_ping(reply_message.http, reply_message.formatted_result) + .await + .context("Failed to reply.")?; + + Ok(()) +} + +pub fn format_result(result: ResponseCompletionResultModel, model: &str) -> String { + let pricing = usage_pricing(result.input_token, result.output_token, model); + format!( + "{}\n\n`利用料金: ¥{:.2}` - `合計トークン: {}` - `使用モデル: {}`", + result.response_message, pricing, result.total_token, model + ) +} diff --git a/src/adapters/mod.rs b/src/adapters/mod.rs new file mode 100644 index 0000000..ae3cbef --- /dev/null +++ b/src/adapters/mod.rs @@ -0,0 +1,3 @@ +pub mod chatgpt; + +pub mod discord; diff --git a/src/client/discord.rs b/src/client/discord.rs index 0591603..58a879f 100644 --- a/src/client/discord.rs +++ b/src/client/discord.rs @@ -1,19 +1,13 @@ +use crate::model::EvHandler; use anyhow::Context; use serenity::{prelude::GatewayIntents, Client}; -pub struct EvHandler; - -pub async fn start_discord_client(token: &str) -> anyhow::Result<()> { - // メッセージ内容の取得とギルドメッセージの取得を有効化 +pub async fn create_discord_client(token: &str) -> anyhow::Result { let intents = GatewayIntents::GUILD_MESSAGES | GatewayIntents::MESSAGE_CONTENT; - - let mut client = Client::builder(token, intents) + let client = Client::builder(token, intents) .event_handler(EvHandler) .await - .context("クライアントの作成に失敗しました.")?; + .context("Failed to create discord client")?; - client - .start() - .await - .context("クライアントの起動に失敗しました.") + Ok(client) } diff --git a/src/client/mod.rs b/src/client/mod.rs index 94cc9c5..3228a62 100644 --- a/src/client/mod.rs +++ b/src/client/mod.rs @@ -1,2 +1 @@ pub mod discord; -pub mod openai; diff --git a/src/client/openai.rs b/src/client/openai.rs deleted file mode 100644 index 7633c12..0000000 --- a/src/client/openai.rs +++ /dev/null @@ -1,151 +0,0 @@ -use crate::model::{MessageCompletionResult, ReplyMessage, ReplyRole}; -use anyhow::{Context, Ok}; -use async_openai::{ - types::{ChatCompletionRequestMessageArgs, CreateChatCompletionRequestArgs, Role}, - Client, -}; -use std::time::Duration; -use tokio::time::timeout; - -static TIMEOUT_DURATION: Duration = Duration::from_secs(180); - -// 会話モード・返信モード で使用するシステムコンテキスト。膨大なレスポンスにならないように抑える目的に使用する。 -// レスポンスの後にメタ情報(利用料金表示など)を含めるため、100字分の余裕を設けている。 -static SYSTEM_CONTEXT: &str = "回答時は以下のルールに従うこと.\n- 1900文字以内に収めること。\n- なるべく簡潔に言うこと。\n- 一般的に知られている単語は説明しない。"; - -/// ChatGPT に対してメッセージを送信し、レスポンスをリクエストします。 -/// -/// ### 引数 -/// * `request_message` -- ChatGPT に送信するメッセージ。[ReplyMessages] を実装しておく必要がある。 -/// * `model` -- -/// 使用する ChatGPT のモデルを使用する。使用できるモデルは [&str] で定義されている物のみ。 -/// 指定しない場合([None])は [&str::Gpt35Turbo] が使用される。 -/// ### 返り値 -/// [String]: ChatGPT からのレスポンス -/// -/// ### エラー -/// 下記条件でエラーが報告されます。 -/// * ChatGPT とのやり取りに失敗する -/// * 2000文字を超過する -pub async fn request_message( - request_message: &[ReplyMessage], - model: &str, -) -> anyhow::Result { - let client = Client::new(); - - let mut messages = vec![ChatCompletionRequestMessageArgs::default() - .role(Role::System) - .content(SYSTEM_CONTEXT) - .build()?]; - let history = request_message - .iter() - .map(|reply| { - ChatCompletionRequestMessageArgs::default() - .role(Role::User) - .content(reply.content.clone()) - .build() - }) - .collect::, _>>()?; - messages.extend(history); - - let request = CreateChatCompletionRequestArgs::default() - .model(model) - .messages(messages) - .build()?; - - let response = timeout(TIMEOUT_DURATION, client.chat().create(request)) - .await - .context("タイムアウトしました, もう一度お試しください.")??; - - let choice = response - .choices - .get(0) - .context("response message not found")?; - let (input_token, output_token, total_token) = response - .usage - .map(|usage| { - ( - usage.prompt_tokens, - usage.completion_tokens, - usage.total_tokens, - ) - }) - .unwrap_or_default(); - - let result = MessageCompletionResult { - message: choice.message.content.clone().unwrap_or_default(), - input_token, - output_token, - total_token, - }; - - Ok(result) -} - -/// ChatGPT に対して一連の会話コンテキストを送信し、レスポンスをリクエストします。 -/// -/// ### 引数 -/// * `reply_messages` -- ChatGPT に送信する会話コンテキスト。[ReplyMessages] を実装しておく必要がある。 -/// * `model` -- -/// 使用する ChatGPT のモデルを使用する。使用できるモデルは [&str] で定義されている物のみ。 -/// 指定しない場合([None])は [&str::Gpt35Turbo] が使用される。 -/// -/// ### 返り値 -/// [String]: ChatGPT からのレスポンス -pub async fn request_reply_message( - reply_messages: &[ReplyMessage], - model: &str, -) -> anyhow::Result { - let client = Client::new(); - - let mut messages = vec![ChatCompletionRequestMessageArgs::default() - .role(Role::System) - .content(SYSTEM_CONTEXT) - .build()?]; - let history = reply_messages - .iter() - .map(|reply| { - ChatCompletionRequestMessageArgs::default() - .role(match reply.role { - ReplyRole::Ichiyo => Role::Assistant, - ReplyRole::User => Role::User, - }) - .content(reply.content.clone()) - .build() - }) - .collect::, _>>()?; - messages.extend(history); - - let request = CreateChatCompletionRequestArgs::default() - .model(model) - .messages(messages) - .build()?; - - let response = timeout(TIMEOUT_DURATION, client.chat().create(request)) - .await - .context("タイムアウトしました, もう一度お試しください.")??; - - let choice = response - .choices - .get(0) - .context("response message not found")?; - let (input_token, output_token, total_token) = response - .usage - .map(|usage| { - ( - usage.prompt_tokens, - usage.completion_tokens, - usage.total_tokens, - ) - }) - .unwrap_or_default(); - - let result = MessageCompletionResult { - message: choice.message.content.clone().unwrap_or_default(), - input_token, - output_token, - total_token, - }; - - Ok(result) -} diff --git a/src/env.rs b/src/env.rs deleted file mode 100644 index 9fa1c01..0000000 --- a/src/env.rs +++ /dev/null @@ -1,13 +0,0 @@ -use dotenvy::dotenv; -use std::env; - -pub fn load_env() { - dotenv().ok(); -} - -pub fn get_env(key: &str) -> String { - match env::var(key) { - Ok(val) => val, - Err(e) => panic!("{}: {}", e, key), - } -} diff --git a/src/event.rs b/src/event.rs index b882001..0d55f79 100644 --- a/src/event.rs +++ b/src/event.rs @@ -1,98 +1,150 @@ -use crate::client::discord::EvHandler; -use crate::env::get_env; -use crate::service::chat::chat_mode; -use crate::service::reply::reply_mode; -use once_cell::sync::Lazy; +use crate::adapters::chatgpt::{request_chatgpt_message, SYSTEM_CONTEXT}; +use crate::adapters::discord::{format_result, reply_completion_result}; +use crate::model::chatgpt::RequestMessageModel; +use crate::model::discord::DiscordReplyMessageModel; +use crate::model::env::ICHIYOAI_ENV; +use crate::model::EvHandler; +use async_openai::types::{ChatCompletionRequestMessage, ChatCompletionRequestMessageArgs, Role}; +use once_cell::sync::OnceCell; use serenity::async_trait; use serenity::client::Context; use serenity::http::{Http, Typing}; -use serenity::model::channel::Message; +use serenity::model::channel::{Message, MessageType}; use serenity::model::gateway::Ready; -use serenity::model::id::ChannelId; -use serenity::model::prelude::{Activity, GuildId, MessageType, RoleId}; +use serenity::model::id::{ChannelId, GuildId, RoleId}; +use serenity::model::prelude::Activity; use serenity::prelude::EventHandler; use std::sync::Arc; use tracing::log::{error, info}; -static VERSION: &str = env!("CARGO_PKG_VERSION"); -static GUILD_ID: Lazy = Lazy::new(|| get_env("GUILD_ID").parse().unwrap()); -static SUBSCRIPTION_ROLE_ID: Lazy = - Lazy::new(|| get_env("SUBSCRIPTION_ROLE_ID").parse().unwrap()); - #[async_trait] impl EventHandler for EvHandler { - async fn message(&self, ctx: Context, new_msg: Message) { - if new_msg.author.bot || new_msg.is_private() { + async fn message(&self, ctx: Context, message: Message) { + if message.author.bot || message.is_private() { return; } - if let Ok(false) = new_msg.mentions_me(&ctx).await { + if let Ok(false) = message.mentions_me(&ctx).await { return; } - let http = ctx.clone().http; - let channel_id = new_msg.channel_id; - - info!("{sender}: 会話を開始します.", sender = new_msg.author.name); - let typing = start_typing(http, channel_id); - - let is_subscriber = new_msg - .author - .has_role(&ctx, GuildId(*GUILD_ID), RoleId(*SUBSCRIPTION_ROLE_ID)) - .await - .unwrap(); - let model = if is_subscriber { - "gpt-4" - } else { - "gpt-3.5-turbo" - }; - - match new_msg.kind { - // 通常メッセージ (チャットモード) - MessageType::Regular => { - if let Err(why) = chat_mode(&ctx, &new_msg, model).await { - let _ = new_msg - .reply_ping( - &ctx, - &format!("エラーが発生しました. \n```{error}\n```", error = why), - ) - .await; - error!("{:?}", why) - } - } - // 返信 (リプライモード) - MessageType::InlineReply => { - if let Err(why) = reply_mode(&ctx, &new_msg, model).await { - let _ = new_msg - .reply_ping( - &ctx, - &format!("エラーが発生しました.\n```{error}\n```", error = why), - ) - .await; - error!("{:?}", why) - } - } - _ => (), + match process_ichiyoai(ctx, message).await { + Ok(()) => (), + Err(why) => error!("Processing message failed with error: {}", why), } - - typing.stop(); - info!( - "{sender}: 会話を完了させました.", - sender = new_msg.author.name - ) } async fn ready(&self, ctx: Context, self_bot: Ready) { - ctx.set_activity(Activity::playing(&format!("v{}", VERSION))) + info!("Starting..."); + + let version = &ICHIYOAI_ENV.get().unwrap().cargo_pkg_version; + ctx.set_activity(Activity::playing(&format!("v{}", version))) .await; + info!("Running ichiyoAI v{}", version); info!( - "{username}(ID: {userid}) に接続しました! - ichiyoAI v{version} を使用しています.", - username = self_bot.user.name, - userid = self_bot.user.id, - version = VERSION + "Connected!: {name}(Id:{id})", + name = self_bot.user.name, + id = self_bot.user.id + ) + } +} + +async fn process_ichiyoai(ctx: Context, message: Message) -> anyhow::Result<()> { + static OWN_MENTION: OnceCell = OnceCell::new(); + let channel_id = message.channel_id; + let mention = OWN_MENTION.get_or_init(|| format!("<@{}>", ctx.cache.current_user_id())); + let mut content = message.content.replace(mention, "").trim().to_string(); + + if content.chars().count() < 5 { + return Err(anyhow::anyhow!( + "Message is too short. Please enter at least 5 characters." + )); + } + + let typing = start_typing(ctx.http.clone(), channel_id); + let is_subscriber = message + .author + .has_role( + &ctx, + GuildId(ICHIYOAI_ENV.get().unwrap().guild_id), + RoleId(ICHIYOAI_ENV.get().unwrap().taxpayer_role_id), ) + .await + .unwrap_or(false); + let model = if is_subscriber { + "gpt-4".to_string() + } else { + "gpt-3.5-turbo".to_string() + }; + let mut replies: Vec = + vec![ChatCompletionRequestMessageArgs::default() + .role(Role::System) + .content(SYSTEM_CONTEXT) + .build()?]; + + match message.kind { + MessageType::Regular => { + let reply = ChatCompletionRequestMessageArgs::default() + .role(Role::User) + .content(content) + .build()?; + replies.push(reply); + } + MessageType::InlineReply => { + let mut target_message_id = message.referenced_message.as_ref().map(|m| m.id); + while let Some(message_id) = target_message_id { + let message = ctx + .http + .clone() + .get_message(channel_id.0, message_id.0) + .await?; + + let role = if message.is_own(ctx.clone()) { + Role::Assistant + } else { + Role::User + }; + + if role == Role::Assistant { + let len = content.rfind("\n\n").unwrap_or(content.len()); + content.truncate(len); + } + + let content = message.content.replace(mention, "").trim().to_string(); + + let reply = ChatCompletionRequestMessageArgs::default() + .role(role) + .content(content) + .build()?; + + replies.push(reply); + target_message_id = message.referenced_message.map(|m| m.id); + } + } + _ => (), + } + + replies.reverse(); + + let request = RequestMessageModel::builder() + .replies(replies) + .model(model.clone()) + .build(); + let result = request_chatgpt_message(request).await?; + + let reply = DiscordReplyMessageModel::builder() + .http(ctx.http.clone()) + .target_message(message.clone()) + .formatted_result(format_result(result, &model)) + .build(); + + if let Err(why) = reply_completion_result(reply).await { + error!("{}", why) } + typing.stop(); + + Ok(()) } fn start_typing(http: Arc, target_channel_id: ChannelId) -> Typing { diff --git a/src/main.rs b/src/main.rs index 9515a4c..6a0a07e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,21 +1,27 @@ -use crate::env::load_env; -use client::discord::start_discord_client; -use env::get_env; +use crate::client::discord::create_discord_client; +use crate::model::env::{IchiyoAiEnv, ICHIYOAI_ENV}; +use dotenvy::dotenv; +use tracing::log::error; +mod adapters; mod client; -mod env; mod event; mod model; -mod service; #[tokio::main] async fn main() -> anyhow::Result<()> { - load_env(); + dotenv().ok(); tracing_subscriber::fmt().compact().init(); - let token = get_env("DISCORD_API_TOKEN"); + ICHIYOAI_ENV + .set(envy::from_env::().expect("Failed to load enviroment variables")) + .unwrap(); - start_discord_client(token.as_str()).await?; + let mut client = create_discord_client(&ICHIYOAI_ENV.get().unwrap().discord_api_token).await?; + + if let Err(why) = client.start().await { + error!("Failed to starting ichiyoAI: {}", why) + } Ok(()) } diff --git a/src/model.rs b/src/model.rs deleted file mode 100644 index b6b27b3..0000000 --- a/src/model.rs +++ /dev/null @@ -1,19 +0,0 @@ -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ReplyRole { - Ichiyo, - User, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct ReplyMessage { - pub role: ReplyRole, - pub content: String, -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct MessageCompletionResult { - pub message: String, - pub input_token: u32, - pub output_token: u32, - pub total_token: u32, -} diff --git a/src/service/pricing.rs b/src/model/chatgpt.rs similarity index 66% rename from src/service/pricing.rs rename to src/model/chatgpt.rs index 1c818b9..90e6267 100644 --- a/src/service/pricing.rs +++ b/src/model/chatgpt.rs @@ -1,3 +1,20 @@ +use async_openai::types::ChatCompletionRequestMessage; +use typed_builder::TypedBuilder; + +#[derive(TypedBuilder)] +pub struct RequestMessageModel { + pub replies: Vec, + pub model: String, +} + +#[derive(TypedBuilder)] +pub struct ResponseCompletionResultModel { + pub response_message: String, + pub input_token: u32, + pub output_token: u32, + pub total_token: u32, +} + // 桁落ちを防ぐため、10,000,000倍して計算する const SCALE: f32 = 10_000_000.0; @@ -13,7 +30,7 @@ const GPT4_JPY_PER_OUTPUT_TOKEN: u32 = (0.06 * EXCHANGE_RATE * SCALE / 1000.0) a pub fn usage_pricing(input_token: u32, output_token: u32, model: &str) -> f32 { let (input_rate, output_rate) = match model { "gpt-3.5-turbo" => (GPT3_5_JPY_PER_INPUT_TOKEN, GPT3_5_JPY_PER_OUTPUT_TOKEN), - "gtp-4" => (GPT4_JPY_PER_INPUT_TOKEN, GPT4_JPY_PER_OUTPUT_TOKEN), + "gpt-4" => (GPT4_JPY_PER_INPUT_TOKEN, GPT4_JPY_PER_OUTPUT_TOKEN), _ => panic!("Invalid model: {:?}", model), }; diff --git a/src/model/discord.rs b/src/model/discord.rs new file mode 100644 index 0000000..87f36a9 --- /dev/null +++ b/src/model/discord.rs @@ -0,0 +1,18 @@ +use serenity::http::Http; +use serenity::model::prelude::{ChannelId, Message, MessageId}; +use std::sync::Arc; +use typed_builder::TypedBuilder; + +#[derive(TypedBuilder)] +pub struct DiscordMessageModel { + pub content: String, + pub message_channel_id: ChannelId, + pub reference_message_id: Option, +} + +#[derive(TypedBuilder)] +pub struct DiscordReplyMessageModel { + pub http: Arc, + pub target_message: Message, + pub formatted_result: String, +} diff --git a/src/model/env.rs b/src/model/env.rs new file mode 100644 index 0000000..35140f9 --- /dev/null +++ b/src/model/env.rs @@ -0,0 +1,13 @@ +use once_cell::sync::OnceCell; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug)] +pub struct IchiyoAiEnv { + pub discord_api_token: String, + pub openai_api_key: String, + pub cargo_pkg_version: String, + pub guild_id: u64, + pub taxpayer_role_id: u64, +} + +pub static ICHIYOAI_ENV: OnceCell = OnceCell::new(); diff --git a/src/model/mod.rs b/src/model/mod.rs new file mode 100644 index 0000000..4d8bcc0 --- /dev/null +++ b/src/model/mod.rs @@ -0,0 +1,7 @@ +pub struct EvHandler; + +pub mod env; + +pub mod chatgpt; + +pub mod discord; diff --git a/src/service/chat.rs b/src/service/chat.rs deleted file mode 100644 index 205ce45..0000000 --- a/src/service/chat.rs +++ /dev/null @@ -1,45 +0,0 @@ -use crate::client::openai::request_message; -use crate::model::{MessageCompletionResult, ReplyMessage, ReplyRole}; -use anyhow::Ok; -use once_cell::sync::OnceCell; -use serenity::model::prelude::Message; -use serenity::prelude::Context; - -use super::pricing::usage_pricing; - -pub async fn chat_mode(ctx: &Context, msg: &Message, model: &str) -> anyhow::Result<()> { - let reply = get_reply(ctx, msg).await?; - let result = request_message(&reply, model).await?; - - msg.reply_ping(ctx, format_result(result, model)).await?; - - Ok(()) -} - -async fn get_reply(ctx: &Context, msg: &Message) -> anyhow::Result> { - static OWN_MENTION: OnceCell = OnceCell::new(); - let mention = OWN_MENTION.get_or_init(|| format!("<@{}>", ctx.cache.current_user_id())); - let content = msg.content.replace(mention, "").trim().to_string(); - - if content.chars().count() < 5 { - return Err(anyhow::anyhow!( - "送信メッセージが短すぎます。5文字以上入力してください。" - )); - } - - let mut replies: Vec = vec![ReplyMessage { - role: ReplyRole::User, - content, - }]; - - replies.reverse(); - Ok(replies) -} - -fn format_result(result: MessageCompletionResult, model: &str) -> String { - let pricing = usage_pricing(result.input_token, result.output_token, model); - format!( - "{}\n\n`利用料金: ¥{:.2}(合計トークン: {})` - `使用モデル: {}`", - result.message, pricing, result.total_token, model - ) -} diff --git a/src/service/mod.rs b/src/service/mod.rs deleted file mode 100644 index 2d1128f..0000000 --- a/src/service/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -pub mod chat; -pub mod pricing; -pub mod reply; diff --git a/src/service/reply.rs b/src/service/reply.rs deleted file mode 100644 index fafd9f5..0000000 --- a/src/service/reply.rs +++ /dev/null @@ -1,74 +0,0 @@ -use crate::client::openai::request_reply_message; -use crate::model::{MessageCompletionResult, ReplyMessage, ReplyRole}; -use once_cell::sync::OnceCell; -use serenity::model::prelude::Message; -use serenity::prelude::Context; - -use super::pricing::usage_pricing; - -pub async fn reply_mode(ctx: &Context, msg: &Message, model: &str) -> anyhow::Result<()> { - let replies = get_replies(ctx, msg).await?; - // notes: GPT-4 があまりにも高いため、GPT-3.5 に revert - let result = request_reply_message(&replies, model).await?; - - msg.reply_ping(ctx, format_result(result, model)).await?; - - Ok(()) -} - -async fn get_replies(ctx: &Context, msg: &Message) -> anyhow::Result> { - static OWN_MENTION: OnceCell = OnceCell::new(); - let mention = OWN_MENTION.get_or_init(|| format!("<@{}>", ctx.cache.current_user_id())); - - let mut replies: Vec = vec![ReplyMessage { - role: ReplyRole::User, - content: msg.content.clone(), - }]; - - // TODO: イテレータにしたい - let channel_id = msg.channel_id; - let mut target_message_id = msg.referenced_message.as_ref().map(|m| m.id); - while let Some(message_id) = target_message_id { - // `.referenced_message`は直近のメッセージしかSomeでは無いため,`.get_message`でメッセージを取得している. - let message = ctx.http.get_message(channel_id.0, message_id.0).await?; - - let role = if message.is_own(ctx) { - ReplyRole::Ichiyo - } else { - ReplyRole::User - }; - - let mut content = msg.content.replace(mention, "").trim().to_string(); - - if content.chars().count() < 5 { - return Err(anyhow::anyhow!( - "送信メッセージが短すぎます。5文字以上入力してください。" - )); - } - - // 一葉のメッセージの場合、最後の値段表示を削除する - if role == ReplyRole::Ichiyo { - let len = content.rfind("\n\n").unwrap_or(content.len()); - content.truncate(len); - } - - let reply = ReplyMessage { role, content }; - - replies.push(reply); - - target_message_id = message.referenced_message.map(|m| m.id); - } - - replies.reverse(); - Ok(replies) -} - -// chatにあるものと同一だが、変更の可能性が高いためあえて共通化しない -// TODO (MikuroXina): 同じ概念をフォーマットするものであるため、共通化すべき -fn format_result(result: MessageCompletionResult, model: &str) -> String { - let pricing = usage_pricing(result.input_token, result.output_token, model); - format!( - "{}\n\n`累計利用料金: ¥{:.2}(合計トークン: {})` - `使用モデル: {}`", - result.message, pricing, result.total_token, model - ) -}