commented code and fixed security issue regarding the pgp key, now the database only stores the fingerprint and when loggin in it checks if the submited pgp key has the same fingerprint as the one in the database

This commit is contained in:
bacalhau 2026-03-04 11:14:27 +00:00
parent a882030ed5
commit ef4d8c7486

View file

@ -6,17 +6,19 @@ import secrets
import os import os
from werkzeug.utils import secure_filename from werkzeug.utils import secure_filename
# defines where the upload folder is and creates it
UPLOAD_FOLDER = "static/uploads" UPLOAD_FOLDER = "static/uploads"
os.makedirs(UPLOAD_FOLDER, exist_ok=True) os.makedirs(UPLOAD_FOLDER, exist_ok=True)
app = Flask(__name__) # configures the app
app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://love:love@localhost:3309/lovedb' app = Flask(__name__) # creates de app
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False app.config['SQLALCHEMY_DATABASE_URI'] = 'mysql+pymysql://love:love@localhost:3309/lovedb' # database connection
app.config['SECRET_KEY'] = 'random' app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False # disable track modifications (for better performance)
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER app.config['SECRET_KEY'] = 'random' # sets the secret key used to generate random numbers
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER # sets the upload folder
db = SQLAlchemy(app) db = SQLAlchemy(app) # its like a shortcut to the database
gpg = gnupg.GPG() gpg = gnupg.GPG() # same as above but for gpg
COUNTRIES = [ "Afghanistan","Albania","Algeria","Andorra","Angola","Antigua and Barbuda","Argentina", COUNTRIES = [ "Afghanistan","Albania","Algeria","Andorra","Angola","Antigua and Barbuda","Argentina",
"Armenia","Australia","Austria","Azerbaijan","Bahamas","Bahrain","Bangladesh", "Armenia","Australia","Austria","Azerbaijan","Bahamas","Bahrain","Bangladesh",
@ -45,6 +47,7 @@ COUNTRIES = [ "Afghanistan","Albania","Algeria","Andorra","Angola","Antigua and
"Uzbekistan","Vanuatu","Venezuela","Vietnam","Yemen","Zambia","Zimbabwe" ] "Uzbekistan","Vanuatu","Venezuela","Vietnam","Yemen","Zambia","Zimbabwe" ]
# Database creation
class User(db.Model): class User(db.Model):
id = db.Column(db.Integer, primary_key=True) id = db.Column(db.Integer, primary_key=True)
username = db.Column(db.String(128), unique=True, nullable=False) username = db.Column(db.String(128), unique=True, nullable=False)
@ -68,20 +71,26 @@ class User(db.Model):
phone = db.Column(db.String(20), unique=True, nullable=True) phone = db.Column(db.String(20), unique=True, nullable=True)
is_verified = db.Column(db.Boolean, default=False) is_verified = db.Column(db.Boolean, default=False)
# calculates user age
def calculate_age(dob: date) -> int: def calculate_age(dob: date) -> int:
today = date.today() today = date.today()
return today.year - dob.year - ((today.month, today.day) < (dob.month, dob.day)) return today.year - dob.year - ((today.month, today.day) < (dob.month, dob.day))
# saves files to the upload folder and returns their URL
def save_files(username: str, profile_file, pictures_files): def save_files(username: str, profile_file, pictures_files):
# creates a path for the user inside the upload forlder
user_folder = os.path.join(app.config['UPLOAD_FOLDER'], username) user_folder = os.path.join(app.config['UPLOAD_FOLDER'], username)
os.makedirs(user_folder, exist_ok=True) os.makedirs(user_folder, exist_ok=True)
# prevents unsafe characters to be used in the filename
profile_filename = secure_filename(profile_file.filename) profile_filename = secure_filename(profile_file.filename)
profile_path = os.path.join(user_folder, profile_filename) profile_path = os.path.join(user_folder, profile_filename)
# saves the profile picture to the path
profile_file.save(profile_path) profile_file.save(profile_path)
profile_url = f"/{profile_path.replace(os.sep, '/')}" profile_url = f"/{profile_path.replace(os.sep, '/')}"
# saves all of the other pictures
pictures_urls = [] pictures_urls = []
for pic in pictures_files: for pic in pictures_files:
if pic.filename: if pic.filename:
@ -92,16 +101,21 @@ def save_files(username: str, profile_file, pictures_files):
return profile_url, pictures_urls return profile_url, pictures_urls
# encrypts the chalange for the user to then decrypt with pgp
def pgp_encrypt_and_import(pgp_key: str, message: str): def pgp_encrypt_and_import(pgp_key: str, message: str):
# imports the user's key
result = gpg.import_keys(pgp_key) result = gpg.import_keys(pgp_key)
# check to see if the key has fingerprints
if not result.fingerprints: if not result.fingerprints:
return None, None return None, None
fingerprint = result.fingerprints[0] fingerprint = result.fingerprints[0]
# encrypts message to the user's fingerprint
encrypted = gpg.encrypt(message, recipients=[fingerprint]) encrypted = gpg.encrypt(message, recipients=[fingerprint])
if not encrypted.ok: if not encrypted.ok:
return fingerprint, None return fingerprint, None
return fingerprint, str(encrypted) return fingerprint, str(encrypted)
# ROUTES ------------------------------------------------------------------------------------------------------
@app.route("/") @app.route("/")
def home(): def home():
@ -111,51 +125,65 @@ def home():
@app.route("/register", methods=["GET", "POST"]) @app.route("/register", methods=["GET", "POST"])
def register(): def register():
if request.method == "POST": if request.method == "POST":
# collect data to a dictionary
data = {key: request.form.get(key) for key in [ data = {key: request.form.get(key) for key in [
"username","pgp","firstname","lastname","sex","date_of_birth","country","xmpp", "username","pgp","firstname","lastname","sex","date_of_birth","country","xmpp",
"email","phone","city","height","weight","race","prefered_age_range" "email","phone","city","height","weight","race","prefered_age_range"
]} ]}
# required fields
required_fields = ["username","pgp","firstname","lastname","sex","date_of_birth","country","xmpp"] required_fields = ["username","pgp","firstname","lastname","sex","date_of_birth","country","xmpp"]
if not all(data[f] for f in required_fields): if not all(data[f] for f in required_fields):
flash("Please fill all required fields.") flash("Please fill all required fields.")
return redirect(url_for("register")) return redirect(url_for("register"))
# check if fields are unique
for field in ["username","xmpp","email","phone"]: for field in ["username","xmpp","email","phone"]:
if data.get(field) and User.query.filter_by(**{field:data[field]}).first(): if data.get(field) and User.query.filter_by(**{field:data[field]}).first():
flash(f"{field.capitalize()} already exists.") flash(f"{field.capitalize()} already exists.")
return redirect(url_for("register")) return redirect(url_for("register"))
# validates date format to iso (YYYY-MM-DD)
try: try:
dob = date.fromisoformat(data["date_of_birth"]) dob = date.fromisoformat(data["date_of_birth"])
except ValueError: except ValueError:
flash("Invalid date format.") flash("Invalid date format.")
return redirect(url_for("register")) return redirect(url_for("register"))
# blocks underage users
if calculate_age(dob) < 18: if calculate_age(dob) < 18:
flash("You must be at least 18 years old to register.") flash("You must be at least 18 years old to register.")
return redirect(url_for("register")) return redirect(url_for("register"))
# retrieves the user uploaded pictures
profile_file = request.files.get("profile_picture") profile_file = request.files.get("profile_picture")
pictures_files = request.files.getlist("pictures") pictures_files = request.files.getlist("pictures")
# doesn't let the user create an account without a profile picture
if not profile_file: if not profile_file:
flash("Profile picture is required.") flash("Profile picture is required.")
return redirect(url_for("register")) return redirect(url_for("register"))
# saves the users pictures
profile_url, pictures_urls = save_files(data["username"], profile_file, pictures_files) profile_url, pictures_urls = save_files(data["username"], profile_file, pictures_files)
# creates a random string
random_string = secrets.token_hex(16) random_string = secrets.token_hex(16)
# uses the string to create the message that wll be encrypted
challenge_phrase = f"this is the unencrypted string: {random_string}" challenge_phrase = f"this is the unencrypted string: {random_string}"
# encrypts message
fingerprint, encrypted_msg = pgp_encrypt_and_import(data["pgp"], challenge_phrase) fingerprint, encrypted_msg = pgp_encrypt_and_import(data["pgp"], challenge_phrase)
# checks fingerprint
if not fingerprint or not encrypted_msg: if not fingerprint or not encrypted_msg:
flash("Invalid PGP key or encryption failed.") flash("Invalid PGP key or encryption failed.")
return redirect(url_for("register")) return redirect(url_for("register"))
# creates a temporary session used to verify the user
session["pending_user"] = {**data, "profile_url": profile_url, "pictures_urls": pictures_urls} session["pending_user"] = {**data, "profile_url": profile_url, "pictures_urls": pictures_urls}
session["pgp_expected_phrase"] = challenge_phrase session["pgp_expected_phrase"] = challenge_phrase
# renders the verification page
return render_template("verify.html", encrypted_message=encrypted_msg) return render_template("verify.html", encrypted_message=encrypted_msg)
return render_template("register.html", countries=COUNTRIES) return render_template("register.html", countries=COUNTRIES)
@ -163,26 +191,35 @@ def register():
@app.route("/verify", methods=["POST"]) @app.route("/verify", methods=["POST"])
def verify(): def verify():
# retrieve the phrase from the session
expected_phrase = session.get("pgp_expected_phrase") expected_phrase = session.get("pgp_expected_phrase")
# retrieve user data from the session
data = session.get("pending_user") data = session.get("pending_user")
# check to see if data exists
if not data or not expected_phrase: if not data or not expected_phrase:
flash("Session expired.") flash("Session expired.")
return redirect(url_for("register")) return redirect(url_for("register"))
# get the decrypted message
submitted = request.form.get("decrypted_message") submitted = request.form.get("decrypted_message")
# check to see if submission was empty
if not submitted: if not submitted:
flash("You must paste the decrypted message.") flash("You must paste the decrypted message.")
return redirect(url_for("register")) return redirect(url_for("register"))
# checks if frase is correct
if submitted.strip() != expected_phrase: if submitted.strip() != expected_phrase:
flash("Verification failed. Account not created.") flash("Verification failed. Account not created.")
return redirect(url_for("register")) return redirect(url_for("register"))
# saves the correcty formated date of birth
dob = date.fromisoformat(data["date_of_birth"]) dob = date.fromisoformat(data["date_of_birth"])
# stores the data on the database
new_user = User( new_user = User(
username=data["username"], username=data["username"],
pgp=data["pgp"], pgp=fingerprint, # i store the fingerprint not the whole pgp key
firstname=data["firstname"], firstname=data["firstname"],
lastname=data["lastname"], lastname=data["lastname"],
sex=data["sex"], sex=data["sex"],
@ -204,8 +241,10 @@ def verify():
db.session.add(new_user) db.session.add(new_user)
db.session.commit() db.session.commit()
# creates login session
session['user_id'] = new_user.id session['user_id'] = new_user.id
session['username'] = new_user.username session['username'] = new_user.username
# remove temporary session
session.pop("pending_user", None) session.pop("pending_user", None)
session.pop("pgp_expected_phrase", None) session.pop("pgp_expected_phrase", None)
@ -215,6 +254,7 @@ def verify():
@app.route("/login", methods=["GET","POST"]) @app.route("/login", methods=["GET","POST"])
def login(): def login():
# Requests username and pgp
if request.method == "POST": if request.method == "POST":
username = request.form.get("username") username = request.form.get("username")
pgp_key = request.form.get("pgp") pgp_key = request.form.get("pgp")
@ -223,48 +263,74 @@ def login():
flash("Please enter both username and PGP key.") flash("Please enter both username and PGP key.")
return redirect(url_for("login")) return redirect(url_for("login"))
# cehcks if user exists
user = User.query.filter_by(username=username).first() user = User.query.filter_by(username=username).first()
if not user: if not user:
flash("User not found.") flash("User not found.")
return redirect(url_for("login")) return redirect(url_for("login"))
random_string = secrets.token_hex(16) # checks if imported pgp key has valid fingerprints
challenge_phrase = f"this is the unencrypted string: {random_string}" pgp = gpg.import_keys(pgp_key)
fingerprint, encrypted_msg = pgp_encrypt_and_import(pgp_key, challenge_phrase) if not pgp.fingerprints:
flash("Invalid PGP key.")
if not fingerprint or not encrypted_msg:
flash("Invalid PGP key or encryption failed.")
return redirect(url_for("login")) return redirect(url_for("login"))
# retrieves fingerprint
submitted_fingerprint = pgp.fingerprints[0]
# Checks if pgp matches the user's pgp
if submitted_fingerprint != user.pgp:
flash("PGP key does not match our records.")
return redirect(url_for("login"))
# Generate a challenge for PGP verification
random_string = secrets.token_hex(16)
challenge_phrase = f"this is the unencrypted string: {random_string}"
# Encrypt the challenge phrase using the stored fingerprint
encrypted = gpg.encrypt(challenge_phrase, recipients=[submitted_fingerprint])
if not encrypted.ok:
flash("Encryption failed.")
return redirect(url_for("login"))
# Store login verification data in session (temporary)
session["login_user_id"] = user.id session["login_user_id"] = user.id
session["login_expected_phrase"] = challenge_phrase session["login_expected_phrase"] = challenge_phrase
return render_template("login_verify.html", encrypted_message=encrypted_msg) # Render page where user will paste decrypted message
return render_template("login_verify.html", encrypted_message=str(encrypted))
return render_template("login.html") return render_template("login.html")
@app.route("/login_verify", methods=["POST"]) @app.route("/login_verify", methods=["POST"])
def login_verify(): def login_verify():
# get the temporary session data
user_id = session.get("login_user_id") user_id = session.get("login_user_id")
expected_phrase = session.get("login_expected_phrase") expected_phrase = session.get("login_expected_phrase")
# cehcks if session exists
if not user_id or not expected_phrase: if not user_id or not expected_phrase:
flash("Login session expired") flash("Login session expired")
return redirect(url_for("login")) return redirect(url_for("login"))
# cehcks if decrypted frase was submited
submitted = request.form.get("decrypted_message") submitted = request.form.get("decrypted_message")
if not submitted: if not submitted:
flash("You must paste the decrypted message") flash("You must paste the decrypted message")
return redirect(url_for("login")) return redirect(url_for("login"))
# Checks if submited frase matches the expected
if submitted.strip() != expected_phrase: if submitted.strip() != expected_phrase:
flash("Verification failed") flash("Verification failed")
return redirect(url_for("login")) return redirect(url_for("login"))
# saves session
user = User.query.get(user_id) user = User.query.get(user_id)
session['user_id'] = user.id session['user_id'] = user.id
session['username'] = user.username session['username'] = user.username
# removes temporary session
session.pop("login_user_id", None) session.pop("login_user_id", None)
session.pop("login_expected_phrase", None) session.pop("login_expected_phrase", None)
@ -274,12 +340,13 @@ def login_verify():
@app.route("/logout") @app.route("/logout")
def logout(): def logout():
# removes session
session.pop('user_id', None) session.pop('user_id', None)
session.pop('username', None) session.pop('username', None)
flash("Logged out successfully") flash("Logged out successfully")
return redirect(url_for("home")) return redirect(url_for("home"))
# Renders users route
@app.route("/user/<username>") @app.route("/user/<username>")
def user_profile(username): def user_profile(username):
user = User.query.filter_by(username=username).first_or_404() user = User.query.filter_by(username=username).first_or_404()