llm-code-example

---concurrency-limiter---

---concurrency-limiter/Cargo.mdx---

concurrency-limiter/Cargo.toml
[package]
name = "example-concurrency-limiter"
version = "0.1.0"
edition = "2024"


[dependencies]
salvo = { version = "0.77.1", features = ["concurrency-limiter"]}
tokio = { version = "1", features = ["macros"] }
tracing = "0.1"
tracing-subscriber = "0.3"

---concurrency-limiter/src---

---concurrency-limiter/src/main.mdx---

concurrency-limiter/src/main.rs
use std::fs::create_dir_all;
use std::path::Path;

use salvo::prelude::*;

// Handler for serving the index page with upload forms
#[handler]
async fn index(res: &mut Response) {
    res.render(Text::Html(INDEX_HTML));
}

// Handler for processing file uploads with a simulated delay
#[handler]
async fn upload(req: &mut Request, res: &mut Response) {
    // Extract file from the multipart form data
    let file = req.file("file").await;
    // Simulate a long-running operation (10 seconds)
    tokio::time::sleep(tokio::time::Duration::from_secs(10)).await;

    if let Some(file) = file {
        // Generate destination path for the uploaded file
        let dest = format!("temp/{}", file.name().unwrap_or("file"));
        tracing::debug!(dest = %dest, "upload file");

        // Copy file to destination
        if let Err(e) = std::fs::copy(file.path(), Path::new(&dest)) {
            res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
            res.render(Text::Plain(format!("file not found in request: {e}")));
        } else {
            res.render(Text::Plain(format!("File uploaded to {dest}")));
        }
    } else {
        res.status_code(StatusCode::BAD_REQUEST);
        res.render(Text::Plain("file not found in request"));
    }
}

#[tokio::main]
async fn main() {
    // Initialize logging system
    tracing_subscriber::fmt().init();

    // Create temporary directory for file uploads
    create_dir_all("temp").unwrap();

    // Configure router with two upload endpoints:
    // - /limited: Only allows one concurrent upload (with concurrency limiter)
    // - /unlimit: Allows unlimited concurrent uploads
    let router = Router::new()
        .get(index)
        .push(
            Router::new()
                .hoop(max_concurrency(1)) // Limit concurrent requests to 1
                .path("limited")
                .post(upload),
        )
        .push(Router::with_path("unlimit").post(upload));

    // Bind server to port 5800 and start serving
    let acceptor = TcpListener::new("0.0.0.0:5800").bind().await;
    Server::new(acceptor).serve(router).await;
}

// HTML template for the upload forms page
static INDEX_HTML: &str = r#"<!DOCTYPE html>
<html>
    <head>
        <title>Upload file</title>
    </head>
    <body>
        <h1>Upload file</h1>
        <form action="/unlimit" method="post" enctype="multipart/form-data">
            <h3>Unlimit</h3>
            <input type="file" name="file" />
            <input type="submit" value="upload" />
        </form>
        <form action="/limited" method="post" enctype="multipart/form-data">
            <h3>Limited</h3>
            <input type="file" name="file" />
            <input type="submit" value="upload" />
        </form>
    </body>
</html>
"#;

---csrf-session-store---

---csrf-session-store/Cargo.mdx---

csrf-session-store/Cargo.toml
[package]
name = "example-csrf-session-store"
version = "0.1.0"
edition = "2024"

[dependencies]
salvo = { version = "0.77.1", features = ["csrf", "session"] }
tokio = { version = "1", features = ["macros"] }
tracing = "0.1"
tracing-subscriber = "0.3"
serde = { version = "1", features = ["derive"] }

---csrf-session-store/src---

---csrf-session-store/src/main.mdx---

csrf-session-store/src/main.rs
use salvo::csrf::*;
use salvo::prelude::*;
use serde::{Deserialize, Serialize};

