"auth" -> "user"

This commit is contained in:
Jörn-Michael Miehe 2022-03-18 23:04:28 +00:00
parent 98c89991b4
commit b42a5b44f3
5 changed files with 63 additions and 28 deletions

View file

@ -11,6 +11,5 @@
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": true
},
"python.formatting.provider": "black"
}
}

View file

@ -5,6 +5,32 @@ from sqlalchemy.orm import Session, sessionmaker
from .models import ORMBaseModel
class Connection:
engine: Engine | None = None
session_local: sessionmaker | None = None
@classmethod
def connect(cls, engine: Engine) -> None:
cls.engine = engine
cls.session_local = sessionmaker(
autocommit=False, autoflush=False, bind=engine,
)
ORMBaseModel.metadata.create_all(bind=engine)
@classmethod
async def get(cls) -> Generator[Session | None, None, None]:
if cls.session_local is None:
yield None
else:
db = cls.session_local()
try:
yield db
finally:
db.close()
ENGINE: Engine | None = None
SESSION_LOCAL: sessionmaker | None = None

View file

@ -4,8 +4,8 @@ import uvicorn
from fastapi import FastAPI
from .config import Config, Settings
from .db import connection, crud
from .routers import admin
from .db.connection import Connection
from .routers import admin, user
settings = Settings.get()
@ -36,16 +36,16 @@ async def on_startup():
api.include_router(admin.router)
if (current_config := await Config.get()) is not None:
connection.reconnect(current_config.db_engine)
Connection.connect(current_config.db_engine)
async for db in connection.get():
user = crud.get_user(db, "admin")
print(user.name)
for cap in user.capabilities:
print(cap.capability)
# async for db in connection.get():
# user = crud.get_user(db, "admin")
# print(user.name)
# for cap in user.capabilities:
# print(cap.capability)
# include other routers
# api.include_router(auth.router)
api.include_router(user.router)
def main():

View file

@ -4,7 +4,8 @@ from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from ..config import Config
from ..db import connection, crud, schemas
from ..db import crud, schemas
from ..db.connection import Connection
router = APIRouter(prefix="/admin")
@ -31,7 +32,7 @@ async def set_config(
if new_config.jwt.secret is None:
new_config.jwt.secret = token_hex(32)
connection.reconnect(new_config.db_engine)
Connection.connect(new_config.db_engine)
Config.set(new_config)
@ -52,7 +53,7 @@ async def add_user(
user_name: str,
user_password: str,
current_config: Config | None = Depends(Config.get),
db: Session | None = Depends(connection.get),
db: Session | None = Depends(Connection.get),
):
if current_config is None:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)

View file

@ -4,15 +4,14 @@ from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import jwt
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ..config import (ACCESS_TOKEN_EXPIRE_MINUTES, ALGORITHM, CRYPT_CONTEXT,
SECRET_KEY)
from ..db import User
from ..config import Config
from ..db import crud
from ..db.connection import Connection
router = APIRouter(prefix="/auth")
SCHEME = OAuth2PasswordBearer(
tokenUrl=f".{router.prefix}/token"
)
router = APIRouter(prefix="/user")
SCHEME = OAuth2PasswordBearer(tokenUrl=f".{router.prefix}/token")
class Token(BaseModel):
@ -20,18 +19,28 @@ class Token(BaseModel):
token_type: str
def create_access_token(data: dict, expires_delta: timedelta | None = None):
def create_access_token(
data: dict,
expires_delta: timedelta | None = None,
config: Config = Depends(Config.get),
):
to_encode = data.copy()
if expires_delta is None:
expires_delta = timedelta(minutes=15)
to_encode.update({"exp": datetime.utcnow() + expires_delta})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
return jwt.encode(
to_encode,
config.jwt.secret,
algorithm=config.jwt.hash_algorithm,
)
@router.post("/token", response_model=Token)
@router.post("/auth", response_model=Token)
async def login_for_access_token(
form_data: OAuth2PasswordRequestForm = Depends()
form_data: OAuth2PasswordRequestForm = Depends(),
config: Config = Depends(Config.get),
db: Session = Depends(Connection.get),
):
credentials_exception = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -39,9 +48,9 @@ async def login_for_access_token(
headers={"WWW-Authenticate": "Bearer"},
)
user = User.get_by_name(form_data.username)
user = crud.get_user(db, form_data.username)
if user is None:
CRYPT_CONTEXT.dummy_verify()
config.crypt_context.dummy_verify()
raise credentials_exception
if not user.verify(form_data.password):
@ -49,6 +58,6 @@ async def login_for_access_token(
access_token = create_access_token(
data={"sub": user.name},
expires_delta=timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES),
expires_delta=timedelta(minutes=config.jwt.expiry_minutes),
)
return {"access_token": access_token, "token_type": "bearer"}