diff --git a/api/.vscode/settings.json b/api/.vscode/settings.json index 78086c1..cd32962 100644 --- a/api/.vscode/settings.json +++ b/api/.vscode/settings.json @@ -11,6 +11,5 @@ "editor.formatOnSave": true, "editor.codeActionsOnSave": { "source.organizeImports": true - }, - "python.formatting.provider": "black" + } } \ No newline at end of file diff --git a/api/kiwi_vpn_api/db/connection.py b/api/kiwi_vpn_api/db/connection.py index 6c45fd8..23c537d 100644 --- a/api/kiwi_vpn_api/db/connection.py +++ b/api/kiwi_vpn_api/db/connection.py @@ -5,6 +5,32 @@ from sqlalchemy.orm import Session, sessionmaker from .models import ORMBaseModel + +class Connection: + engine: Engine | None = None + session_local: sessionmaker | None = None + + @classmethod + def connect(cls, engine: Engine) -> None: + cls.engine = engine + cls.session_local = sessionmaker( + autocommit=False, autoflush=False, bind=engine, + ) + ORMBaseModel.metadata.create_all(bind=engine) + + @classmethod + async def get(cls) -> Generator[Session | None, None, None]: + if cls.session_local is None: + yield None + + else: + db = cls.session_local() + try: + yield db + finally: + db.close() + + ENGINE: Engine | None = None SESSION_LOCAL: sessionmaker | None = None diff --git a/api/kiwi_vpn_api/main.py b/api/kiwi_vpn_api/main.py index b1c87e5..2612862 100755 --- a/api/kiwi_vpn_api/main.py +++ b/api/kiwi_vpn_api/main.py @@ -4,8 +4,8 @@ import uvicorn from fastapi import FastAPI from .config import Config, Settings -from .db import connection, crud -from .routers import admin +from .db.connection import Connection +from .routers import admin, user settings = Settings.get() @@ -36,16 +36,16 @@ async def on_startup(): api.include_router(admin.router) if (current_config := await Config.get()) is not None: - connection.reconnect(current_config.db_engine) + Connection.connect(current_config.db_engine) - async for db in connection.get(): - user = crud.get_user(db, "admin") - print(user.name) - for cap in user.capabilities: - print(cap.capability) + # async for db in connection.get(): + # user = crud.get_user(db, "admin") + # print(user.name) + # for cap in user.capabilities: + # print(cap.capability) # include other routers - # api.include_router(auth.router) + api.include_router(user.router) def main(): diff --git a/api/kiwi_vpn_api/routers/admin.py b/api/kiwi_vpn_api/routers/admin.py index 8926418..76f999c 100644 --- a/api/kiwi_vpn_api/routers/admin.py +++ b/api/kiwi_vpn_api/routers/admin.py @@ -4,7 +4,8 @@ from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.orm import Session from ..config import Config -from ..db import connection, crud, schemas +from ..db import crud, schemas +from ..db.connection import Connection router = APIRouter(prefix="/admin") @@ -31,7 +32,7 @@ async def set_config( if new_config.jwt.secret is None: new_config.jwt.secret = token_hex(32) - connection.reconnect(new_config.db_engine) + Connection.connect(new_config.db_engine) Config.set(new_config) @@ -52,7 +53,7 @@ async def add_user( user_name: str, user_password: str, current_config: Config | None = Depends(Config.get), - db: Session | None = Depends(connection.get), + db: Session | None = Depends(Connection.get), ): if current_config is None: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) diff --git a/api/kiwi_vpn_api/routers/auth.py b/api/kiwi_vpn_api/routers/user.py similarity index 54% rename from api/kiwi_vpn_api/routers/auth.py rename to api/kiwi_vpn_api/routers/user.py index 52896c8..797ce96 100644 --- a/api/kiwi_vpn_api/routers/auth.py +++ b/api/kiwi_vpn_api/routers/user.py @@ -4,15 +4,14 @@ from fastapi import APIRouter, Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from jose import jwt from pydantic import BaseModel +from sqlalchemy.orm import Session -from ..config import (ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, CRYPT_CONTEXT, - SECRET_KEY) -from ..db import User +from ..config import Config +from ..db import crud +from ..db.connection import Connection -router = APIRouter(prefix="/auth") -SCHEME = OAuth2PasswordBearer( - tokenUrl=f".{router.prefix}/token" -) +router = APIRouter(prefix="/user") +SCHEME = OAuth2PasswordBearer(tokenUrl=f".{router.prefix}/token") class Token(BaseModel): @@ -20,18 +19,28 @@ class Token(BaseModel): token_type: str -def create_access_token(data: dict, expires_delta: timedelta | None = None): +def create_access_token( + data: dict, + expires_delta: timedelta | None = None, + config: Config = Depends(Config.get), +): to_encode = data.copy() if expires_delta is None: expires_delta = timedelta(minutes=15) to_encode.update({"exp": datetime.utcnow() + expires_delta}) - return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) + return jwt.encode( + to_encode, + config.jwt.secret, + algorithm=config.jwt.hash_algorithm, + ) -@router.post("/token", response_model=Token) +@router.post("/auth", response_model=Token) async def login_for_access_token( - form_data: OAuth2PasswordRequestForm = Depends() + form_data: OAuth2PasswordRequestForm = Depends(), + config: Config = Depends(Config.get), + db: Session = Depends(Connection.get), ): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -39,9 +48,9 @@ async def login_for_access_token( headers={"WWW-Authenticate": "Bearer"}, ) - user = User.get_by_name(form_data.username) + user = crud.get_user(db, form_data.username) if user is None: - CRYPT_CONTEXT.dummy_verify() + config.crypt_context.dummy_verify() raise credentials_exception if not user.verify(form_data.password): @@ -49,6 +58,6 @@ async def login_for_access_token( access_token = create_access_token( data={"sub": user.name}, - expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES), + expires_delta=timedelta(minutes=config.jwt.expiry_minutes), ) return {"access_token": access_token, "token_type": "bearer"}