// Handler for serving the home page with links to different CSRF protection methods
#[handler]
pub async fn home(res: &mut Response) {
    let html = r#"
    <!DOCTYPE html>
    <html>
    <head><meta charset="UTF-8"><title>Csrf SessionStore</title></head>
    <body>
    <h2>Csrf Exampe: SessionStore</h2>
    <ul>
        <li><a href="/bcrypt/">Bcrypt</a></li>
        <li><a href="/hmac/">Hmac</a></li>
        <li><a href="/aes_gcm/">Aes Gcm</a></li>
        <li><a href="/ccp/">chacha20poly1305</a></li>
    </ul>
    </body>"#;
    res.render(Text::Html(html));
}

// Handler for GET requests that displays a form with CSRF token
#[handler]
pub async fn get_page(depot: &mut Depot, res: &mut Response) {
    let new_token = depot.csrf_token().unwrap_or_default();
    res.render(Text::Html(get_page_html(new_token, "")));
}

// Handler for POST requests that processes form submission with CSRF validation
#[handler]
pub async fn post_page(req: &mut Request, depot: &mut Depot, res: &mut Response) {
    // Define data structure for form submission
    #[derive(Deserialize, Serialize, Debug)]
    struct Data {
        csrf_token: String,
        message: String,
    }
    // Parse the submitted form data into the Data struct
    let data = req.parse_form::<Data>().await.unwrap();
    // Log the received form data for debugging
    tracing::info!("posted data: {:?}", data);
    // Generate a new CSRF token for the next request
    let new_token = depot.csrf_token().unwrap_or_default();
    // Generate HTML response with the new token and display the submitted data
    let html = get_page_html(new_token, &format!("{data:#?}"));
    // Send the HTML response back to the client
    res.render(Text::Html(html));
}

#[tokio::main]
async fn main() {
    // Initialize logging system
    tracing_subscriber::fmt().init();

    // Configure CSRF token finder in form data
    let form_finder = FormFinder::new("csrf_token");

    // Initialize different CSRF protection methods using session store
    let bcrypt_csrf = bcrypt_session_csrf(form_finder.clone());
    let hmac_csrf = hmac_session_csrf(*b"01234567012345670123456701234567", form_finder.clone());
    let aes_gcm_session_csrf =
        aes_gcm_session_csrf(*b"01234567012345670123456701234567", form_finder.clone());
    let ccp_session_csrf =
        ccp_session_csrf(*b"01234567012345670123456701234567", form_finder.clone());

    // Configure session handler with memory store and secret key
    let session_handler = salvo::session::SessionHandler::builder(
        salvo::session::MemoryStore::new(),
        b"secretabsecretabsecretabsecretabsecretabsecretabsecretabsecretab",
    )
    .build()
    .unwrap();

    // Configure router with session handler and different CSRF protection endpoints
    let router = Router::new()
        .get(home)
        .hoop(session_handler)
        // Bcrypt-based CSRF protection
        .push(
            Router::with_hoop(bcrypt_csrf)
                .path("bcrypt")
                .get(get_page)
                .post(post_page),
        )
        // HMAC-based CSRF protection
        .push(
            Router::with_hoop(hmac_csrf)
                .path("hmac")
                .get(get_page)
                .post(post_page),
        )
        // AES-GCM-based CSRF protection
        .push(
            Router::with_hoop(aes_gcm_session_csrf)
                .path("aes_gcm")
                .get(get_page)
                .post(post_page),
        )
        // ChaCha20Poly1305-based CSRF protection
        .push(
            Router::with_hoop(ccp_session_csrf)
                .path("ccp")
                .get(get_page)
                .post(post_page),
        );

    // Start server on port 5800
    let acceptor = TcpListener::new("0.0.0.0:5800").bind().await;
    Server::new(acceptor).serve(router).await;
}

// Helper function to generate HTML page with CSRF token and message
fn get_page_html(csrf_token: &str, msg: &str) -> String {
    format!(
        r#"
    <!DOCTYPE html>
    <html>
    <head><meta charset="UTF-8"><title>Csrf SessionStore</title></head>
    <body>
    <h2>Csrf Exampe: SessionStore</h2>
    <ul>
        <li><a href="/bcrypt/">Bcrypt</a></li>
        <li><a href="/hmac/">Hmac</a></li>
        <li><a href="/aes_gcm/">Aes Gcm</a></li>
        <li><a href="/ccp/">chacha20poly1305</a></li>
    </ul>
    <form action="./" method="post">
        <input type="hidden" name="csrf_token" value="{csrf_token}" />
        <div>
            <label>Message:<input type="text" name="message" /></label>
        </div>
        <button type="submit">Send</button>
    </form>
    <pre>{msg}</pre>
    </body>
    </html>
    "#
    )
}

