From 374558d30f1788c4b3421cfd85cf614847b5cbaa Mon Sep 17 00:00:00 2001 From: Sami Abuzakuk Date: Sat, 1 Nov 2025 16:05:34 +0100 Subject: [PATCH] Add backend support for users --- backend/auth.py | 67 ++++++++++++ backend/backend.py | 193 +++++++++++++++++++++++++++++------ backend/get_notifications.py | 50 +++++---- backend/model.py | 64 ++++++++++-- 4 files changed, 309 insertions(+), 65 deletions(-) create mode 100644 backend/auth.py diff --git a/backend/auth.py b/backend/auth.py new file mode 100644 index 0000000..945b12b --- /dev/null +++ b/backend/auth.py @@ -0,0 +1,67 @@ +from datetime import datetime, timedelta +from passlib.context import CryptContext +from jose import JWTError, jwt +import os +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from model import User, SessionLocal + +# JWT settings +SECRET_KEY = os.getenv("SECRET_KEY", "") +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 60 + +if SECRET_KEY == "": + raise ValueError("SECRET_KEY environment variable is not set") + +pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") + + +def verify_password(plain_password, hashed_password): + return pwd_context.verify(plain_password, hashed_password) + + +def get_password_hash(password): + return pwd_context.hash(password) + + +def create_access_token(data: dict, expires_delta: timedelta | None = None): + to_encode = data.copy() + expire = datetime.utcnow() + ( + expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + ) + to_encode.update({"exp": expire}) + return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + + +def get_user(db, username: str): + return db.query(User).filter(User.username == username).first() + + +def authenticate_user(db, username: str, password: str): + user = get_user(db, username) + if not user or not verify_password(password, user.password_hash): + return None + return user + + +def get_current_user(token: str = Depends(oauth2_scheme)): + db = SessionLocal() + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + username: str | None = payload.get("sub") + if username is None: + raise credentials_exception + except JWTError: + raise credentials_exception + user = get_user(db, username) + db.close() + if user is None: + raise credentials_exception + return user diff --git a/backend/backend.py b/backend/backend.py index b1fbf66..2053498 100644 --- a/backend/backend.py +++ b/backend/backend.py @@ -1,14 +1,38 @@ from datetime import datetime -from fastapi import FastAPI, Query -from fastapi.exceptions import HTTPException +from fastapi import FastAPI, Depends, HTTPException, status, Query +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel -from model import Log, SessionLocal, Script, Settings, Subscription, Notification +from model import Log, SessionLocal, Script, Settings, Subscription, Notification, User from run_scripts import run_scripts, update_requirements, update_environment import uvicorn +from passlib.context import CryptContext +import os +from model import ensure_default_setting + +from auth import ( + get_password_hash, + create_access_token, + authenticate_user, + get_current_user, +) app = FastAPI() +# JWT settings +SECRET_KEY = os.getenv("SECRET_KEY", "") +ALGORITHM = "HS256" +ACCESS_TOKEN_EXPIRE_MINUTES = 60 + +if SECRET_KEY == "": + raise ValueError("SECRET_KEY environment variable is not set") + +pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login") + +ensure_default_setting() + + # Update cors app.add_middleware( CORSMiddleware, @@ -19,6 +43,17 @@ app.add_middleware( ) +# User registration/login models +class UserCreate(BaseModel): + username: str + password: str + + +class Token(BaseModel): + access_token: str + token_type: str + + # Define Pydantic models class ScriptBase(BaseModel): name: str @@ -52,6 +87,39 @@ def hello(): return {"message": "Welcome to the Project Monitor API"} +@app.post("/register", response_model=Token) +def register(user: UserCreate): + db = SessionLocal() + existing_user = db.query(User).filter(User.username == user.username).first() + if existing_user: + db.close() + raise HTTPException(status_code=400, detail="Username already registered") + hashed_password = get_password_hash(user.password) + new_user = User(username=user.username, password_hash=hashed_password) + db.add(new_user) + db.commit() + db.refresh(new_user) + access_token = create_access_token(data={"sub": new_user.username}) + db.close() + return {"access_token": access_token, "token_type": "bearer"} + + +@app.post("/login", response_model=Token) +def login(form_data: OAuth2PasswordRequestForm = Depends()): + db = SessionLocal() + user = authenticate_user(db, form_data.username, form_data.password) + if not user: + db.close() + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + access_token = create_access_token(data={"sub": user.username}) + db.close() + return {"access_token": access_token, "token_type": "bearer"} + + class SubscriptionCreate(BaseModel): topic: str @@ -67,9 +135,11 @@ class SubscriptionResponse(BaseModel): # Subscriptions API Endpoints @app.get("/subscriptions", response_model=list[SubscriptionResponse]) -def list_subscriptions(): +def list_subscriptions(current_user: User = Depends(get_current_user)): db = SessionLocal() - subscriptions = db.query(Subscription).all() + subscriptions = ( + db.query(Subscription).filter(Subscription.user_id == current_user.id).all() + ) # TODO: find a better way to do this for subscription in subscriptions: @@ -88,7 +158,9 @@ def list_subscriptions(): @app.get("/subscriptions/{subscription_id}", response_model=SubscriptionResponse) -def get_subscription(subscription_id: int): +def get_subscription( + subscription_id: int, current_user: User = Depends(get_current_user) +): db = SessionLocal() subscription = ( db.query(Subscription).filter(Subscription.id == subscription_id).first() @@ -112,7 +184,9 @@ def get_subscription(subscription_id: int): @app.post("/subscriptions") -def add_subscription(subscription: SubscriptionCreate): +def add_subscription( + subscription: SubscriptionCreate, current_user: User = Depends(get_current_user) +): db = SessionLocal() existing_subscription = ( db.query(Subscription).filter(Subscription.topic == subscription.topic).first() @@ -120,7 +194,7 @@ def add_subscription(subscription: SubscriptionCreate): if existing_subscription: db.close() raise HTTPException(status_code=400, detail="Subscription already exists") - new_subscription = Subscription(topic=subscription.topic) + new_subscription = Subscription(topic=subscription.topic, user_id=current_user.id) db.add(new_subscription) db.commit() db.refresh(new_subscription) @@ -129,7 +203,9 @@ def add_subscription(subscription: SubscriptionCreate): @app.delete("/subscriptions/{subscription_id}") -def remove_subscription(subscription_id: int): +def remove_subscription( + subscription_id: int, current_user: User = Depends(get_current_user) +): db = SessionLocal() subscription = ( db.query(Subscription).filter(Subscription.id == subscription_id).first() @@ -148,6 +224,7 @@ def list_subscription_notifications( subscription_id: int, limit: int = Query(20, ge=1, le=100), offset: int = Query(0, ge=0), + current_user: User = Depends(get_current_user), ): db = SessionLocal() notifications = ( @@ -166,7 +243,7 @@ def list_subscription_notifications( @app.get("/notifications") -def list_notifications(): +def list_notifications(current_user: User = Depends(get_current_user)): db = SessionLocal() notifications = db.query(Notification).all() db.close() @@ -177,7 +254,9 @@ def list_notifications(): @app.delete("/notifications/{notification_id}") -def remove_notification(notification_id: int): +def remove_notification( + notification_id: int, current_user: User = Depends(get_current_user) +): db = SessionLocal() notification = ( db.query(Notification).filter(Notification.id == notification_id).first() @@ -215,7 +294,11 @@ class NotificationResponse(NotificationCreate): @app.put("/notifications/{notification_id}", response_model=NotificationResponse) -def update_notification(notification_id: int, notification: NotificationUpdate): +def update_notification( + notification_id: int, + notification: NotificationUpdate, + current_user: User = Depends(get_current_user), +): db = SessionLocal() existing_notification = ( db.query(Notification).filter(Notification.id == notification_id).first() @@ -240,7 +323,9 @@ def update_notification(notification_id: int, notification: NotificationUpdate): @app.post("/notifications", response_model=NotificationResponse) -def create_notification(notification: NotificationCreate): +def create_notification( + notification: NotificationCreate, current_user: User = Depends(get_current_user) +): db = SessionLocal() new_notification = Notification( subscription_id=notification.subscription_id, @@ -259,7 +344,6 @@ def create_notification(notification: NotificationCreate): class SettingsBase(BaseModel): requirements: str environment: str - user: str ntfy_url: str @@ -274,18 +358,39 @@ class SettingsResponse(SettingsBase): # Settings API Endpoints -@app.get("/settings", response_model=list[SettingsResponse]) -def read_settings(): +@app.get("/settings", response_model=SettingsResponse) +def read_settings(current_user: User = Depends(get_current_user)): db = SessionLocal() - settings = db.query(Settings).all() + settings = db.query(Settings).filter(Settings.user_id == current_user.id).all() + if not settings: + # Add a default settings row for this user if not found + new_setting = Settings( + requirements="", + environment="", + user_id=current_user.id, + ntfy_url="https://ntfy.abzk.fr", + ) + db.add(new_setting) + db.commit() + db.refresh(new_setting) + db.close() + return new_setting + + if len(settings) > 1: + raise HTTPException(status_code=400, detail="Multiple settings found") + + settings = settings[0] + db.close() return settings @app.post("/settings", response_model=SettingsResponse) -def create_setting(settings: SettingsBase): +def create_setting( + settings: SettingsBase, current_user: User = Depends(get_current_user) +): db = SessionLocal() - new_setting = Settings(**settings.model_dump()) + new_setting = Settings(**settings.model_dump(), user_id=current_user.id) db.add(new_setting) db.commit() db.refresh(new_setting) @@ -295,9 +400,13 @@ def create_setting(settings: SettingsBase): @app.get("/settings/{settings_id}", response_model=SettingsResponse) -def read_setting(settings_id: int): +def read_setting(settings_id: int, current_user: User = Depends(get_current_user)): db = SessionLocal() - setting = db.query(Settings).filter(Settings.id == settings_id).first() + setting = ( + db.query(Settings) + .filter(Settings.id == settings_id, Settings.user_id == current_user.id) + .first() + ) db.close() if not setting: raise HTTPException(status_code=404, detail="Setting not found") @@ -305,9 +414,17 @@ def read_setting(settings_id: int): @app.put("/settings/{settings_id}", response_model=SettingsResponse) -def update_setting(settings_id: int, settings: SettingsUpdate): +def update_setting( + settings_id: int, + settings: SettingsUpdate, + current_user: User = Depends(get_current_user), +): db = SessionLocal() - existing_setting = db.query(Settings).filter(Settings.id == settings_id).first() + existing_setting = ( + db.query(Settings) + .filter(Settings.id == settings_id, Settings.user_id == current_user.id) + .first() + ) if not existing_setting: raise HTTPException(status_code=404, detail="Setting not found") @@ -330,17 +447,19 @@ def update_setting(settings_id: int, settings: SettingsUpdate): @app.get("/script", response_model=list[ScriptResponse]) -def read_scripts(): +def read_scripts(current_user: User = Depends(get_current_user)): db = SessionLocal() - scripts = db.query(Script).all() + scripts = db.query(Script).filter(Script.user_id == current_user.id).all() db.close() return scripts @app.post("/script", response_model=ScriptResponse) -def create_script(script: ScriptCreate): +def create_script(script: ScriptCreate, current_user: User = Depends(get_current_user)): db = SessionLocal() - new_script = Script(name=script.name, script_content=script.script_content) + new_script = Script( + name=script.name, script_content=script.script_content, user_id=current_user.id + ) db.add(new_script) db.commit() db.refresh(new_script) @@ -349,7 +468,7 @@ def create_script(script: ScriptCreate): @app.get("/script/{script_id}", response_model=ScriptResponse) -def read_script(script_id: int): +def read_script(script_id: int, current_user: User = Depends(get_current_user)): db = SessionLocal() script = db.query(Script).filter(Script.id == script_id).first() db.close() @@ -359,7 +478,7 @@ def read_script(script_id: int): @app.delete("/script/{script_id}") -def delete_script(script_id: int): +def delete_script(script_id: int, current_user: User = Depends(get_current_user)): db = SessionLocal() script = db.query(Script).filter(Script.id == script_id).first() if not script: @@ -375,7 +494,9 @@ def delete_script(script_id: int): @app.put("/script/{script_id}", response_model=ScriptResponse) -def update_script(script_id: int, script: ScriptUpdate): +def update_script( + script_id: int, script: ScriptUpdate, current_user: User = Depends(get_current_user) +): db = SessionLocal() existing_script = db.query(Script).filter(Script.id == script_id).first() if not existing_script: @@ -391,7 +512,7 @@ def update_script(script_id: int, script: ScriptUpdate): @app.get("/script/{script_id}/log") -def get_script_logs(script_id: int): +def get_script_logs(script_id: int, current_user: User = Depends(get_current_user)): db = SessionLocal() logs = db.query(Log).filter(Log.script_id == script_id).all() db.close() @@ -399,7 +520,9 @@ def get_script_logs(script_id: int): @app.post("/script/{script_id}/log") -def create_script_log(script_id: int, log: ScriptLogCreate): +def create_script_log( + script_id: int, log: ScriptLogCreate, current_user: User = Depends(get_current_user) +): db = SessionLocal() new_log = Log( script_id=script_id, @@ -415,7 +538,9 @@ def create_script_log(script_id: int, log: ScriptLogCreate): @app.delete("/script/{script_id}/log/{log_id}") -def delete_script_log(script_id: int, log_id: int): +def delete_script_log( + script_id: int, log_id: int, current_user: User = Depends(get_current_user) +): db = SessionLocal() log = db.query(Log).filter(Log.id == log_id and Log.script_id == script_id).first() if not log: @@ -427,7 +552,7 @@ def delete_script_log(script_id: int, log_id: int): @app.post("/script/{script_id}/execute") -def execute_script(script_id: int): +def execute_script(script_id: int, current_user: User = Depends(get_current_user)): run_scripts([script_id]) return {"run_script": True} diff --git a/backend/get_notifications.py b/backend/get_notifications.py index 271731a..5caa6bc 100644 --- a/backend/get_notifications.py +++ b/backend/get_notifications.py @@ -5,7 +5,6 @@ from model import SessionLocal, Subscription, Settings, Notification import json # Constants - NTFY_TOKEN = os.getenv("NTFY_TOKEN") @@ -34,14 +33,11 @@ def fetch_ntfy_notifications(base_url, subscriptions): notifications.append(notification) print(f"Fetched {len(notifications)} notifications") - print(notifications) - return notifications def save_notifications_to_db(notifications, topic_to_subscription, db): """Save the fetched notifications to the database and update last_message_id.""" - db = SessionLocal() last_message_ids = {} for notification in notifications: topic = notification["topic"] @@ -67,33 +63,26 @@ def save_notifications_to_db(notifications, topic_to_subscription, db): if subscription: subscription.last_message_id = message_id db.commit() - db.close() -def main(): - """Main function to fetch and save notifications.""" - db = SessionLocal() - - # Get the ntfy base URL from settings - settings = db.query(Settings).filter(Settings.user == "default").first() - if not settings: - print("Default user settings not found.") - return - - ntfy_url = settings.ntfy_url +def process_user_notifications(user_settings, db): + """Process notifications for a specific user's subscriptions.""" + ntfy_url = user_settings.ntfy_url if not ntfy_url: - print("Ntfy URL not found in settings.") + print(f"Ntfy URL not found for user ID {user_settings.user_id}. Skipping...") return - # Get all subscribed topics - subscriptions = db.query(Subscription).all() + # Get all subscriptions for the user + subscriptions = ( + db.query(Subscription) + .filter(Subscription.user_id == user_settings.user_id) + .all() + ) topic_to_subscription = { subscription.topic: subscription.id for subscription in subscriptions } - db.close() - # Fetch notifications from ntfy.sh notifications = fetch_ntfy_notifications(ntfy_url, subscriptions) @@ -101,5 +90,24 @@ def main(): save_notifications_to_db(notifications, topic_to_subscription, db) +def main(): + """Main function to fetch and save notifications for all users.""" + db = SessionLocal() + + # Get all user settings + user_settings_list = db.query(Settings).all() + + if not user_settings_list: + print("No user settings found.") + return + + # Process notifications for each user + for user_settings in user_settings_list: + print(f"Processing notifications for user ID {user_settings.user_id}") + process_user_notifications(user_settings, db) + + db.close() + + if __name__ == "__main__": main() diff --git a/backend/model.py b/backend/model.py index 38384f3..62beabd 100644 --- a/backend/model.py +++ b/backend/model.py @@ -1,9 +1,12 @@ from sqlalchemy import create_engine, Column, Integer, String, Text, ForeignKey, Boolean +from sqlalchemy.sql.sqltypes import DateTime from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import sessionmaker from sqlalchemy.sql.functions import func from sqlalchemy.sql.sqltypes import DateTime import os +import secrets +from passlib.context import CryptContext # Initialize the database DATABASE_URL = os.getenv("DATABASE_URL") @@ -17,7 +20,15 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() -# Define the table model +class User(Base): + __tablename__ = "users" + + id = Column(Integer, primary_key=True, index=True) + username = Column(String(64), unique=True, nullable=False, index=True) + password_hash = Column(String(128), nullable=False) + created_at = Column( + DateTime(timezone=True), nullable=False, server_default=func.now() + ) class Script(Base): @@ -30,6 +41,9 @@ class Script(Base): created_at = Column( DateTime(timezone=True), nullable=False, server_default=func.now() ) + user_id = Column( + Integer, ForeignKey("users.id", name="fk_script_user_id"), nullable=False + ) class Log(Base): @@ -43,7 +57,9 @@ class Log(Base): DateTime(timezone=True), nullable=False, server_default=func.now() ) - script_id = Column(Integer, ForeignKey("scripts.id"), nullable=False) + script_id = Column( + Integer, ForeignKey("scripts.id", name="fk_log_script_id"), nullable=False + ) class Settings(Base): @@ -52,8 +68,10 @@ class Settings(Base): id = Column(Integer, primary_key=True, index=True) requirements = Column(String, nullable=False) environment = Column(String, nullable=False) - user = Column(String, nullable=False) ntfy_url = Column(String, nullable=True) + user_id = Column( + Integer, ForeignKey("users.id", name="fk_user_settings_user_id"), nullable=False + ) class Subscription(Base): @@ -65,6 +83,9 @@ class Subscription(Base): created_at = Column( DateTime(timezone=True), nullable=False, server_default=func.now() ) + user_id = Column( + Integer, ForeignKey("users.id", name="fk_subscription_user_id"), nullable=False + ) class Notification(Base): @@ -77,7 +98,11 @@ class Notification(Base): viewed = Column(Boolean, default=False) sent = Column(Boolean, default=False) - subscription_id = Column(Integer, ForeignKey("subscriptions.id"), nullable=False) + subscription_id = Column( + Integer, + ForeignKey("subscriptions.id", name="fk_notification_subscription_id"), + nullable=False, + ) created_at = Column( DateTime(timezone=True), nullable=False, server_default=func.now() ) @@ -87,20 +112,39 @@ class Notification(Base): Base.metadata.create_all(bind=engine) -# Ensure a default setting line exists +# Ensure a default admin user exists def ensure_default_setting(): db = SessionLocal() - default_setting = db.query(Settings).filter(Settings.user == "default").first() + admin_user = db.query(User).filter(User.username == "admin").first() + if not admin_user: + pwd_context = CryptContext(schemes=["argon2"], deprecated="auto") + random_password = secrets.token_urlsafe(12) + password_hash = pwd_context.hash(random_password) + admin_user = User(username="admin", password_hash=password_hash) + db.add(admin_user) + db.commit() + print( + f"Default admin user created. Username: admin, Password: {random_password}" + ) + # Refresh to get admin_user.id + db.refresh(admin_user) + # Set all rows with null user_id in Script and Subscription to admin user id + db.query(Script).filter(Script.user_id is None).update({"user_id": admin_user.id}) + db.query(Subscription).filter(Subscription.user_id is None).update( + {"user_id": admin_user.id} + ) + db.commit() + + default_setting = ( + db.query(Settings).filter(Settings.user_id == admin_user.id).first() + ) if not default_setting: new_setting = Settings( requirements="", environment="", - user="default", + user_id=admin_user.id, ntfy_url="https://ntfy.abzk.fr", ) db.add(new_setting) db.commit() db.close() - - -ensure_default_setting()