Add backend support for users

This commit is contained in:
Sami Abuzakuk
2025-11-01 16:05:34 +01:00
parent 16989ed518
commit 374558d30f
4 changed files with 309 additions and 65 deletions

67
backend/auth.py Normal file
View File

@@ -0,0 +1,67 @@
from datetime import datetime, timedelta
from passlib.context import CryptContext
from jose import JWTError, jwt
import os
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from model import User, SessionLocal
# JWT settings
SECRET_KEY = os.getenv("SECRET_KEY", "")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
if SECRET_KEY == "":
raise ValueError("SECRET_KEY environment variable is not set")
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
def verify_password(plain_password, hashed_password):
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password):
return pwd_context.hash(password)
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
expire = datetime.utcnow() + (
expires_delta or timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
)
to_encode.update({"exp": expire})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
def get_user(db, username: str):
return db.query(User).filter(User.username == username).first()
def authenticate_user(db, username: str, password: str):
user = get_user(db, username)
if not user or not verify_password(password, user.password_hash):
return None
return user
def get_current_user(token: str = Depends(oauth2_scheme)):
db = SessionLocal()
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
username: str | None = payload.get("sub")
if username is None:
raise credentials_exception
except JWTError:
raise credentials_exception
user = get_user(db, username)
db.close()
if user is None:
raise credentials_exception
return user

View File