---webtransport---

---webtransport/certs---

---webtransport/static---

---webtransport/static/client.mdx---

webtransport/static/client.html
<!doctype html>
<html lang="en">
  <title>WebTransport over HTTP/3 client</title>
  <meta charset="utf-8">
  <!-- WebTransport origin trial token. See https://developer.chrome.com/origintrials/#/view_trial/793759434324049921 -->
  <meta http-equiv="origin-trial" content="AkSQvBVsfMTgBtlakApX94hWGyBPQJXerRc2Aq8g/sKTMF+yG62+bFUB2yIxaK1furrNH3KNNeJV00UZSZHicw4AAABceyJvcmlnaW4iOiJodHRwczovL2dvb2dsZWNocm9tZS5naXRodWIuaW86NDQzIiwiZmVhdHVyZSI6IldlYlRyYW5zcG9ydCIsImV4cGlyeSI6MTY0Mzc1OTk5OX0=">
  <script src="client.js"></script>
  <link rel="stylesheet" href="client.css">
  <meta name="viewport" content="width=device-width, initial-scale=1">
  <body>
  <div id="top">
    <div id="explanation">
      This tool can be used to connect to an arbitrary WebTransport server.
      It has several limitations:
      <ul>
        <li>It can only send an entirety of a stream at once.  Once the stream
          is opened, all of the data is immediately sent, and the write side of
          the steam is closed.</li>
        <li>This tool does not listen to server-initiated bidirectional
          streams.</li>
        <li>Stream IDs are different from the one used by QUIC on the wire, as
          the on-the-wire IDs are not exposed via the Web API.</li>
        <li>The <code>WebTransport</code> object can be accessed using the developer console via <code>currentTransport</code>.</li>
      </ul>
    </div>
    <div id="tool">
    <h1>WebTransport over HTTP/3 client</h1>
    <div>
      <h2>Establish WebTransport connection</h2>
      <div class="input-line">
      <label for="url">URL:</label>
      <input type="text" name="url" id="url"
             value="https://0.0.0.0:5800/counter">
      <input type="button" id="connect" value="Connect" onclick="connect()">
      </div>
    </div>
    <div>
      <h2>Send data over WebTransport</h2>
      <form name="sending">
      <textarea name="data" id="data"></textarea>
      <div>
        <input type="radio" name="sendtype" value="datagram"
               id="datagram" checked>
        <label for="datagram">Send a datagram</label>
      </div>
      <div>
        <input type="radio" name="sendtype" value="unidi" id="unidi-stream">
        <label for="unidi-stream">Open a unidirectional stream</label>
      </div>
      <div>
        <input type="radio" name="sendtype" value="bidi" id="bidi-stream">
        <label for="bidi-stream">Open a bidirectional stream</label>
      </div>
      <input type="button" id="send" name="send" value="Send data"
             disabled onclick="sendData()">
      </form>
    </div>
    <div>
      <h2>Event log</h2>
      <ul id="event-log">
      </ul>
    </div>
    </div>
  </div>
  </body>
</html>

---webtransport/Cargo.mdx---

webtransport/Cargo.toml
[package]
name = "example-webtransport"
version = "0.1.0"
edition = "2024"


[dependencies]
anyhow = "1"
futures-util = "0.3"
salvo = { version = "0.77.1", features = ["quinn", "anyhow", "serve-static"] }
tokio = { version = "1", features = ["macros"] }
tracing = "0.1"
tracing-subscriber = "0.3"
serde = "1"
serde_json = "1"
bytes = "1"

---webtransport/src---

---webtransport/src/main.mdx---

webtransport/src/main.rs
use std::time::Duration;

use anyhow::{Context, Result};
use bytes::{BufMut, Bytes, BytesMut};
use salvo::conn::rustls::{Keycert, RustlsConfig};
use salvo::prelude::*;
use salvo::proto::webtransport;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::pin;

