Files
authorize/src/main.rs
matresnan 1d107a97ff fix: 修正创建令牌成功返回码为200
- 将创建令牌成功时的返回状态码从404改为200
2025-08-25 19:55:25 +08:00

364 lines
7.4 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use axum::{
extract::{Query, State},
http::StatusCode,
routing::get,
Json, Router,
};
use chrono::{prelude::*, Duration, ParseError};
use serde::{Deserialize, Serialize};
use std::{env, ops::Deref, sync::Arc};
mod db;
mod generate;
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
// 根据操作系统选择数据库路径
let db_path: String = env::current_dir()
.expect("无法获取当前目录")
.join("authorize_data/database.db")
.to_string_lossy()
.into_owned();
let db = Arc::new(db::Db::new(db_path).await?);
let generator = Arc::new(
generate::TokenGenerator::new()
.with_uppercase(true)
.with_lowercase(true)
.with_numbers(true),
);
let state = AppState { db, generator };
let app = Router::new()
.route("/create", get(create_token))
.route("/verify", get(verify_token))
.route("/info", get(get_token_info))
.route("/reset", get(update_token_status))
.route("/renewal", get(renewal_token))
.route("/project", get(get_projects))
.route("/all", get(get_project_token))
.with_state(state);
let listener = tokio::net::TcpListener::bind("0.0.0.0:3009").await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn create_token(
State(state): State<AppState>,
Query(args): Query<CreateToken>,
) -> (StatusCode, Json<CreateTokenInfo>) {
let exists = state
.db
.exists_project(&args.project, &args.device_id)
.await
.unwrap();
match exists {
Some(info) => {
return (
StatusCode::OK,
Json(CreateTokenInfo {
code: 200,
project: info.project,
device_id: info.device_id,
token: info.token,
msg: "token已存在请勿重复创建".to_owned(),
}),
)
}
None => (),
}
let str_time = get_current_datetime();
let exp_time = add_day(&str_time, 7).unwrap();
let token: String = state.generator.generate(16);
let _token_id = state
.db
.insert_authorize(InsertArgs {
project: args.project.clone(),
token: token.clone(),
device_id: args.device_id.clone(),
disable: 1,
expire: exp_time,
insert_time: str_time,
})
.await;
(
StatusCode::OK,
Json(CreateTokenInfo {
code: 200,
project: args.project,
device_id: args.device_id,
token,
msg: "token创建成功".to_owned(),
}),
)
}
async fn verify_token(
State(state): State<AppState>,
Query(args): Query<VerifyToken>,
) -> (StatusCode, Json<VerifyResult>) {
if state.db.verify_token(&args.token).await.unwrap() {
(
StatusCode::OK,
Json(VerifyResult {
code: 200,
msg: "正常".to_owned(),
}),
)
} else {
(
StatusCode::OK,
Json(VerifyResult {
code: 404,
msg: "token已过期".to_owned(),
}),
)
}
}
async fn get_token_info(
State(state): State<AppState>,
Query(args): Query<VerifyToken>,
) -> (StatusCode, Json<TokenResponse>) {
let auth_info = state.db.get_token_info(&args.token).await.unwrap();
match auth_info {
Some(auth) => {
let v = auth.into_inner();
(
StatusCode::OK,
Json(TokenResponse::Success(TokenInfo {
code: 200,
project: v.project,
token: v.token,
device_id: v.device_id,
disable: v.disable,
expire: v.expire,
insert_time: v.insert_time,
})),
)
}
None => (
StatusCode::OK,
Json(TokenResponse::Error(QueryError {
code: 404,
msg: "未查询token相关信息".to_owned(),
})),
),
}
}
async fn update_token_status(
State(state): State<AppState>,
Query(args): Query<UpdateTokenStatus>,
) -> (StatusCode, Json<VerifyResult>) {
let result = if args.enable {
state.db.update_token_state(&args.token, 1).await.unwrap()
} else {
state.db.update_token_state(&args.token, 0).await.unwrap()
};
if result {
(
StatusCode::OK,
Json(VerifyResult {
code: 200,
msg: "操作成功".to_owned(),
}),
)
} else {
(
StatusCode::OK,
Json(VerifyResult {
code: 404,
msg: "操作失败".to_owned(),
}),
)
}
}
async fn renewal_token(
State(state): State<AppState>,
Query(args): Query<RenewalToken>,
) -> (StatusCode, Json<TokenResponse>) {
let auth_info = state
.db
.update_token_expiry(&args.token, args.days)
.await
.unwrap();
match auth_info {
Some(auth) => {
let v = auth.into_inner();
(
StatusCode::OK,
Json(TokenResponse::Success(TokenInfo {
code: 200,
project: v.project,
token: v.token,
device_id: v.device_id,
disable: v.disable,
expire: v.expire,
insert_time: v.insert_time,
})),
)
}
None => (
StatusCode::OK,
Json(TokenResponse::Error(QueryError {
code: 404,
msg: "操作失败".to_owned(),
})),
),
}
}
/// 获取所有的项目名
async fn get_projects(
State(state): State<AppState>,
) -> (StatusCode, Json<ProjectResponse>) {
let projects = state.db.get_all_project().await.unwrap();
(
StatusCode::OK,
Json(ProjectResponse {
code: 200,
projects: projects,
}),
)
}
async fn get_project_token(
State(state): State<AppState>,
Query(args): Query<QueryProject>,
) -> (StatusCode, Json<ProjectToken>) {
let items = state.db.query_project_token(&args.project).await.unwrap();
(
StatusCode::OK,
Json(ProjectToken {
code: 200,
data: items,
}),
)
}
fn get_current_datetime() -> String {
Local::now().format("%Y-%m-%d %H:%M:%S").to_string()
}
fn add_day(t: &str, days: i64) -> Result<String, ParseError> {
let date_time = NaiveDateTime::parse_from_str(t, "%Y-%m-%d %H:%M:%S")?;
let new_time = date_time + Duration::days(days);
Ok(new_time.format("%Y-%m-%d %H:%M:%S").to_string())
}
pub struct InsertArgs {
pub project: String,
pub token: String,
pub device_id: String,
pub disable: i8,
pub expire: String,
pub insert_time: String,
}
#[derive(Clone)]
struct AppState {
db: Arc<db::Db>,
generator: Arc<generate::TokenGenerator>,
}
#[derive(Deserialize)]
struct CreateToken {
project: String,
device_id: String,
}
#[derive(Deserialize, Debug, Clone)]
struct Token(String);
// 让 Token 可以像 &str 一样使用
impl Deref for Token {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
#[derive(Deserialize)]
struct VerifyToken {
token: Token,
}
#[derive(Deserialize)]
struct RenewalToken {
token: Token,
days: i64,
}
#[derive(Deserialize)]
struct QueryProject {
project: String,
}
#[derive(Serialize)]
struct VerifyResult {
code: i16,
msg: String,
}
#[derive(Serialize)]
struct TokenInfo {
code: i16,
project: String,
token: String,
device_id: String,
disable: i8,
expire: String,
insert_time: String,
}
#[derive(Serialize)]
struct QueryError {
code: i16,
msg: String,
}
#[derive(Serialize)]
#[serde(untagged)]
enum TokenResponse {
Success(TokenInfo),
Error(QueryError),
}
#[derive(Serialize)]
struct CreateTokenInfo {
code: i16,
project: String,
device_id: String,
token: String,
msg: String,
}
#[derive(Deserialize)]
struct UpdateTokenStatus {
token: String,
enable: bool,
}
#[derive(Serialize)]
struct ProjectResponse {
code: i32,
projects: Vec<String>,
}
#[derive(Serialize)]
struct ProjectToken {
code: i32,
data: Vec<db::Authorize>,
}