From 02bf20921514fcb41b7f827bc3c2c492ba72bd64 Mon Sep 17 00:00:00 2001 From: deflax Date: Sat, 6 Apr 2024 17:04:05 +0300 Subject: [PATCH] implement oauth2 --- src/forest/auth/routes.py | 99 ++++++++++++++++++++++++++++++++++----- variables.env.dist | 2 + 2 files changed, 88 insertions(+), 13 deletions(-) diff --git a/src/forest/auth/routes.py b/src/forest/auth/routes.py index cefc5ca..6dfe949 100644 --- a/src/forest/auth/routes.py +++ b/src/forest/auth/routes.py @@ -1,6 +1,7 @@ from flask import render_template, redirect, request, url_for, flash, session, abort, current_app from flask_login import login_required, login_user, logout_user, current_user from markupsafe import Markup, escape +from urllib.parse import urlencode from . import auth from .forms import LoginForm, TwoFAForm, RegistrationForm, ChangePasswordForm, PasswordResetRequestForm, PasswordResetForm @@ -12,19 +13,91 @@ from models import db, User from io import BytesIO import pyqrcode -def get_google_auth(state=None, token=None): - if token: - return OAuth2Session(current_app.config['CLIENT_ID'], token=token) - if state: - return OAuth2Session( - current_app.config['CLIENT_ID'], - state=state, - redirect_uri=current_app.config['REDIRECT_URI']) - oauth = OAuth2Session( - current_app.config['CLIENT_ID'], - redirect_uri=current_app.config['REDIRECT_URI'], - scope=current_app.config['SCOPE']) - return oauth +@app.route('/authorize/') +def oauth2_authorize(provider): + if not current_user.is_anonymous: + return redirect(url_for('index')) + + provider_data = current_app.config['OAUTH2_PROVIDERS'].get(provider) + if provider_data is None: + abort(404) + + # generate a random string for the state parameter + session['oauth2_state'] = secrets.token_urlsafe(16) + + # create a query string with all the OAuth2 parameters + qs = urlencode({ + 'client_id': provider_data['client_id'], + 'redirect_uri': url_for('oauth2_callback', provider=provider, + _external=True), + 'response_type': 'code', + 'scope': ' '.join(provider_data['scopes']), + 'state': session['oauth2_state'], + }) + + # redirect the user to the OAuth2 provider authorization URL + return redirect(provider_data['authorize_url'] + '?' + qs) + +@app.route('/callback/') +def oauth2_callback(provider): + if not current_user.is_anonymous: + return redirect(url_for('index')) + + provider_data = current_app.config['OAUTH2_PROVIDERS'].get(provider) + if provider_data is None: + abort(404) + + # if there was an authentication error, flash the error messages and exit + if 'error' in request.args: + for k, v in request.args.items(): + if k.startswith('error'): + flash(f'{k}: {v}') + return redirect(url_for('index')) + + # make sure that the state parameter matches the one we created in the + # authorization request + if request.args['state'] != session.get('oauth2_state'): + abort(401) + + # make sure that the authorization code is present + if 'code' not in request.args: + abort(401) + + # exchange the authorization code for an access token + response = requests.post(provider_data['token_url'], data={ + 'client_id': provider_data['client_id'], + 'client_secret': provider_data['client_secret'], + 'code': request.args['code'], + 'grant_type': 'authorization_code', + 'redirect_uri': url_for('oauth2_callback', provider=provider, + _external=True), + }, headers={'Accept': 'application/json'}) + if response.status_code != 200: + abort(401) + oauth2_token = response.json().get('access_token') + if not oauth2_token: + abort(401) + + # use the access token to get the user's email address + response = requests.get(provider_data['userinfo']['url'], headers={ + 'Authorization': 'Bearer ' + oauth2_token, + 'Accept': 'application/json', + }) + if response.status_code != 200: + abort(401) + email = provider_data['userinfo']['email'](response.json()) + + # find or create the user in the database + user = db.session.scalar(db.select(User).where(User.email == email)) + if user is None: + #user = User(email=email, username=email.split('@')[0]) + user = User(email=email) + db.session.add(user) + db.session.commit() + + # log the user in + login_user(user) + return redirect(url_for('index')) @auth.before_app_request def before_request(): diff --git a/variables.env.dist b/variables.env.dist index 3d1a1e5..fb5f73b 100644 --- a/variables.env.dist +++ b/variables.env.dist @@ -16,6 +16,8 @@ PGADMIN_DEFAULT_PASSWORD=hackme GOOGLE_CLIENT_ID=changeme GOOGLE_CLIENT_SECRET=changeme +GITHUB_CLIENT_ID=changeme +GITHUB_CLIENT_SECRET=changeme MAIL_SENDER=mail@example.com MAIL_SUBJECT_PREFIX=ForestNet