macro_rules! log_result {
    ($expr:expr) => {
        if let Err(err) = $expr {
            tracing::error!("{err:?}");
        }
    };
}
async fn echo_stream<T, R>(send: T, recv: R) -> anyhow::Result<()>
where
    T: AsyncWrite,
    R: AsyncRead,
{
    pin!(send);
    pin!(recv);

    tracing::info!("Got stream");
    let mut buf = Vec::new();
    recv.read_to_end(&mut buf).await?;

    let message = Bytes::from(buf);
    send_chunked(send, message).await?;

    Ok(())
}
// Used to test that all chunks arrive properly as it is easy to write an impl which only reads and
// writes the first chunk.
async fn send_chunked(mut send: impl AsyncWrite + Unpin, data: Bytes) -> anyhow::Result<()> {
    for chunk in data.chunks(4) {
        tokio::time::sleep(Duration::from_millis(100)).await;
        tracing::info!("Sending {chunk:?}");
        send.write_all(chunk).await?;
    }

    Ok(())
}

#[handler]
async fn connect(req: &mut Request) -> Result<(), salvo::Error> {
    let session = req.web_transport_mut().await.unwrap();
    let session_id = session.session_id();

    // This will open a bidirectional stream and send a message to the client right after connecting!
    let stream = session.open_bi(session_id).await?;

    tokio::spawn(async move {
        log_result!(open_bidi_test(stream).await);
    });
    loop {
        tokio::select! {
            datagram = session.accept_datagram() => {
                let datagram = datagram?;
                if let Some((_, datagram)) = datagram {
                    tracing::info!("Responding with {datagram:?}");
                    // Put something before to make sure encoding and decoding works and don't just
                    // pass through
                    let mut resp = BytesMut::from(&b"Response: "[..]);
                    resp.put(datagram);

                    session.send_datagram(resp.freeze())?;
                    tracing::info!("Finished sending datagram");
                }
            }
            uni_stream = session.accept_uni() => {
                let (id, stream) = uni_stream?.unwrap();

                let send = session.open_uni(id).await?;
                tokio::spawn( async move { log_result!(echo_stream(send, stream).await); });
            }
            stream = session.accept_bi() => {
                if let Some(webtransport::server::AcceptedBi::BidiStream(_, stream)) = stream? {
                    let (send, recv) = salvo::proto::quic::BidiStream::split(stream);
                    tokio::spawn( async move { log_result!(echo_stream(send, recv).await); });
                }
            }
            else => {
                break
            }
        }
    }

    tracing::info!("Finished handling session");

    Ok(())
}

async fn open_bidi_test<S>(mut stream: S) -> anyhow::Result<()>
where
    S: Unpin + AsyncRead + AsyncWrite,
{
    tracing::info!("Opening bidirectional stream");

    stream
        .write_all(b"Hello from a server initiated bidi stream")
        .await
        .context("Failed to respond")?;

    let mut resp = Vec::new();
    stream.shutdown().await?;
    stream.read_to_end(&mut resp).await?;

    tracing::info!("Got response from client: {resp:?}");

    Ok(())
}

#[tokio::main]
async fn main() {
    tracing_subscriber::fmt().init();

    let cert = include_bytes!("../certs/cert.pem").to_vec();
    let key = include_bytes!("../certs/key.pem").to_vec();

    let router = Router::new()
        .push(Router::with_path("counter").goal(connect))
        .push(
            Router::with_path("{*path}")
                .get(StaticDir::new(["webtransport/static", "./static"]).defaults("client.html")),
        );

    let config = RustlsConfig::new(Keycert::new().cert(cert.as_slice()).key(key.as_slice()));
    let listener = TcpListener::new(("0.0.0.0", 5800)).rustls(config.clone());

    let acceptor = QuinnListener::new(config, ("0.0.0.0", 5800))
        .join(listener)
        .bind()
        .await;

    Server::new(acceptor).serve(router).await;
}

---catch-error---

---catch-error/Cargo.mdx---

catch-error/Cargo.toml
[package]
name = "example-catch-error"
version = "0.1.0"
edition = "2024"