@@ -1,14 +1,38 @@
from datetime import datetime from datetime import datetime
from fastapi import FastAPI, Query from fastapi import FastAPI, Depends, HTTPException, status, Query
from fastapi.exceptions import HTTPException from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel from pydantic import BaseModel
from model import Log, SessionLocal, Script, Settings, Subscription, Notification from model import Log, SessionLocal, Script, Settings, Subscription, Notification, User
from run_scripts import run_scripts, update_requirements, update_environment from run_scripts import run_scripts, update_requirements, update_environment
import uvicorn import uvicorn
from passlib.context import CryptContext
import os
from model import ensure_default_setting
from auth import (
get_password_hash,
create_access_token,
authenticate_user,
get_current_user,
)
app = FastAPI() app = FastAPI()
# JWT settings
SECRET_KEY = os.getenv("SECRET_KEY", "")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60
if SECRET_KEY == "":
raise ValueError("SECRET_KEY environment variable is not set")
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="login")
ensure_default_setting()
# Update cors # Update cors
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@@ -19,6 +43,17 @@ app.add_middleware(
) )
# User registration/login models
class UserCreate(BaseModel):
username: str
password: str
class Token(BaseModel):
access_token: str
token_type: str
# Define Pydantic models # Define Pydantic models
class ScriptBase(BaseModel): class ScriptBase(BaseModel):
name: str name: str
@@ -52,6 +87,39 @@ def hello():
return {"message": "Welcome to the Project Monitor API"} return {"message": "Welcome to the Project Monitor API"}
@app.post("/register", response_model=Token)
def register(user: UserCreate):
db = SessionLocal()
existing_user = db.query(User).filter(User.username == user.username).first()
if existing_user:
db.close()
raise HTTPException(status_code=400, detail="Username already registered")
hashed_password = get_password_hash(user.password)
new_user = User(username=user.username, password_hash=hashed_password)
db.add(new_user)
db.commit()
db.refresh(new_user)
access_token = create_access_token(data={"sub": new_user.username})
db.close()
return {"access_token": access_token, "token_type": "bearer"}
@app.post("/login", response_model=Token)
def login(form_data: OAuth2PasswordRequestForm = Depends()):
db = SessionLocal()
user = authenticate_user(db, form_data.username, form_data.password)
if not user:
db.close()
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token = create_access_token(data={"sub": user.username})
db.close()
return {"access_token": access_token, "token_type": "bearer"}
class SubscriptionCreate(BaseModel): class SubscriptionCreate(BaseModel):
topic: str topic: str
@@ -67,9 +135,11 @@ class SubscriptionResponse(BaseModel):
# Subscriptions API Endpoints # Subscriptions API Endpoints
@app.get("/subscriptions", response_model=list[SubscriptionResponse]) @app.get("/subscriptions", response_model=list[SubscriptionResponse])
def list_subscriptions(): def list_subscriptions(current_user: User = Depends(get_current_user)):
db = SessionLocal() db = SessionLocal()
subscriptions = db.query(Subscription).all() subscriptions = (
db.query(Subscription).filter(Subscription.user_id == current_user.id).all()
)
# TODO: find a better way to do this # TODO: find a better way to do this
for subscription in subscriptions: for subscription in subscriptions:
@@ -88,7 +158,9 @@ def list_subscriptions():
@app.get("/subscriptions/{subscription_id}", response_model=SubscriptionResponse) @app.get("/subscriptions/{subscription_id}", response_model=SubscriptionResponse)
def get_subscription(subscription_id: int): def get_subscription(
subscription_id: int, current_user: User = Depends(get_current_user)
):
db = SessionLocal() db = SessionLocal()
subscription = ( subscription = (
db.query(Subscription).filter(Subscription.id == subscription_id).first() db.query(Subscription).filter(Subscription.id == subscription_id).first()
@@ -112,7 +184,9 @@ def get_subscription(subscription_id: int):
@app.post("/subscriptions") @app.post("/subscriptions")
def add_subscription(subscription: SubscriptionCreate): def add_subscription(
subscription: SubscriptionCreate, current_user: User = Depends(get_current_user)
):
db = SessionLocal() db = SessionLocal()
existing_subscription = ( existing_subscription = (
db.query(Subscription).filter(Subscription.topic == subscription.topic).first() db.query(Subscription).filter(Subscription.topic == subscription.topic).first()
@@ -120,7 +194,7 @@ def add_subscription(subscription: SubscriptionCreate):
if existing_subscription: if existing_subscription:
db.close() db.close()
raise HTTPException(status_code=400, detail="Subscription already exists") raise HTTPException(status_code=400, detail="Subscription already exists")
new_subscription = Subscription(topic=subscription.topic) new_subscription = Subscription(topic=subscription.topic, user_id=current_user.id)
db.add(new_subscription) db.add(new_subscription)
db.commit() db.commit()
db.refresh(new_subscription) db.refresh(new_subscription)
@@ -129,7 +203,9 @@ def add_subscription(subscription: SubscriptionCreate):
@app.delete("/subscriptions/{subscription_id}") @app.delete("/subscriptions/{subscription_id}")
def remove_subscription(subscription_id: int): def remove_subscription(
subscription_id: int, current_user: User = Depends(get_current_user)
):
db = SessionLocal() db = SessionLocal()
subscription = ( subscription = (
db.query(Subscription).filter(Subscription.id == subscription_id).first() db.query(Subscription).filter(Subscription.id == subscription_id).first()
@@ -148,6 +224,7 @@ def list_subscription_notifications(
subscription_id: int, subscription_id: int,
limit: int = Query(20, ge=1, le=100), limit: int = Query(20, ge=1, le=100),
offset: int = Query(0, ge=0), offset: int = Query(0, ge=0),
current_user: User = Depends(get_current_user),
): ):
db = SessionLocal() db = SessionLocal()
notifications = ( notifications = (
@@ -166,7 +243,7 @@ def list_subscription_notifications(
@app.get("/notifications") @app.get("/notifications")
def list_notifications(): def list_notifications(current_user: User = Depends(get_current_user)):
db = SessionLocal() db = SessionLocal()
notifications = db.query(Notification).all() notifications = db.query(Notification).all()
db.close() db.close()
@@ -177,7 +254,9 @@ def list_notifications():
@app.delete("/notifications/{notification_id}") @app.delete("/notifications/{notification_id}")
def remove_notification(notification_id: int): def remove_notification(
notification_id: int, current_user: User = Depends(get_current_user)
):
db = SessionLocal() db = SessionLocal()
notification = ( notification = (
db.query(Notification).filter(Notification.id == notification_id).first() db.query(Notification).filter(Notification.id == notification_id).first()
@@ -215,7 +294,11 @@ class NotificationResponse(NotificationCreate):
@app.put("/notifications/{notification_id}", response_model=NotificationResponse) @app.put("/notifications/{notification_id}", response_model=NotificationResponse)
def update_notification(notification_id: int, notification: NotificationUpdate): def update_notification(
notification_id: int,
notification: NotificationUpdate,
current_user: User = Depends(get_current_user),
):
db = SessionLocal() db = SessionLocal()
existing_notification = ( existing_notification = (
db.query(Notification).filter(Notification.id == notification_id).first() db.query(Notification).filter(Notification.id == notification_id).first()
@@ -240,7 +323,9 @@ def update_notification(notification_id: int, notification: NotificationUpdate):
@app.post("/notifications", response_model=NotificationResponse) @app.post("/notifications", response_model=NotificationResponse)
def create_notification(notification: NotificationCreate): def create_notification(
notification: NotificationCreate, current_user: User = Depends(get_current_user)
):
db = SessionLocal() db = SessionLocal()
new_notification = Notification( new_notification = Notification(
subscription_id=notification.subscription_id, subscription_id=notification.subscription_id,
@@ -259,7 +344,6 @@ def create_notification(notification: NotificationCreate):
class SettingsBase(BaseModel): class SettingsBase(BaseModel):
requirements: str requirements: str
environment: str environment: str
user: str
ntfy_url: str ntfy_url: str
@@ -274,18 +358,39 @@ class SettingsResponse(SettingsBase):
# Settings API Endpoints # Settings API Endpoints
@app.get("/settings", response_model=list[SettingsResponse]) @app.get("/settings", response_model=SettingsResponse)
def read_settings(): def read_settings(current_user: User = Depends(get_current_user)):
db = SessionLocal() db = SessionLocal()
settings = db.query(Settings).all() settings = db.query(Settings).filter(Settings.user_id == current_user.id).all()
if not settings:
# Add a default settings row for this user if not found
new_setting = Settings(
requirements="",
environment="",
user_id=current_user.id,
ntfy_url="https://ntfy.abzk.fr",
)
db.add(new_setting)
db.commit()
db.refresh(new_setting)
db.close()
return new_setting
if len(settings) > 1:
raise HTTPException(status_code=400, detail="Multiple settings found")
settings = settings[0]
db.close() db.close()
return settings return settings
@app.post("/settings", response_model=SettingsResponse) @app.post("/settings", response_model=SettingsResponse)
def create_setting(settings: SettingsBase): def create_setting(
settings: SettingsBase, current_user: User = Depends(get_current_user)
):
db = SessionLocal() db = SessionLocal()
new_setting = Settings(**settings.model_dump()) new_setting = Settings(**settings.model_dump(), user_id=current_user.id)
db.add(new_setting) db.add(new_setting)
db.commit() db.commit()
db.refresh(new_setting) db.refresh(new_setting)
@@ -295,9 +400,13 @@ def create_setting(settings: SettingsBase):
@app.get("/settings/{settings_id}", response_model=SettingsResponse) @app.get("/settings/{settings_id}", response_model=SettingsResponse)
def read_setting(settings_id: int): def read_setting(settings_id: int, current_user: User = Depends(get_current_user)):
db = SessionLocal() db = SessionLocal()
setting = db.query(Settings).filter(Settings.id == settings_id).first() setting = (
db.query(Settings)
.filter(Settings.id == settings_id, Settings.user_id == current_user.id)
.first()
)
db.close() db.close()
if not setting: if not setting:
raise HTTPException(status_code=404, detail="Setting not found") raise HTTPException(status_code=404, detail="Setting not found")
@@ -305,9 +414,17 @@ def read_setting(settings_id: int):
@app.put("/settings/{settings_id}", response_model=SettingsResponse) @app.put("/settings/{settings_id}", response_model=SettingsResponse)
def update_setting(settings_id: int, settings: SettingsUpdate): def update_setting(
settings_id: int,
settings: SettingsUpdate,
current_user: User = Depends(get_current_user),
):
db = SessionLocal() db = SessionLocal()
existing_setting = db.query(Settings).filter(Settings.id == settings_id).first() existing_setting = (
db.query(Settings)
.filter(Settings.id == settings_id, Settings.user_id == current_user.id)
.first()
)
if not existing_setting: if not existing_setting:
raise HTTPException(status_code=404, detail="Setting not found") raise HTTPException(status_code=404, detail="Setting not found")
@@ -330,17 +447,19 @@ def update_setting(settings_id: int, settings: SettingsUpdate):
@app.get("/script", response_model=list[ScriptResponse]) @app.get("/script", response_model=list[ScriptResponse])
def read_scripts(): def read_scripts(current_user: User = Depends(get_current_user)):
db = SessionLocal() db = SessionLocal()
scripts = db.query(Script).all() scripts = db.query(Script).filter(Script.user_id == current_user.id).all()
db.close() db.close()
return scripts return scripts
@app.post("/script", response_model=ScriptResponse) @app.post("/script", response_model=ScriptResponse)
def create_script(script: ScriptCreate): def create_script(script: ScriptCreate, current_user: User = Depends(get_current_user)):
db = SessionLocal() db = SessionLocal()
new_script = Script(name=script.name, script_content=script.script_content) new_script = Script(
name=script.name, script_content=script.script_content, user_id=current_user.id
)
db.add(new_script) db.add(new_script)
db.commit() db.commit()
db.refresh(new_script) db.refresh(new_script)
@@ -349,7 +468,7 @@ def create_script(script: ScriptCreate):
@app.get("/script/{script_id}", response_model=ScriptResponse) @app.get("/script/{script_id}", response_model=ScriptResponse)
def read_script(script_id: int): def read_script(script_id: int, current_user: User = Depends(get_current_user)):
db = SessionLocal() db = SessionLocal()
script = db.query(Script).filter(Script.id == script_id).first() script = db.query(Script).filter(Script.id == script_id).first()
db.close() db.close()
@@ -359,7 +478,7 @@ def read_script(script_id: int):
@app.delete("/script/{script_id}") @app.delete("/script/{script_id}")
def delete_script(script_id: int): def delete_script(script_id: int, current_user: User = Depends(get_current_user)):
db = SessionLocal() db = SessionLocal()
script = db.query(Script).filter(Script.id == script_id).first() script = db.query(Script).filter(Script.id == script_id).first()
if not script: if not script:
@@ -375,7 +494,9 @@ def delete_script(script_id: int):
@app.put("/script/{script_id}", response_model=ScriptResponse) @app.put("/script/{script_id}", response_model=ScriptResponse)
def update_script(script_id: int, script: ScriptUpdate): def update_script(
script_id: int, script: ScriptUpdate, current_user: User = Depends(get_current_user)
):
db = SessionLocal() db = SessionLocal()
existing_script = db.query(Script).filter(Script.id == script_id).first() existing_script = db.query(Script).filter(Script.id == script_id).first()
if not existing_script: if not existing_script:
@@ -391,7 +512,7 @@ def update_script(script_id: int, script: ScriptUpdate):
@app.get("/script/{script_id}/log") @app.get("/script/{script_id}/log")
def get_script_logs(script_id: int): def get_script_logs(script_id: int, current_user: User = Depends(get_current_user)):
db = SessionLocal() db = SessionLocal()
logs = db.query(Log).filter(Log.script_id == script_id).all() logs = db.query(Log).filter(Log.script_id == script_id).all()
db.close() db.close()
@@ -399,7 +520,9 @@ def get_script_logs(script_id: int):
@app.post("/script/{script_id}/log") @app.post("/script/{script_id}/log")
def create_script_log(script_id: int, log: ScriptLogCreate): def create_script_log(
script_id: int, log: ScriptLogCreate, current_user: User = Depends(get_current_user)
):
db = SessionLocal() db = SessionLocal()
new_log = Log( new_log = Log(
script_id=script_id, script_id=script_id,
@@ -415,7 +538,9 @@ def create_script_log(script_id: int, log: ScriptLogCreate):
@app.delete("/script/{script_id}/log/{log_id}") @app.delete("/script/{script_id}/log/{log_id}")
def delete_script_log(script_id: int, log_id: int): def delete_script_log(
script_id: int, log_id: int, current_user: User = Depends(get_current_user)
):
db = SessionLocal() db = SessionLocal()
log = db.query(Log).filter(Log.id == log_id and Log.script_id == script_id).first() log = db.query(Log).filter(Log.id == log_id and Log.script_id == script_id).first()
if not log: if not log:
@@ -427,7 +552,7 @@ def delete_script_log(script_id: int, log_id: int):
@app.post("/script/{script_id}/execute") @app.post("/script/{script_id}/execute")
def execute_script(script_id: int): def execute_script(script_id: int, current_user: User = Depends(get_current_user)):
run_scripts([script_id]) run_scripts([script_id])
return {"run_script": True} return {"run_script": True}

View File

@@ -5,7 +5,6 @@ from model import SessionLocal, Subscription, Settings, Notification
import json import json
# Constants # Constants
NTFY_TOKEN = os.getenv("NTFY_TOKEN") NTFY_TOKEN = os.getenv("NTFY_TOKEN")
@@ -34,14 +33,11 @@ def fetch_ntfy_notifications(base_url, subscriptions):
notifications.append(notification) notifications.append(notification)
print(f"Fetched {len(notifications)} notifications") print(f"Fetched {len(notifications)} notifications")
print(notifications)
return notifications return notifications
def save_notifications_to_db(notifications, topic_to_subscription, db): def save_notifications_to_db(notifications, topic_to_subscription, db):
"""Save the fetched notifications to the database and update last_message_id.""" """Save the fetched notifications to the database and update last_message_id."""
db = SessionLocal()
last_message_ids = {} last_message_ids = {}
for notification in notifications: for notification in notifications:
topic = notification["topic"] topic = notification["topic"]
@@ -67,33 +63,26 @@ def save_notifications_to_db(notifications, topic_to_subscription, db):
if subscription: if subscription:
subscription.last_message_id = message_id subscription.last_message_id = message_id
db.commit() db.commit()
db.close()
def main(): def process_user_notifications(user_settings, db):
"""Main function to fetch and save notifications.""" """Process notifications for a specific user's subscriptions."""
db = SessionLocal() ntfy_url = user_settings.ntfy_url
# Get the ntfy base URL from settings
settings = db.query(Settings).filter(Settings.user == "default").first()
if not settings:
print("Default user settings not found.")
return
ntfy_url = settings.ntfy_url
if not ntfy_url: if not ntfy_url:
print("Ntfy URL not found in settings.") print(f"Ntfy URL not found for user ID {user_settings.user_id}. Skipping...")
return return
# Get all subscribed topics # Get all subscriptions for the user
subscriptions = db.query(Subscription).all() subscriptions = (
db.query(Subscription)
.filter(Subscription.user_id == user_settings.user_id)
.all()
)
topic_to_subscription = { topic_to_subscription = {
subscription.topic: subscription.id for subscription in subscriptions subscription.topic: subscription.id for subscription in subscriptions
} }
db.close()
# Fetch notifications from ntfy.sh # Fetch notifications from ntfy.sh
notifications = fetch_ntfy_notifications(ntfy_url, subscriptions) notifications = fetch_ntfy_notifications(ntfy_url, subscriptions)
@@ -101,5 +90,24 @@ def main():
save_notifications_to_db(notifications, topic_to_subscription, db) save_notifications_to_db(notifications, topic_to_subscription, db)
def main():
"""Main function to fetch and save notifications for all users."""
db = SessionLocal()
# Get all user settings
user_settings_list = db.query(Settings).all()
if not user_settings_list:
print("No user settings found.")
return
# Process notifications for each user
for user_settings in user_settings_list:
print(f"Processing notifications for user ID {user_settings.user_id}")
process_user_notifications(user_settings, db)
db.close()
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -1,9 +1,12 @@
from sqlalchemy import create_engine, Column, Integer, String, Text, ForeignKey, Boolean from sqlalchemy import create_engine, Column, Integer, String, Text, ForeignKey, Boolean
from sqlalchemy.sql.sqltypes import DateTime
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql.functions import func from sqlalchemy.sql.functions import func
from sqlalchemy.sql.sqltypes import DateTime from sqlalchemy.sql.sqltypes import DateTime
import os import os
import secrets
from passlib.context import CryptContext
# Initialize the database # Initialize the database
DATABASE_URL = os.getenv("DATABASE_URL") DATABASE_URL = os.getenv("DATABASE_URL")
@@ -17,7 +20,15 @@ SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
Base = declarative_base() Base = declarative_base()
# Define the table model class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
username = Column(String(64), unique=True, nullable=False, index=True)
password_hash = Column(String(128), nullable=False)
created_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
class Script(Base): class Script(Base):
@@ -30,6 +41,9 @@ class Script(Base):
created_at = Column( created_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now() DateTime(timezone=True), nullable=False, server_default=func.now()
) )
user_id = Column(
Integer, ForeignKey("users.id", name="fk_script_user_id"), nullable=False
)
class Log(Base): class Log(Base):
@@ -43,7 +57,9 @@ class Log(Base):
DateTime(timezone=True), nullable=False, server_default=func.now() DateTime(timezone=True), nullable=False, server_default=func.now()
) )
script_id = Column(Integer, ForeignKey("scripts.id"), nullable=False) script_id = Column(
Integer, ForeignKey("scripts.id", name="fk_log_script_id"), nullable=False
)
class Settings(Base): class Settings(Base):
@@ -52,8 +68,10 @@ class Settings(Base):
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
requirements = Column(String, nullable=False) requirements = Column(String, nullable=False)
environment = Column(String, nullable=False) environment = Column(String, nullable=False)
user = Column(String, nullable=False)
ntfy_url = Column(String, nullable=True) ntfy_url = Column(String, nullable=True)
user_id = Column(
Integer, ForeignKey("users.id", name="fk_user_settings_user_id"), nullable=False
)
class Subscription(Base): class Subscription(Base):
@@ -65,6 +83,9 @@ class Subscription(Base):
created_at = Column( created_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now() DateTime(timezone=True), nullable=False, server_default=func.now()
) )
user_id = Column(
Integer, ForeignKey("users.id", name="fk_subscription_user_id"), nullable=False
)
class Notification(Base): class Notification(Base):
@@ -77,7 +98,11 @@ class Notification(Base):
viewed = Column(Boolean, default=False) viewed = Column(Boolean, default=False)
sent = Column(Boolean, default=False) sent = Column(Boolean, default=False)
subscription_id = Column(Integer, ForeignKey("subscriptions.id"), nullable=False) subscription_id = Column(
Integer,
ForeignKey("subscriptions.id", name="fk_notification_subscription_id"),
nullable=False,
)
created_at = Column( created_at = Column(
DateTime(timezone=True), nullable=False, server_default=func.now() DateTime(timezone=True), nullable=False, server_default=func.now()
) )
@@ -87,20 +112,39 @@ class Notification(Base):
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
# Ensure a default setting line exists # Ensure a default admin user exists
def ensure_default_setting(): def ensure_default_setting():
db = SessionLocal() db = SessionLocal()
default_setting = db.query(Settings).filter(Settings.user == "default").first() admin_user = db.query(User).filter(User.username == "admin").first()
if not admin_user:
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
random_password = secrets.token_urlsafe(12)
password_hash = pwd_context.hash(random_password)
admin_user = User(username="admin", password_hash=password_hash)
db.add(admin_user)
db.commit()
print(
f"Default admin user created. Username: admin, Password: {random_password}"
)
# Refresh to get admin_user.id
db.refresh(admin_user)
# Set all rows with null user_id in Script and Subscription to admin user id
db.query(Script).filter(Script.user_id is None).update({"user_id": admin_user.id})
db.query(Subscription).filter(Subscription.user_id is None).update(
{"user_id": admin_user.id}
)
db.commit()
default_setting = (
db.query(Settings).filter(Settings.user_id == admin_user.id).first()
)
if not default_setting: if not default_setting:
new_setting = Settings( new_setting = Settings(
requirements="", requirements="",
environment="", environment="",
user="default", user_id=admin_user.id,
ntfy_url="https://ntfy.abzk.fr", ntfy_url="https://ntfy.abzk.fr",
) )
db.add(new_setting) db.add(new_setting)
db.commit() db.commit()
db.close() db.close()
ensure_default_setting()