implement oauth2
This commit is contained in:
parent
04833ff7e0
commit
02bf209215
2 changed files with 88 additions and 13 deletions
|
@ -1,6 +1,7 @@
|
|||
from flask import render_template, redirect, request, url_for, flash, session, abort, current_app
|
||||
from flask_login import login_required, login_user, logout_user, current_user
|
||||
from markupsafe import Markup, escape
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from . import auth
|
||||
from .forms import LoginForm, TwoFAForm, RegistrationForm, ChangePasswordForm, PasswordResetRequestForm, PasswordResetForm
|
||||
|
@ -12,19 +13,91 @@ from models import db, User
|
|||
from io import BytesIO
|
||||
import pyqrcode
|
||||
|
||||
def get_google_auth(state=None, token=None):
|
||||
if token:
|
||||
return OAuth2Session(current_app.config['CLIENT_ID'], token=token)
|
||||
if state:
|
||||
return OAuth2Session(
|
||||
current_app.config['CLIENT_ID'],
|
||||
state=state,
|
||||
redirect_uri=current_app.config['REDIRECT_URI'])
|
||||
oauth = OAuth2Session(
|
||||
current_app.config['CLIENT_ID'],
|
||||
redirect_uri=current_app.config['REDIRECT_URI'],
|
||||
scope=current_app.config['SCOPE'])
|
||||
return oauth
|
||||
@app.route('/authorize/<provider>')
|
||||
def oauth2_authorize(provider):
|
||||
if not current_user.is_anonymous:
|
||||
return redirect(url_for('index'))
|
||||
|
||||
provider_data = current_app.config['OAUTH2_PROVIDERS'].get(provider)
|
||||
if provider_data is None:
|
||||
abort(404)
|
||||
|
||||
# generate a random string for the state parameter
|
||||
session['oauth2_state'] = secrets.token_urlsafe(16)
|
||||
|
||||
# create a query string with all the OAuth2 parameters
|
||||
qs = urlencode({
|
||||
'client_id': provider_data['client_id'],
|
||||
'redirect_uri': url_for('oauth2_callback', provider=provider,
|
||||
_external=True),
|
||||
'response_type': 'code',
|
||||
'scope': ' '.join(provider_data['scopes']),
|
||||
'state': session['oauth2_state'],
|
||||
})
|
||||
|
||||
# redirect the user to the OAuth2 provider authorization URL
|
||||
return redirect(provider_data['authorize_url'] + '?' + qs)
|
||||
|
||||
@app.route('/callback/<provider>')
|
||||
def oauth2_callback(provider):
|
||||
if not current_user.is_anonymous:
|
||||
return redirect(url_for('index'))
|
||||
|
||||
provider_data = current_app.config['OAUTH2_PROVIDERS'].get(provider)
|
||||
if provider_data is None:
|
||||
abort(404)
|
||||
|
||||
# if there was an authentication error, flash the error messages and exit
|
||||
if 'error' in request.args:
|
||||
for k, v in request.args.items():
|
||||
if k.startswith('error'):
|
||||
flash(f'{k}: {v}')
|
||||
return redirect(url_for('index'))
|
||||
|
||||
# make sure that the state parameter matches the one we created in the
|
||||
# authorization request
|
||||
if request.args['state'] != session.get('oauth2_state'):
|
||||
abort(401)
|
||||
|
||||
# make sure that the authorization code is present
|
||||
if 'code' not in request.args:
|
||||
abort(401)
|
||||
|
||||
# exchange the authorization code for an access token
|
||||
response = requests.post(provider_data['token_url'], data={
|
||||
'client_id': provider_data['client_id'],
|
||||
'client_secret': provider_data['client_secret'],
|
||||
'code': request.args['code'],
|
||||
'grant_type': 'authorization_code',
|
||||
'redirect_uri': url_for('oauth2_callback', provider=provider,
|
||||
_external=True),
|
||||
}, headers={'Accept': 'application/json'})
|
||||
if response.status_code != 200:
|
||||
abort(401)
|
||||
oauth2_token = response.json().get('access_token')
|
||||
if not oauth2_token:
|
||||
abort(401)
|
||||
|
||||
# use the access token to get the user's email address
|
||||
response = requests.get(provider_data['userinfo']['url'], headers={
|
||||
'Authorization': 'Bearer ' + oauth2_token,
|
||||
'Accept': 'application/json',
|
||||
})
|
||||
if response.status_code != 200:
|
||||
abort(401)
|
||||
email = provider_data['userinfo']['email'](response.json())
|
||||
|
||||
# find or create the user in the database
|
||||
user = db.session.scalar(db.select(User).where(User.email == email))
|
||||
if user is None:
|
||||
#user = User(email=email, username=email.split('@')[0])
|
||||
user = User(email=email)
|
||||
db.session.add(user)
|
||||
db.session.commit()
|
||||
|
||||
# log the user in
|
||||
login_user(user)
|
||||
return redirect(url_for('index'))
|
||||
|
||||
@auth.before_app_request
|
||||
def before_request():
|
||||
|
|
|
@ -16,6 +16,8 @@ PGADMIN_DEFAULT_PASSWORD=hackme
|
|||
|
||||
GOOGLE_CLIENT_ID=changeme
|
||||
GOOGLE_CLIENT_SECRET=changeme
|
||||
GITHUB_CLIENT_ID=changeme
|
||||
GITHUB_CLIENT_SECRET=changeme
|
||||
|
||||
MAIL_SENDER=mail@example.com
|
||||
MAIL_SUBJECT_PREFIX=ForestNet
|
||||
|
|
Loading…
Add table
Reference in a new issue