[dependencies]
anyhow = "1"
eyre = "0.6"
salvo = { version = "0.77.1", features = ["anyhow", "eyre"] }
tokio = { version = "1", features = ["macros"] }
tracing = "0.1"
tracing-subscriber = "0.3"

---catch-error/src---

---catch-error/src/main.mdx---

catch-error/src/main.rs
use salvo::prelude::*;

// Custom error type for demonstration
struct CustomError;

// Implement Writer trait for CustomError to customize error response
#[async_trait]
impl Writer for CustomError {
    async fn write(self, _req: &mut Request, _depot: &mut Depot, res: &mut Response) {
        // Set response status code to 500 and custom error message
        res.status_code(StatusCode::INTERNAL_SERVER_ERROR);
        res.render("custom error");
    }
}

// Handler that returns an anyhow error for testing error handling
#[handler]
async fn handle_anyhow() -> Result<(), anyhow::Error> {
    Err(anyhow::anyhow!("handled anyhow error"))
}

// Handler that returns an eyre error for testing error handling
#[handler]
async fn handle_eyre() -> eyre::Result<()> {
    Err(eyre::Report::msg("handled eyre error"))
}

// Handler that returns our custom error type
#[handler]
async fn handle_custom() -> Result<(), CustomError> {
    Err(CustomError)
}

#[tokio::main]
async fn main() {
    // Initialize logging system
    tracing_subscriber::fmt().init();

    // Set up router with three error handling endpoints:
    // - /anyhow : demonstrates anyhow error handling
    // - /eyre : demonstrates eyre error handling
    // - /custom : demonstrates custom error handling
    let router = Router::new()
        .push(Router::with_path("anyhow").get(handle_anyhow))
        .push(Router::with_path("eyre").get(handle_eyre))
        .push(Router::with_path("custom").get(handle_custom));

    // Bind server to port 5800 and start serving
    let acceptor = TcpListener::new("0.0.0.0:5800").bind().await;
    Server::new(acceptor).serve(router).await;
}

---join-listeners---

---join-listeners/Cargo.mdx---

join-listeners/Cargo.toml
[package]
name = "example-join-listeners"
version = "0.1.0"
edition = "2024"


[dependencies]
salvo = { version = "0.77.1" }
tokio = { version = "1", features = ["macros"] }
tracing = "0.1"
tracing-subscriber = "0.3"

---join-listeners/src---

---join-listeners/src/main.mdx---

join-listeners/src/main.rs
use salvo::prelude::*;

#[handler]
async fn hello() -> &'static str {
    "Hello World"
}

#[tokio::main]
async fn main() {
    tracing_subscriber::fmt().init();

    let router = Router::new().get(hello);
    let acceptor = TcpListener::new("0.0.0.0:5800")
        .join(TcpListener::new("0.0.0.0:5801"))
        .bind()
        .await;

    Server::new(acceptor).serve(router).await;
}

---oapi-todos---

---oapi-todos/Cargo.mdx---

oapi-todos/Cargo.toml
[package]
name = "example-oapi-todos"
version = "0.1.0"
edition = "2024"


