canary-rs/apps/magpie/src/protocol.rs

160 lines
4.9 KiB
Rust

// Copyright (c) 2022 Marceline Cramer
// SPDX-License-Identifier: AGPL-3.0-or-later
use std::collections::VecDeque;
use std::io::{Read, Write};
use std::marker::PhantomData;
use std::path::PathBuf;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
/// The name of the Magpie server socket.
pub const MAGPIE_SOCK: &str = "magpie.sock";
/// An identifier for a Magpie panel.
///
/// Only valid on a connection between a single client and its server. Clients
/// are allowed to use arbitrary values for [PanelId].
pub type PanelId = u32;
/// Creates a new Magpie panel with a given ID.
///
/// If the given [PanelId] is already being used on this connection, the server
/// will delete the old panel using that [PanelId].
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct CreatePanel {
pub id: PanelId,
pub protocol: String,
pub script: PathBuf,
}
/// Sends a panel a message.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct SendMessage {
pub id: PanelId,
pub msg: Vec<u8>,
}
/// A message sent from a Magpie client to the server.
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(tag = "kind")]
pub enum MagpieServerMsg {
CreatePanel(CreatePanel),
SendMessage(SendMessage),
}
/// A message sent from a script's panel to a client.
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct RecvMessage {
pub id: PanelId,
pub msg: Vec<u8>,
}
/// A message sent from the Magpie server to a client.
#[derive(Clone, Debug, Deserialize, Serialize)]
#[serde(tag = "kind")]
pub enum MagpieClientMsg {
RecvMessage(RecvMessage),
}
/// A [Messenger] specialized for Magpie clients.
pub type ClientMessenger<T> = Messenger<T, MagpieClientMsg, MagpieServerMsg>;
/// A [Messenger] specialized for Magpie servers.
pub type ServerMessenger<T> = Messenger<T, MagpieServerMsg, MagpieClientMsg>;
/// Bidirectional, transport-agnostic Magpie IO wrapper struct.
pub struct Messenger<T, I, O> {
pub transport: T,
expected_len: Option<usize>,
received_buf: VecDeque<u8>,
received_queue: VecDeque<I>,
closed: bool,
_output: PhantomData<O>,
}
impl<T: Read + Write, I: DeserializeOwned, O: Serialize> Messenger<T, I, O> {
pub fn new(transport: T) -> Self {
Self {
transport,
expected_len: None,
received_buf: Default::default(),
received_queue: Default::default(),
closed: false,
_output: PhantomData,
}
}
pub fn is_closed(&self) -> bool {
self.closed
}
pub fn send(&mut self, msg: &O) -> std::io::Result<()> {
use byteorder::{LittleEndian, WriteBytesExt};
let payload = serde_json::to_vec(msg).unwrap();
let len = payload.len() as u32;
self.transport.write_u32::<LittleEndian>(len)?;
self.transport.write_all(&payload)?;
self.transport.flush()?;
Ok(())
}
/// Receives all pending messages and queues them for [recv].
pub fn flush_recv(&mut self) -> std::io::Result<()> {
let mut buf = [0u8; 1024];
loop {
match self.transport.read(&mut buf) {
Ok(0) => {
self.closed = true;
break;
}
Err(ref err) if err.kind() == std::io::ErrorKind::ConnectionReset => {
self.closed = true;
break;
}
Ok(n) => {
self.received_buf.write(&buf[..n])?;
}
Err(ref err) if err.kind() == std::io::ErrorKind::WouldBlock => break,
Err(ref err) if err.kind() == std::io::ErrorKind::Interrupted => continue,
Err(err) => return Err(err),
}
}
loop {
if let Some(expected_len) = self.expected_len {
if self.received_buf.len() < expected_len {
break;
}
self.expected_len = None;
let mut buf = vec![0u8; expected_len];
self.received_buf.read_exact(&mut buf)?;
match serde_json::from_slice::<I>(&buf) {
Ok(received) => self.received_queue.push_front(received),
Err(e) => {
let kind = std::io::ErrorKind::InvalidData;
let payload = Box::new(e);
let error = std::io::Error::new(kind, payload);
return Err(error);
}
}
} else if self.received_buf.len() >= 4 {
use byteorder::{LittleEndian, ReadBytesExt};
let expected_len = self.received_buf.read_u32::<LittleEndian>()?;
self.expected_len = Some(expected_len as usize);
} else {
break;
}
}
Ok(())
}
/// Tries to receive a single input packet.
pub fn recv(&mut self) -> Option<I> {
self.received_queue.pop_back()
}
}