Add backend support for users

This commit is contained in:
Sami Abuzakuk
2025-11-01 16:05:34 +01:00
parent 16989ed518
commit 374558d30f
4 changed files with 309 additions and 65 deletions

View File

@@ -1,14 +1,38 @@
from datetime import datetime
from fastapi import FastAPI, Query
from fastapi.exceptions import HTTPException
from fastapi import FastAPI, Depends, HTTPException, status, Query
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from model import Log, SessionLocal, Script, Settings, Subscription, Notification
from model import Log, SessionLocal, Script, Settings, Subscription, Notification, User
from run_scripts import run_scripts, update_requirements, update_environment
import uvicorn
from passlib.context import CryptContext
import os
from model import ensure_default_setting
from auth import (
get_password_hash,
create_access_token,
authenticate_user,
get_current_user,
)
app = FastAPI()
# JWT settings
SECRET_KEY = os.getenv("SECRET_KEY", "")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
if SECRET_KEY == "":
raise ValueError("SECRET_KEY environment variable is not set")
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
ensure_default_setting()
# Update cors
app.add_middleware(
CORSMiddleware,
@@ -19,6 +43,17 @@ app.add_middleware(
)
# User registration/login models
class UserCreate(BaseModel):
username: str
password: str
class Token(BaseModel):
access_token: str
token_type: str
# Define Pydantic models
class ScriptBase(BaseModel):
name: str
@@ -52,6 +87,39 @@ def hello():
return {"message": "Welcome to the Project Monitor API"}
@app.post("/register", response_model=Token)
def register(user: UserCreate):
db = SessionLocal()
existing_user = db.query(User).filter(User.username == user.username).first()
if existing_user:
db.close()
raise HTTPException(status_code=400, detail="Username already registered")
hashed_password = get_password_hash(user.password)
new_user = User(username=user.username, password_hash=hashed_password)
db.add(new_user)
db.commit()
db.refresh(new_user)
access_token = create_access_token(data={"sub": new_user.username})
db.close()
return {"access_token": access_token, "token_type": "bearer"}
@app.post("/login", response_model=Token)
def login(form_data: OAuth2PasswordRequestForm = Depends()):
db = SessionLocal()
user = authenticate_user(db, form_data.username, form_data.password)
if not user:
db.close()
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token = create_access_token(data={"sub": user.username})
db.close()
return {"access_token": access_token, "token_type": "bearer"}
class SubscriptionCreate(BaseModel):
topic: str
@@ -67,9 +135,11 @@ class SubscriptionResponse(BaseModel):
# Subscriptions API Endpoints
@app.get("/subscriptions", response_model=list[SubscriptionResponse])
def list_subscriptions():
def list_subscriptions(current_user: User = Depends(get_current_user)):
db = SessionLocal()
subscriptions = db.query(Subscription).all()
subscriptions = (
db.query(Subscription).filter(Subscription.user_id == current_user.id).all()
)
# TODO: find a better way to do this
for subscription in subscriptions:
@@ -88,7 +158,9 @@ def list_subscriptions():
@app.get("/subscriptions/{subscription_id}", response_model=SubscriptionResponse)
def get_subscription(subscription_id: int):
def get_subscription(
subscription_id: int, current_user: User = Depends(get_current_user)
):
db = SessionLocal()
subscription = (
db.query(Subscription).filter(Subscription.id == subscription_id).first()
@@ -112,7 +184,9 @@ def get_subscription(subscription_id: int):
@app.post("/subscriptions")
def add_subscription(subscription: SubscriptionCreate):
def add_subscription(
subscription: SubscriptionCreate, current_user: User = Depends(get_current_user)
):
db = SessionLocal()
existing_subscription = (
db.query(Subscription).filter(Subscription.topic == subscription.topic).first()
@@ -120,7 +194,7 @@ def add_subscription(subscription: SubscriptionCreate):
if existing_subscription:
db.close()
raise HTTPException(status_code=400, detail="Subscription already exists")
new_subscription = Subscription(topic=subscription.topic)
new_subscription = Subscription(topic=subscription.topic, user_id=current_user.id)
db.add(new_subscription)
db.commit()
db.refresh(new_subscription)
@@ -129,7 +203,9 @@ def add_subscription(subscription: SubscriptionCreate):
@app.delete("/subscriptions/{subscription_id}")
def remove_subscription(subscription_id: int):
def remove_subscription(
subscription_id: int, current_user: User = Depends(get_current_user)
):
db = SessionLocal()
subscription = (
db.query(Subscription).filter(Subscription.id == subscription_id).first()
@@ -148,6 +224,7 @@ def list_subscription_notifications(
subscription_id: int,
limit: int = Query(20, ge=1, le=100),
offset: int = Query(0, ge=0),
current_user: User = Depends(get_current_user),
):
db = SessionLocal()
notifications = (
@@ -166,7 +243,7 @@ def list_subscription_notifications(
@app.get("/notifications")
def list_notifications():
def list_notifications(current_user: User = Depends(get_current_user)):
db = SessionLocal()
notifications = db.query(Notification).all()
db.close()
@@ -177,7 +254,9 @@ def list_notifications():
@app.delete("/notifications/{notification_id}")
def remove_notification(notification_id: int):
def remove_notification(
notification_id: int, current_user: User = Depends(get_current_user)
):
db = SessionLocal()
notification = (
db.query(Notification).filter(Notification.id == notification_id).first()
@@ -215,7 +294,11 @@ class NotificationResponse(NotificationCreate):
@app.put("/notifications/{notification_id}", response_model=NotificationResponse)
def update_notification(notification_id: int, notification: NotificationUpdate):
def update_notification(
notification_id: int,
notification: NotificationUpdate,
current_user: User = Depends(get_current_user),
):
db = SessionLocal()
existing_notification = (
db.query(Notification).filter(Notification.id == notification_id).first()
@@ -240,7 +323,9 @@ def update_notification(notification_id: int, notification: NotificationUpdate):
@app.post("/notifications", response_model=NotificationResponse)
def create_notification(notification: NotificationCreate):
def create_notification(
notification: NotificationCreate, current_user: User = Depends(get_current_user)
):
db = SessionLocal()
new_notification = Notification(
subscription_id=notification.subscription_id,
@@ -259,7 +344,6 @@ def create_notification(notification: NotificationCreate):
class SettingsBase(BaseModel):
requirements: str
environment: str
user: str
ntfy_url: str
@@ -274,18 +358,39 @@ class SettingsResponse(SettingsBase):
# Settings API Endpoints
@app.get("/settings", response_model=list[SettingsResponse])
def read_settings():
@app.get("/settings", response_model=SettingsResponse)
def read_settings(current_user: User = Depends(get_current_user)):
db = SessionLocal()
settings = db.query(Settings).all()
settings = db.query(Settings).filter(Settings.user_id == current_user.id).all()
if not settings:
# Add a default settings row for this user if not found
new_setting = Settings(
requirements="",
environment="",
user_id=current_user.id,
ntfy_url="https://ntfy.abzk.fr",
)
db.add(new_setting)
db.commit()
db.refresh(new_setting)
db.close()
return new_setting
if len(settings) > 1:
raise HTTPException(status_code=400, detail="Multiple settings found")
settings = settings[0]
db.close()
return settings
@app.post("/settings", response_model=SettingsResponse)
def create_setting(settings: SettingsBase):
def create_setting(
settings: SettingsBase, current_user: User = Depends(get_current_user)
):
db = SessionLocal()
new_setting = Settings(**settings.model_dump())
new_setting = Settings(**settings.model_dump(), user_id=current_user.id)
db.add(new_setting)
db.commit()
db.refresh(new_setting)
@@ -295,9 +400,13 @@ def create_setting(settings: SettingsBase):
@app.get("/settings/{settings_id}", response_model=SettingsResponse)
def read_setting(settings_id: int):
def read_setting(settings_id: int, current_user: User = Depends(get_current_user)):
db = SessionLocal()
setting = db.query(Settings).filter(Settings.id == settings_id).first()
setting = (
db.query(Settings)
.filter(Settings.id == settings_id, Settings.user_id == current_user.id)
.first()
)
db.close()
if not setting:
raise HTTPException(status_code=404, detail="Setting not found")
@@ -305,9 +414,17 @@ def read_setting(settings_id: int):
@app.put("/settings/{settings_id}", response_model=SettingsResponse)
def update_setting(settings_id: int, settings: SettingsUpdate):
def update_setting(
settings_id: int,
settings: SettingsUpdate,
current_user: User = Depends(get_current_user),
):
db = SessionLocal()
existing_setting = db.query(Settings).filter(Settings.id == settings_id).first()
existing_setting = (
db.query(Settings)
.filter(Settings.id == settings_id, Settings.user_id == current_user.id)
.first()
)
if not existing_setting:
raise HTTPException(status_code=404, detail="Setting not found")
@@ -330,17 +447,19 @@ def update_setting(settings_id: int, settings: SettingsUpdate):
@app.get("/script", response_model=list[ScriptResponse])
def read_scripts():
def read_scripts(current_user: User = Depends(get_current_user)):
db = SessionLocal()
scripts = db.query(Script).all()
scripts = db.query(Script).filter(Script.user_id == current_user.id).all()
db.close()
return scripts
@app.post("/script", response_model=ScriptResponse)
def create_script(script: ScriptCreate):
def create_script(script: ScriptCreate, current_user: User = Depends(get_current_user)):
db = SessionLocal()
new_script = Script(name=script.name, script_content=script.script_content)
new_script = Script(
name=script.name, script_content=script.script_content, user_id=current_user.id
)
db.add(new_script)
db.commit()
db.refresh(new_script)
@@ -349,7 +468,7 @@ def create_script(script: ScriptCreate):
@app.get("/script/{script_id}", response_model=ScriptResponse)
def read_script(script_id: int):
def read_script(script_id: int, current_user: User = Depends(get_current_user)):
db = SessionLocal()
script = db.query(Script).filter(Script.id == script_id).first()
db.close()
@@ -359,7 +478,7 @@ def read_script(script_id: int):
@app.delete("/script/{script_id}")
def delete_script(script_id: int):
def delete_script(script_id: int, current_user: User = Depends(get_current_user)):
db = SessionLocal()
script = db.query(Script).filter(Script.id == script_id).first()
if not script:
@@ -375,7 +494,9 @@ def delete_script(script_id: int):
@app.put("/script/{script_id}", response_model=ScriptResponse)
def update_script(script_id: int, script: ScriptUpdate):
def update_script(
script_id: int, script: ScriptUpdate, current_user: User = Depends(get_current_user)
):
db = SessionLocal()
existing_script = db.query(Script).filter(Script.id == script_id).first()
if not existing_script:
@@ -391,7 +512,7 @@ def update_script(script_id: int, script: ScriptUpdate):
@app.get("/script/{script_id}/log")
def get_script_logs(script_id: int):
def get_script_logs(script_id: int, current_user: User = Depends(get_current_user)):
db = SessionLocal()
logs = db.query(Log).filter(Log.script_id == script_id).all()
db.close()
@@ -399,7 +520,9 @@ def get_script_logs(script_id: int):
@app.post("/script/{script_id}/log")
def create_script_log(script_id: int, log: ScriptLogCreate):
def create_script_log(
script_id: int, log: ScriptLogCreate, current_user: User = Depends(get_current_user)
):
db = SessionLocal()
new_log = Log(
script_id=script_id,
@@ -415,7 +538,9 @@ def create_script_log(script_id: int, log: ScriptLogCreate):
@app.delete("/script/{script_id}/log/{log_id}")
def delete_script_log(script_id: int, log_id: int):
def delete_script_log(
script_id: int, log_id: int, current_user: User = Depends(get_current_user)
):
db = SessionLocal()
log = db.query(Log).filter(Log.id == log_id and Log.script_id == script_id).first()
if not log:
@@ -427,7 +552,7 @@ def delete_script_log(script_id: int, log_id: int):
@app.post("/script/{script_id}/execute")
def execute_script(script_id: int):
def execute_script(script_id: int, current_user: User = Depends(get_current_user)):
run_scripts([script_id])
return {"run_script": True}