[dependencies]
salvo = { version = "0.77.1", features = ["oapi"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tokio = { version = "1", features = ["macros"] }
compact_str = { version = "0.7", features = ["serde"] }
tracing = "0.1"
tracing-subscriber = "0.3"

---oapi-todos/src---

---oapi-todos/src/main.mdx---

oapi-todos/src/main.rs
use std::sync::LazyLock;

use salvo::oapi::{ToSchema, extract::*};
use salvo::prelude::*;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;

static STORE: LazyLock<Db> = LazyLock::new(new_store);
pub type Db = Mutex<Vec<Todo>>;

pub fn new_store() -> Db {
    Mutex::new(Vec::new())
}

#[derive(Serialize, Deserialize, Clone, Debug, ToSchema)]
pub struct Todo {
    #[salvo(schema(example = 1))]
    pub id: u64,
    #[salvo(schema(example = "Buy coffee"))]
    pub text: String,
    pub completed: bool,
}

#[tokio::main]
async fn main() {
    tracing_subscriber::fmt().init();

    let router = Router::new().get(index).push(
        Router::with_path("api").push(
            Router::with_path("todos")
                .get(list_todos)
                .post(create_todo)
                .push(
                    Router::with_path("{id}")
                        .patch(update_todo)
                        .delete(delete_todo),
                ),
        ),
    );

    let doc = OpenApi::new("todos api", "0.0.1").merge_router(&router);

    let router = router
        .unshift(doc.into_router("/api-doc/openapi.json"))
        .unshift(
            SwaggerUi::new("/api-doc/openapi.json")
                .title("Todos - SwaggerUI")
                .into_router("/swagger-ui"),
        )
        .unshift(
            Scalar::new("/api-doc/openapi.json")
                .title("Todos - Scalar")
                .into_router("/scalar"),
        )
        .unshift(
            RapiDoc::new("/api-doc/openapi.json")
                .title("Todos - RapiDoc")
                .into_router("/rapidoc"),
        )
        .unshift(
            ReDoc::new("/api-doc/openapi.json")
                .title("Todos - ReDoc")
                .into_router("/redoc"),
        );

    let acceptor = TcpListener::new("0.0.0.0:5800").bind().await;
    Server::new(acceptor).serve(router).await;
}

#[handler]
pub async fn index() -> Text<&'static str> {
    Text::Html(INDEX_HTML)
}

/// List todos.
#[endpoint(
    tags("todos"),
    parameters(
        ("offset", description = "Offset is an optional query paramter."),
    )
)]
pub async fn list_todos(
    offset: QueryParam<usize, false>,
    limit: QueryParam<usize, false>,
) -> Json<Vec<Todo>> {
    let todos = STORE.lock().await;
    let todos: Vec<Todo> = todos
        .clone()
        .into_iter()
        .skip(offset.into_inner().unwrap_or(0))
        .take(limit.into_inner().unwrap_or(usize::MAX))
        .collect();
    Json(todos)
}

/// Create new todo.
#[endpoint(tags("todos"), status_codes(201, 409))]
pub async fn create_todo(req: JsonBody<Todo>) -> Result<StatusCode, StatusError> {
    tracing::debug!(todo = ?req, "create todo");

    let mut vec = STORE.lock().await;

    for todo in vec.iter() {
        if todo.id == req.id {
            tracing::debug!(id = ?req.id, "todo already exists");
            return Err(StatusError::bad_request().brief("todo already exists"));
        }
    }

    vec.push(req.into_inner());
    Ok(StatusCode::CREATED)
}

/// Update existing todo.
#[endpoint(tags("todos"), status_codes(200, 404))]
pub async fn update_todo(
    id: PathParam<u64>,
    updated: JsonBody<Todo>,
) -> Result<StatusCode, StatusError> {
    tracing::debug!(todo = ?updated, id = ?id, "update todo");
    let mut vec = STORE.lock().await;

    for todo in vec.iter_mut() {
        if todo.id == *id {
            *todo = (*updated).clone();
            return Ok(StatusCode::OK);
        }
    }

    tracing::debug!(?id, "todo is not found");
    Err(StatusError::not_found())
}

/// Delete todo.
#[endpoint(tags("todos"), status_codes(200, 401, 404))]
pub async fn delete_todo(id: PathParam<u64>) -> Result<StatusCode, StatusError> {
    tracing::debug!(?id, "delete todo");

    let mut vec = STORE.lock().await;

    let len = vec.len();
    vec.retain(|todo| todo.id != *id);

    let deleted = vec.len() != len;
    if deleted {
        Ok(StatusCode::NO_CONTENT)
    } else {
        tracing::debug!(?id, "todo is not found");
        Err(StatusError::not_found())
    }
}

static INDEX_HTML: &str = r#"<!DOCTYPE html>
<html>
    <head>
        <title>Oapi todos</title>
    </head>
    <body>
        <ul>
        <li><a href="swagger-ui" target="_blank">swagger-ui</a></li>
        <li><a href="scalar" target="_blank">scalar</a></li>
        <li><a href="rapidoc" target="_blank">rapidoc</a></li>
        <li><a href="redoc" target="_blank">redoc</a></li>
        </ul>
    </body>
