Skip to content

Commit

Permalink
Implement command responses (WIP)
Browse files Browse the repository at this point in the history
It freezes
  • Loading branch information
lptr committed Aug 10, 2024
1 parent 9c2d820 commit b284644
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 28 deletions.
45 changes: 23 additions & 22 deletions src/kernel/command.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
use serde::de::DeserializeOwned;
use anyhow::{anyhow, Result};
use serde::{de::DeserializeOwned, Serialize};
use std::{collections::HashMap, sync::Arc};

trait CommandHandler {
fn execute(&self, payload: &str);
fn execute(&self, payload: &str) -> Result<Option<String>>;
}

struct Command<T> {
delegate: Arc<dyn Fn(T) + Send + Sync>,
struct Command<I, O> {
delegate: Arc<dyn Fn(I) -> Result<Option<O>> + Send + Sync>,
}

impl<T> CommandHandler for Command<T>
impl<I, O> CommandHandler for Command<I, O>
where
T: DeserializeOwned + 'static,
I: DeserializeOwned + 'static,
O: Serialize + 'static,
{
fn execute(&self, payload: &str) {
match serde_json::from_str::<T>(payload) {
Ok(data) => {
(self.delegate)(data);
}
Err(e) => {
eprintln!("Failed to deserialize payload: {}", e);
}
}
fn execute(&self, payload: &str) -> Result<Option<String>> {
let data = serde_json::from_str(payload)?;
let result = (self.delegate)(data)?;
Ok(result
.map(|result| serde_json::to_string(&result))
.transpose()?)
}
}

Expand All @@ -36,10 +35,11 @@ impl CommandManager {
}
}

pub fn register<T, F>(&mut self, path: &str, command: F)
pub fn register<I, O, F>(&mut self, path: &str, command: F)
where
T: DeserializeOwned + 'static,
F: Fn(T) + Send + Sync + 'static,
I: DeserializeOwned + 'static,
O: Serialize + 'static,
F: Fn(I) -> Result<Option<O>> + Send + Sync + 'static,
{
self.commands.insert(
path.to_string(),
Expand All @@ -49,11 +49,12 @@ impl CommandManager {
);
}

pub fn handle(&self, command: &str, payload: &str) {
if let Some(command) = self.commands.get(command) {
command.execute(payload);
pub fn handle(&self, command_name: &str, payload: &str) -> Result<Option<String>> {
if let Some(command) = self.commands.get(command_name) {
let response = command.execute(payload)?;
Ok(response)
} else {
eprintln!("Unregistered registered: {}", command);
Err(anyhow!("Unregistered registered: {}", command_name))
}
}
}
22 changes: 20 additions & 2 deletions src/kernel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod mqtt;
mod rtc;
mod wifi;

use anyhow::anyhow;
use anyhow::Result;
use embassy_futures::join::join;
use embassy_sync::blocking_mutex::raw::CriticalSectionRawMutex;
Expand All @@ -16,8 +17,10 @@ use esp_idf_svc::wifi::AsyncWifi;
use esp_idf_svc::wifi::EspWifi;
use esp_idf_svc::{eventloop::EspSystemEventLoop, nvs::EspDefaultNvsPartition};
use esp_idf_sys::{esp_pm_config_esp32_t, esp_pm_configure};
use mqtt::IncomingMessageResponse;
use mqtt::Mqtt;
use serde::{Deserialize, Serialize};
use serde_json::json;
use serde_json::Value;
use static_cell::StaticCell;
use std::ffi::c_void;
Expand Down Expand Up @@ -70,6 +73,7 @@ impl Device {
let command_manager = COMMAND_MANAGER.init(command::CommandManager::new());
command_manager.register("ping", |v: Value| {
log::info!("Ping received: {:?}", v);
Ok(Some(json!({"pong": v})))
});

// TODO Use some async mDNS instead to avoid blocking the executor
Expand All @@ -79,9 +83,23 @@ impl Device {
mqtt::Mqtt::create(
&mdns,
&config.instance,
Box::new(|path, payload| {
Box::new(|path, payload| -> Result<IncomingMessageResponse> {
if let Some(command) = path.strip_prefix("commands/") {
command_manager.handle(command, payload);
let result = command_manager.handle(command, payload);
match result {
Ok(Some(response)) => {
log::info!("Command response: {}", response);
Ok(Some((format!("responses/{command}"), response)))
}
Ok(None) => {
log::info!("Command response: None");
Ok(None)
}
Err(e) => Err(anyhow!("Command error: {}", e)),
}
} else {
log::info!("Not a command: {}", path);
Ok(None)
}
}),
),
Expand Down
32 changes: 28 additions & 4 deletions src/kernel/mqtt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ pub struct Mqtt {
_conn: Arc<Mutex<CriticalSectionRawMutex, EspAsyncMqttConnection>>,
}

type IncomingMessageHandler = Box<dyn Fn(&str, &str)>;
pub type IncomingMessageResponse = Option<(String, String)>;
type IncomingMessageHandler = Box<dyn Fn(&str, &str) -> Result<IncomingMessageResponse>>;

impl Mqtt {
pub async fn create(
Expand Down Expand Up @@ -57,6 +58,7 @@ impl Mqtt {
Spawner::for_current_executor()
.await
.spawn(handle_mqtt_events(
mqtt.clone(),
conn.clone(),
format!("{topic_root}/"),
handler,
Expand Down Expand Up @@ -101,8 +103,9 @@ impl Mqtt {

#[embassy_executor::task]
async fn handle_mqtt_events(
mqtt: Arc<Mutex<CriticalSectionRawMutex, EspAsyncMqttClient>>,
conn: Arc<Mutex<CriticalSectionRawMutex, EspAsyncMqttConnection>>,
incoming_prefix: String,
prefix: String,
handler: IncomingMessageHandler,
connected: Arc<Signal<CriticalSectionRawMutex, ()>>,
) {
Expand All @@ -126,9 +129,30 @@ async fn handle_mqtt_events(
let data = std::str::from_utf8(data);
if let Ok(data) = data {
if let Some(path) = topic {
if let Some(path) = path.strip_prefix(&incoming_prefix) {
if let Some(path) = path.strip_prefix(&prefix) {
log::info!("Received message for path: {:?}", path);
(handler)(path, data);
let result = (handler)(path, data);
match result {
Ok(Some((response_path, response))) => {
let response_topic = format!("{}{}", prefix, response_path);
log::info!("Publishing response to: {:?}", response_topic);
let _ = mqtt
.lock()
.await
.publish(
&response_topic,
QoS::AtMostOnce,
false,
response.as_bytes(),
)
.await;
}
Ok(None) => {}
Err(e) => {
log::error!("Error handling message: {:?}", e);
// TODO Publish error
}
}
continue;
}
}
Expand Down

0 comments on commit b284644

Please sign in to comment.