212 lines
5.3 KiB
Rust
212 lines
5.3 KiB
Rust
use chrono::{Duration, Local, NaiveDateTime};
|
|
use serde::Serialize;
|
|
use std::{fs, path::Path};
|
|
use welds::{connections::sqlite::SqliteClient, prelude::*};
|
|
|
|
use crate::InsertArgs;
|
|
#[derive(WeldsModel, Clone, Serialize)]
|
|
#[welds(table = "authorize")]
|
|
pub struct Authorize {
|
|
#[welds(primary_key)]
|
|
pub id: i32,
|
|
pub project: String,
|
|
pub token: String,
|
|
pub device_id: String,
|
|
pub disable: i8,
|
|
pub expire: String,
|
|
pub insert_time: String,
|
|
}
|
|
|
|
#[derive(Debug, WeldsModel)]
|
|
pub struct Projects {
|
|
pub project: String,
|
|
}
|
|
|
|
/// 包装类,内部持有 SQLite 连接
|
|
pub struct Db {
|
|
client: SqliteClient,
|
|
}
|
|
|
|
impl Db {
|
|
/// 初始化:建目录 -> 建文件 -> 建表 -> 返回 Self
|
|
pub async fn new(
|
|
db_path: impl AsRef<str>,
|
|
) -> Result<Self, Box<dyn std::error::Error>> {
|
|
let db_path = db_path.as_ref();
|
|
|
|
let parent = Path::new(db_path).parent().unwrap();
|
|
if !parent.exists() {
|
|
fs::create_dir_all(parent)?;
|
|
}
|
|
if !Path::new(db_path).exists() {
|
|
fs::File::create(db_path)?;
|
|
}
|
|
|
|
let conn_str = format!("sqlite://{db_path}");
|
|
let client = welds::connections::sqlite::connect(&conn_str).await?;
|
|
|
|
client
|
|
.execute(
|
|
r#"
|
|
CREATE TABLE IF NOT EXISTS authorize (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
project TEXT NOT NULL,
|
|
token TEXT NOT NULL,
|
|
device_id TEXT NOT NULL,
|
|
disable INTEGER DEFAULT 1,
|
|
expire TEXT NOT NULL,
|
|
insert_time TEXT NOT NULL
|
|
)
|
|
"#,
|
|
&[],
|
|
)
|
|
.await?;
|
|
|
|
Ok(Db { client })
|
|
}
|
|
|
|
/// 通过project和device_id查询是否存在数据库
|
|
pub async fn exists_project(
|
|
&self,
|
|
project: &str,
|
|
device_id: &str,
|
|
) -> Result<Option<Authorize>, Box<dyn std::error::Error>> {
|
|
let row = Authorize::all()
|
|
.where_col(|a| a.project.equal(project))
|
|
.where_col(|a| a.device_id.equal(device_id))
|
|
.limit(1)
|
|
.run(&self.client)
|
|
.await?
|
|
.into_inners()
|
|
.into_iter()
|
|
.next();
|
|
Ok(row)
|
|
}
|
|
|
|
/// 判断 token 是否存在
|
|
pub async fn verify_token(
|
|
&self,
|
|
token: &str,
|
|
) -> Result<bool, Box<dyn std::error::Error>> {
|
|
let row = Authorize::where_col(|p| p.token.equal(token))
|
|
.limit(1)
|
|
.run(&self.client)
|
|
.await?
|
|
.into_iter()
|
|
.next();
|
|
match row {
|
|
Some(a) => {
|
|
// 判断token有效性
|
|
if a.disable == 0 {
|
|
return Ok(false);
|
|
}
|
|
// 判断截止时间
|
|
let expire_time =
|
|
NaiveDateTime::parse_from_str(&a.expire, "%Y-%m-%d %H:%M:%S")?;
|
|
let now_time = Local::now().naive_local();
|
|
Ok(now_time < expire_time)
|
|
}
|
|
// 如果没有找到,直接返回 false
|
|
None => return Ok(false),
|
|
}
|
|
}
|
|
|
|
/// 查询 token 的详细信息,返回 Option<Authorize> 如果存在
|
|
pub async fn get_token_info(
|
|
&self,
|
|
token: &str,
|
|
) -> Result<
|
|
Option<DbState<Authorize>>,
|
|
Box<dyn std::error::Error + Send + Sync>,
|
|
> {
|
|
let row = Authorize::where_col(|p| p.token.equal(token))
|
|
.limit(1)
|
|
.run(&self.client)
|
|
.await?
|
|
.into_iter()
|
|
.next();
|
|
|
|
Ok(row)
|
|
}
|
|
|
|
/// 插入新的授权数据
|
|
pub async fn insert_authorize(
|
|
&self,
|
|
args: InsertArgs,
|
|
) -> Result<i32, Box<dyn std::error::Error>> {
|
|
let mut auth = Authorize::new();
|
|
auth.project = args.project.to_string();
|
|
auth.token = args.token.to_string();
|
|
auth.device_id = args.device_id.to_string();
|
|
auth.disable = args.disable;
|
|
auth.expire = args.expire.to_string();
|
|
auth.insert_time = args.insert_time.to_string();
|
|
|
|
let _created = auth.save(&self.client).await?;
|
|
Ok(auth.id)
|
|
}
|
|
|
|
/// 禁用 Token
|
|
pub async fn update_token_state(
|
|
&self,
|
|
token: &str,
|
|
state: i8,
|
|
) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
|
|
let rows = Authorize::all()
|
|
.where_col(|a| a.token.equal(token))
|
|
.run(&self.client)
|
|
.await?;
|
|
if rows.is_empty() {
|
|
return Ok(false);
|
|
}
|
|
for mut row in rows {
|
|
row.disable = state;
|
|
row.save(&self.client).await?;
|
|
}
|
|
Ok(true)
|
|
}
|
|
|
|
/// 更新 Token 的有效期
|
|
pub async fn update_token_expiry(
|
|
&self,
|
|
token: &str,
|
|
days_to_add: i64,
|
|
) -> Result<
|
|
Option<DbState<Authorize>>,
|
|
Box<dyn std::error::Error + Send + Sync>,
|
|
> {
|
|
let auth_info = self.get_token_info(token).await?;
|
|
match auth_info {
|
|
Some(mut auth) => {
|
|
let current_expiry =
|
|
NaiveDateTime::parse_from_str(&auth.expire, "%Y-%m-%d %H:%M:%S")?;
|
|
let new_expiry = current_expiry + Duration::days(days_to_add);
|
|
|
|
auth.expire = new_expiry.format("%Y-%m-%d %H:%M:%S").to_string();
|
|
auth.save(&self.client).await?;
|
|
|
|
Ok(Some(auth))
|
|
}
|
|
None => Ok(None),
|
|
}
|
|
}
|
|
|
|
pub async fn get_all_project(
|
|
&self,
|
|
) -> Result<Vec<String>, Box<dyn std::error::Error + Send + Sync>> {
|
|
let rows = Authorize::all()
|
|
.select(|a| a.project)
|
|
.group_by(|a| a.project)
|
|
.run(&self.client)
|
|
.await?;
|
|
|
|
// 用Projects结构体来收集查询到的结果
|
|
let result: Vec<Projects> = rows.collect_into()?;
|
|
let mut projects:Vec<String> = Vec::new();
|
|
for row in result {
|
|
projects.push(row.project);
|
|
}
|
|
Ok(projects)
|
|
}
|
|
}
|