Files
authorize/src/main.rs
matresnan 806480adb1 docs: 更新README文档并改进错误处理
- 在README.md文件中添加获取项目列表的API接口文档说明。
- 将服务启动代码中的错误处理从 `unwrap()` 改为更安全的 `?` 操作符以传播错误。
2025-08-21 11:28:50 +08:00

337 lines
6.9 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::{
extract::{Query, State},
http::StatusCode,
routing::get,
Json, Router,
};
use chrono::{prelude::*, Duration, ParseError};
use serde::{Deserialize, Serialize};
use std::{env, ops::Deref, sync::Arc};
mod db;
mod generate;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// 根据操作系统选择数据库路径
let db_path: String = env::current_dir()
.expect("无法获取当前目录")
.join("authorize_data/database.db")
.to_string_lossy()
.into_owned();
let db = Arc::new(db::Db::new(db_path).await?);
let generator = Arc::new(
generate::TokenGenerator::new()
.with_uppercase(true)
.with_lowercase(true)
.with_numbers(true),
);
let state = AppState { db, generator };
let app = Router::new()
.route("/create", get(create_token))
.route("/verify", get(verify_token))
.route("/info", get(get_token_info))
.route("/reset", get(update_token_status))
.route("/renewal", get(renewal_token))
.route("/project", get(get_projects))
.with_state(state);
let listener = tokio::net::TcpListener::bind("0.0.0.0:3009").await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn create_token(
State(state): State<AppState>,
Query(args): Query<CreateToken>,
) -> (StatusCode, Json<CreateTokenInfo>) {
let exists = state
.db
.exists_project(&args.project, &args.device_id)
.await
.unwrap();
match exists {
Some(info) => {
return (
StatusCode::OK,
Json(CreateTokenInfo {
code: 404,
project: info.project,
device_id: info.device_id,
token: info.token,
msg: "token已存在请勿重复创建".to_owned(),
}),
)
}
None => (),
}
let str_time = get_current_datetime();
let exp_time = add_day(&str_time, 7).unwrap();
let token: String = state.generator.generate(16);
let _token_id = state
.db
.insert_authorize(InsertArgs {
project: args.project.clone(),
token: token.clone(),
device_id: args.device_id.clone(),
disable: 1,
expire: exp_time,
insert_time: str_time,
})
.await;
(
StatusCode::OK,
Json(CreateTokenInfo {
code: 200,
project: args.project,
device_id: args.device_id,
token,
msg: "token创建成功".to_owned(),
}),
)
}
async fn verify_token(
State(state): State<AppState>,
Query(args): Query<VerifyToken>,
) -> (StatusCode, Json<VerifyResult>) {
if state.db.verify_token(&args.token).await.unwrap() {
(
StatusCode::OK,
Json(VerifyResult {
code: 200,
msg: "正常".to_owned(),
}),
)
} else {
(
StatusCode::OK,
Json(VerifyResult {
code: 404,
msg: "token已过期".to_owned(),
}),
)
}
}
async fn get_token_info(
State(state): State<AppState>,
Query(args): Query<VerifyToken>,
) -> (StatusCode, Json<TokenResponse>) {
let auth_info = state.db.get_token_info(&args.token).await.unwrap();
match auth_info {
Some(auth) => {
let v = auth.into_inner();
(
StatusCode::OK,
Json(TokenResponse::Success(TokenInfo {
code: 200,
project: v.project,
token: v.token,
device_id: v.device_id,
disable: v.disable,
expire: v.expire,
insert_time: v.insert_time,
})),
)
}
None => (
StatusCode::OK,
Json(TokenResponse::Error(QueryError {
code: 404,
msg: "未查询token相关信息".to_owned(),
})),
),
}
}
async fn update_token_status(
State(state): State<AppState>,
Query(args): Query<UpdateTokenStatus>,
) -> (StatusCode, Json<VerifyResult>) {
let result = if args.enable {
state.db.update_token_state(&args.token, 1).await.unwrap()
} else {
state.db.update_token_state(&args.token, 0).await.unwrap()
};
if result {
(
StatusCode::OK,
Json(VerifyResult {
code: 200,
msg: "操作成功".to_owned(),
}),
)
} else {
(
StatusCode::OK,
Json(VerifyResult {
code: 404,
msg: "操作失败".to_owned(),
}),
)
}
}
async fn renewal_token(
State(state): State<AppState>,
Query(args): Query<RenewalToken>,
) -> (StatusCode, Json<TokenResponse>) {
let auth_info = state
.db
.update_token_expiry(&args.token, args.days)
.await
.unwrap();
match auth_info {
Some(auth) => {
let v = auth.into_inner();
(
StatusCode::OK,
Json(TokenResponse::Success(TokenInfo {
code: 200,
project: v.project,
token: v.token,
device_id: v.device_id,
disable: v.disable,
expire: v.expire,
insert_time: v.insert_time,
})),
)
}
None => (
StatusCode::OK,
Json(TokenResponse::Error(QueryError {
code: 404,
msg: "操作失败".to_owned(),
})),
),
}
}
/// 获取所有的项目名
async fn get_projects(
State(state): State<AppState>,
) -> (StatusCode, Json<ProjectResponse>) {
let projects = state.db.get_all_project().await.unwrap();
(
StatusCode::OK,
Json(ProjectResponse {
code: 200,
projects: projects,
}),
)
}
fn get_current_datetime() -> String {
Local::now().format("%Y-%m-%d %H:%M:%S").to_string()
}
fn add_day(t: &str, days: i64) -> Result<String, ParseError> {
let date_time = NaiveDateTime::parse_from_str(t, "%Y-%m-%d %H:%M:%S")?;
let new_time = date_time + Duration::days(days);
Ok(new_time.format("%Y-%m-%d %H:%M:%S").to_string())
}
pub struct InsertArgs {
pub project: String,
pub token: String,
pub device_id: String,
pub disable: i8,
pub expire: String,
pub insert_time: String,
}
#[derive(Clone)]
struct AppState {
db: Arc<db::Db>,
generator: Arc<generate::TokenGenerator>,
}
#[derive(Deserialize)]
struct CreateToken {
project: String,
device_id: String,
}
#[derive(Deserialize, Debug, Clone)]
struct Token(String);
// 让 Token 可以像 &str 一样使用
impl Deref for Token {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Deserialize)]
struct VerifyToken {
token: Token,
}
#[derive(Deserialize)]
struct RenewalToken {
token: Token,
days: i64,
}
#[derive(Serialize)]
struct VerifyResult {
code: i16,
msg: String,
}
#[derive(Serialize)]
struct TokenInfo {
code: i16,
project: String,
token: String,
device_id: String,
disable: i8,
expire: String,
insert_time: String,
}
#[derive(Serialize)]
struct QueryError {
code: i16,
msg: String,
}
#[derive(Serialize)]
#[serde(untagged)]
enum TokenResponse {
Success(TokenInfo),
Error(QueryError),
}
#[derive(Serialize)]
struct CreateTokenInfo {
code: i16,
project: String,
device_id: String,
token: String,
msg: String,
}
#[derive(Deserialize)]
struct UpdateTokenStatus {
token: String,
enable: bool,
}
#[derive(Serialize)]
struct ProjectResponse {
code: i32,
projects: Vec<String>,
}