feat: 增加令牌禁用、启用、续期功能并优化错误处理

- 增加禁用Token和更新Token有效期功能并优化返回类型
- 增加令牌禁用、启用和续期功能,并优化令牌信息查询的错误处理
This commit is contained in:
2025-08-19 15:08:07 +08:00
parent 3714ec6cd0
commit d0ccda69e2
2 changed files with 161 additions and 18 deletions

View File

@@ -1,4 +1,4 @@
use chrono::{Local, NaiveDateTime}; use chrono::{Duration, Local, NaiveDateTime};
use serde::Serialize; use serde::Serialize;
use std::{fs, path::Path}; use std::{fs, path::Path};
use welds::{connections::sqlite::SqliteClient, prelude::*}; use welds::{connections::sqlite::SqliteClient, prelude::*};
@@ -118,12 +118,14 @@ impl Db {
pub async fn get_token_info( pub async fn get_token_info(
&self, &self,
token: &str, token: &str,
) -> Result<Option<Authorize>, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<
Option<DbState<Authorize>>,
Box<dyn std::error::Error + Send + Sync>,
> {
let row = Authorize::where_col(|p| p.token.equal(token)) let row = Authorize::where_col(|p| p.token.equal(token))
.limit(1) .limit(1)
.run(&self.client) .run(&self.client)
.await? .await?
.into_inners()
.into_iter() .into_iter()
.next(); .next();
@@ -146,4 +148,49 @@ impl Db {
let _created = auth.save(&self.client).await?; let _created = auth.save(&self.client).await?;
Ok(auth.id) Ok(auth.id)
} }
/// 禁用 Token
pub async fn update_token_state(
&self,
token: &str,
state: i8,
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
let rows = Authorize::all()
.where_col(|a| a.token.equal(token))
.run(&self.client)
.await?;
if rows.is_empty() {
return Ok(false);
}
for mut row in rows {
row.disable = state;
row.save(&self.client).await?;
}
Ok(true)
}
/// 更新 Token 的有效期
pub async fn update_token_expiry(
&self,
token: &str,
days_to_add: i64,
) -> Result<
Option<DbState<Authorize>>,
Box<dyn std::error::Error + Send + Sync>,
> {
let auth_info = self.get_token_info(token).await?;
match auth_info {
Some(mut auth) => {
let current_expiry =
NaiveDateTime::parse_from_str(&auth.expire, "%Y-%m-%d %H:%M:%S")?;
let new_expiry = current_expiry + Duration::days(days_to_add);
auth.expire = new_expiry.format("%Y-%m-%d %H:%M:%S").to_string();
auth.save(&self.client).await?;
Ok(Some(auth))
}
None => Ok(None),
}
}
} }

View File

@@ -46,6 +46,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.route("/create", get(create_token)) .route("/create", get(create_token))
.route("/verify", get(verify_token)) .route("/verify", get(verify_token))
.route("/info", get(get_token_info)) .route("/info", get(get_token_info))
.route("/disable", get(disable_token))
.route("/enable", get(enable_token))
.route("/renewal", get(renewal_token))
.with_state(state); .with_state(state);
let listener = tokio::net::TcpListener::bind("127.0.0.1:3009") let listener = tokio::net::TcpListener::bind("127.0.0.1:3009")
@@ -134,31 +137,118 @@ async fn get_token_info(
State(state): State<AppState>, State(state): State<AppState>,
Query(args): Query<VerifyToken>, Query(args): Query<VerifyToken>,
) -> (StatusCode, Json<TokenResponse>) { ) -> (StatusCode, Json<TokenResponse>) {
let ruth_info = state.db.get_token_info(&args.token).await.unwrap(); let auth_info = state.db.get_token_info(&args.token).await.unwrap();
match auth_info {
Some(auth) => {
let v = auth.into_inner();
match ruth_info { (
Some(v) => ( StatusCode::OK,
StatusCode::OK, Json(TokenResponse::Success(TokenInfo {
Json(TokenResponse::Success(TokenInfo { code: 200,
code: 200, project: v.project,
project: v.project, token: v.token,
token: v.token, device_id: v.device_id,
device_id: v.device_id, disable: v.disable,
disable: v.disable, expire: v.expire,
expire: v.expire, insert_time: v.insert_time,
insert_time: v.insert_time, })),
})), )
), }
None => ( None => (
StatusCode::OK, StatusCode::OK,
Json(TokenResponse::Error(QueryError { Json(TokenResponse::Error(QueryError {
code: 200, code: 404,
msg: "未查询token相关信息".to_owned(), msg: "未查询token相关信息".to_owned(),
})), })),
), ),
} }
} }
async fn disable_token(
State(state): State<AppState>,
Query(args): Query<VerifyToken>,
) -> (StatusCode, Json<VerifyResult>) {
let result = state.db.update_token_state(&args.token, 0).await.unwrap();
if result {
(
StatusCode::OK,
Json(VerifyResult {
code: 200,
msg: "操作成功".to_owned(),
}),
)
} else {
(
StatusCode::OK,
Json(VerifyResult {
code: 404,
msg: "操作失败".to_owned(),
}),
)
}
}
async fn enable_token(
State(state): State<AppState>,
Query(args): Query<VerifyToken>,
) -> (StatusCode, Json<VerifyResult>) {
let result = state.db.update_token_state(&args.token, 1).await.unwrap();
if result {
(
StatusCode::OK,
Json(VerifyResult {
code: 200,
msg: "操作成功".to_owned(),
}),
)
} else {
(
StatusCode::OK,
Json(VerifyResult {
code: 404,
msg: "操作失败".to_owned(),
}),
)
}
}
async fn renewal_token(
State(state): State<AppState>,
Query(args): Query<RenewalToken>,
) -> (StatusCode, Json<TokenResponse>) {
let auth_info = state
.db
.update_token_expiry(&args.token, args.days)
.await
.unwrap();
match auth_info {
Some(auth) => {
let v = auth.into_inner();
(
StatusCode::OK,
Json(TokenResponse::Success(TokenInfo {
code: 200,
project: v.project,
token: v.token,
device_id: v.device_id,
disable: v.disable,
expire: v.expire,
insert_time: v.insert_time,
})),
)
}
None => (
StatusCode::OK,
Json(TokenResponse::Error(QueryError {
code: 404,
msg: "操作失败".to_owned(),
})),
),
}
}
fn get_current_datetime() -> String { fn get_current_datetime() -> String {
Local::now().format("%Y-%m-%d %H:%M:%S").to_string() Local::now().format("%Y-%m-%d %H:%M:%S").to_string()
} }
@@ -224,3 +314,9 @@ struct CreateTokenInfo {
token: String, token: String,
msg: String, msg: String,
} }
#[derive(Deserialize)]
struct RenewalToken {
token: String,
days: i64,
}