</html>
"#;

---db-mongodb---

---db-mongodb/Cargo.mdx---

db-mongodb/Cargo.toml
[package]
name = "example-db-mongodb"
version = "0.1.0"
edition = "2024"

[dependencies]
salvo = { version = "0.77.1" }
tokio = { version = "1", features = ["macros"] }
tracing = "0.1"
tracing-subscriber = "0.3"
futures.workspace = true
serde = { version = "1", features = ["derive"] }
serde_json = "1"
mongodb = "2"
thiserror = "1"

---db-mongodb/src---

---db-mongodb/src/main.mdx---

db-mongodb/src/main.rs
use std::sync::OnceLock;

use futures::stream::TryStreamExt;
use mongodb::{
    Client, Collection, IndexModel, bson::Document, bson::doc, bson::oid::ObjectId,
    options::IndexOptions,
};
use salvo::prelude::*;
use serde::{Deserialize, Serialize};

// Database and collection names
const DB_NAME: &str = "myApp";
const COLL_NAME: &str = "users";

use thiserror::Error;

// Custom error type for MongoDB operations
#[derive(Error, Debug)]
pub enum Error {
    #[error("MongoDB Error")]
    ErrorMongo(#[from] mongodb::error::Error),
}

pub type AppResult<T> = Result<T, Error>;

// Implement error response writer for our custom error type
#[async_trait]
impl Writer for Error {
    async fn write(self, _req: &mut Request, _depot: &mut Depot, _res: &mut Response) {}
}

// User model representing the document structure in MongoDB
#[derive(Debug, Deserialize, Serialize)]
struct User {
    _id: Option<ObjectId>,
    first_name: String,
    last_name: String,
    username: String,
    email: String,
}

// Global MongoDB client instance
static MONGODB_CLIENT: OnceLock<Client> = OnceLock::new();

// Helper function to get the MongoDB client instance
#[inline]
pub fn get_mongodb_client() -> &'static Client {
    MONGODB_CLIENT.get().unwrap()
}

// Handler for adding a new user to the database
#[handler]
async fn add_user(req: &mut Request, res: &mut Response) {
    let client = get_mongodb_client();
    let coll_users = client.database(DB_NAME).collection::<Document>(COLL_NAME);
    let new_user = req.parse_json::<User>().await.unwrap();

    // Create BSON document from user data
    let user = doc! {
        "first_name": new_user.first_name,
        "last_name": new_user.last_name,
        "username": new_user.username,
        "email": new_user.email,
    };

    // Insert user document into MongoDB
    let result = coll_users.insert_one(user, None).await;
    match result {
        Ok(id) => res.render(format!("user added with ID {:?}", id.inserted_id)),
        Err(e) => res.render(format!("error {e:?}")),
    }
}

// Handler for retrieving all users from the database
#[handler]
async fn get_users(res: &mut Response) -> AppResult<()> {
    let client = get_mongodb_client();
    let coll_users = client.database(DB_NAME).collection::<User>(COLL_NAME);
    // Find all users and convert cursor to vector
    let mut cursor = coll_users.find(None, None).await?;
    let mut vec_users: Vec<User> = Vec::new();
    while let Some(user) = cursor.try_next().await? {
        vec_users.push(user);
    }
    res.render(Json(vec_users));
    Ok(())
}

// Handler for retrieving a single user by username
#[handler]
async fn get_user(req: &mut Request, res: &mut Response) {
    let client = get_mongodb_client();
    let coll_users: Collection<User> = client.database(DB_NAME).collection(COLL_NAME);

    let username = req.param::<String>("username").unwrap();
    // Find user by username
    match coll_users
        .find_one(doc! { "username": &username }, None)
        .await
    {
        Ok(Some(user)) => res.render(Json(user)),
        Ok(None) => res.render(format!("No user found with username {username}")),
        Err(e) => res.render(format!("error {e:?}")),
    }
}

// Create a unique index on the username field
async fn create_username_index(client: &Client) {
    let options = IndexOptions::builder().unique(true).build();
    let model = IndexModel::builder()
        .keys(doc! { "username": 1 })
        .options