"auth" -> "user"

This commit is contained in:
Jörn-Michael Miehe 2022-03-18 23:04:28 +00:00
parent 98c89991b4
commit b42a5b44f3
5 changed files with 63 additions and 28 deletions

View file

@ -11,6 +11,5 @@
"editor.formatOnSave": true, "editor.formatOnSave": true,
"editor.codeActionsOnSave": { "editor.codeActionsOnSave": {
"source.organizeImports": true "source.organizeImports": true
}, }
"python.formatting.provider": "black"
} }

View file

@ -5,6 +5,32 @@ from sqlalchemy.orm import Session, sessionmaker
from .models import ORMBaseModel from .models import ORMBaseModel
class Connection:
engine: Engine | None = None
session_local: sessionmaker | None = None
@classmethod
def connect(cls, engine: Engine) -> None:
cls.engine = engine
cls.session_local = sessionmaker(
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
@classmethod
async def get(cls) -> Generator[Session | None, None, None]:
if cls.session_local is None:
yield None
else:
db = cls.session_local()
try:
yield db
finally:
db.close()
ENGINE: Engine | None = None ENGINE: Engine | None = None
SESSION_LOCAL: sessionmaker | None = None SESSION_LOCAL: sessionmaker | None = None

View file

@ -4,8 +4,8 @@ import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from .config import Config, Settings from .config import Config, Settings
from .db import connection, crud from .db.connection import Connection
from .routers import admin from .routers import admin, user
settings = Settings.get() settings = Settings.get()
@ -36,16 +36,16 @@ async def on_startup():
api.include_router(admin.router) api.include_router(admin.router)
if (current_config := await Config.get()) is not None: if (current_config := await Config.get()) is not None:
connection.reconnect(current_config.db_engine) Connection.connect(current_config.db_engine)
async for db in connection.get(): # async for db in connection.get():
user = crud.get_user(db, "admin") # user = crud.get_user(db, "admin")
print(user.name) # print(user.name)
for cap in user.capabilities: # for cap in user.capabilities:
print(cap.capability) # print(cap.capability)
# include other routers # include other routers
# api.include_router(auth.router) api.include_router(user.router)
def main(): def main():

View file

@ -4,7 +4,8 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from ..config import Config from ..config import Config
from ..db import connection, crud, schemas from ..db import crud, schemas
from ..db.connection import Connection
router = APIRouter(prefix="/admin") router = APIRouter(prefix="/admin")
@ -31,7 +32,7 @@ async def set_config(
if new_config.jwt.secret is None: if new_config.jwt.secret is None:
new_config.jwt.secret = token_hex(32) new_config.jwt.secret = token_hex(32)
connection.reconnect(new_config.db_engine) Connection.connect(new_config.db_engine)
Config.set(new_config) Config.set(new_config)
@ -52,7 +53,7 @@ async def add_user(
user_name: str, user_name: str,
user_password: str, user_password: str,
current_config: Config | None = Depends(Config.get), current_config: Config | None = Depends(Config.get),
db: Session | None = Depends(connection.get), db: Session | None = Depends(Connection.get),
): ):
if current_config is None: if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST) raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)

View file

@ -4,15 +4,14 @@ from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import jwt from jose import jwt
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.orm import Session
from ..config import (ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, CRYPT_CONTEXT, from ..config import Config
SECRET_KEY) from ..db import crud
from ..db import User from ..db.connection import Connection
router = APIRouter(prefix="/auth") router = APIRouter(prefix="/user")
SCHEME = OAuth2PasswordBearer( SCHEME = OAuth2PasswordBearer(tokenUrl=f".{router.prefix}/token")
tokenUrl=f".{router.prefix}/token"
)
class Token(BaseModel): class Token(BaseModel):
@ -20,18 +19,28 @@ class Token(BaseModel):
token_type: str token_type: str
def create_access_token(data: dict, expires_delta: timedelta | None = None): def create_access_token(
data: dict,
expires_delta: timedelta | None = None,
config: Config = Depends(Config.get),
):
to_encode = data.copy() to_encode = data.copy()
if expires_delta is None: if expires_delta is None:
expires_delta = timedelta(minutes=15) expires_delta = timedelta(minutes=15)
to_encode.update({"exp": datetime.utcnow() + expires_delta}) to_encode.update({"exp": datetime.utcnow() + expires_delta})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return jwt.encode(
to_encode,
config.jwt.secret,
algorithm=config.jwt.hash_algorithm,
)
@router.post("/token", response_model=Token) @router.post("/auth", response_model=Token)
async def login_for_access_token( async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends() form_data: OAuth2PasswordRequestForm = Depends(),
config: Config = Depends(Config.get),
db: Session = Depends(Connection.get),
): ):
credentials_exception = HTTPException( credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -39,9 +48,9 @@ async def login_for_access_token(
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
user = User.get_by_name(form_data.username) user = crud.get_user(db, form_data.username)
if user is None: if user is None:
CRYPT_CONTEXT.dummy_verify() config.crypt_context.dummy_verify()
raise credentials_exception raise credentials_exception
if not user.verify(form_data.password): if not user.verify(form_data.password):
@ -49,6 +58,6 @@ async def login_for_access_token(
access_token = create_access_token( access_token = create_access_token(
data={"sub": user.name}, data={"sub": user.name},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES), expires_delta=timedelta(minutes=config.jwt.expiry_minutes),
) )
return {"access_token": access_token, "token_type": "bearer"} return {"access_token": access_token, "token_type": "bearer"}