176 lines
5.8 KiB
Python
176 lines
5.8 KiB
Python
#!/usr/bin/env python3
|
|
|
|
# nonstdlib requirements: certbot, cryptography, dasbus
|
|
# non-programmatic: caddy, maddy, systemd
|
|
|
|
import datetime
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any, Literal
|
|
|
|
import certbot.main as cbmain
|
|
import dasbus
|
|
import requests
|
|
from cryptography import x509
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from dasbus.connection import SystemMessageBus
|
|
|
|
API_TOKEN = os.environ['API_TOKEN']
|
|
DOMAIN = 'matous.dev'
|
|
MX_SERVER = 'mx1'
|
|
MX_SUBDOMAIN = f'{MX_SERVER}.{DOMAIN}'
|
|
RECORD_NAME = f'_25._tcp.{MX_SERVER}'
|
|
LIVE_CERT_PATH = Path(f'/etc/letsencrypt/live/{MX_SUBDOMAIN}/')
|
|
NJALLA = 'https://njal.la/api/1/'
|
|
|
|
type Method = Literal['add-record', 'list-records', 'remove-record']
|
|
|
|
|
|
def njalla_api(method: Method, **params: dict[str, int | str]) -> dict[str, Any]:
|
|
headers = {'Authorization': f'Njalla {API_TOKEN}'}
|
|
params['domain'] = DOMAIN
|
|
response = requests.post(NJALLA, json={'method': method, 'params': params}, headers=headers, timeout=30)
|
|
response.raise_for_status()
|
|
j_response = response.json()
|
|
if 'result' not in j_response:
|
|
raise RuntimeError('API Error', j_response)
|
|
return j_response['result']
|
|
|
|
|
|
def add_tlsa(new_tlsa: str) -> None:
|
|
print('Adding new TLSA record')
|
|
njalla_api('add-record', name='_25._tcp.mx1', content=new_tlsa, ttl=3600, type='TLSA')
|
|
|
|
|
|
def remove_tlsa(entry_id: str) -> None:
|
|
print('Removing old TLSA record')
|
|
njalla_api('remove-record', id=entry_id)
|
|
|
|
|
|
def list_tlsa() -> dict[str, list[dict[str, str]]]:
|
|
return njalla_api('list-records')
|
|
|
|
|
|
def get_cert_hash(cert: x509.Certificate) -> str:
|
|
pubkey = cert.public_key()
|
|
pubkey_bytes = pubkey.public_bytes(serialization.Encoding.DER, serialization.PublicFormat.SubjectPublicKeyInfo)
|
|
digest = hashes.Hash(hashes.SHA256())
|
|
digest.update(pubkey_bytes)
|
|
return digest.finalize().hex()
|
|
|
|
|
|
def renew_cert(current_cert_hash: str, old_links: dict[Path, Path]) -> None:
|
|
bus = SystemMessageBus()
|
|
logind = bus.get_proxy('org.freedesktop.login1', '/org/freedesktop/login1')
|
|
logind.Inhibit('shutdown', sys.argv[0], 'MX certs + TLSA rotation in progress', 'block')
|
|
|
|
new_cert_hash, new_links, systemd = renew_local(bus)
|
|
if new_cert_hash == current_cert_hash:
|
|
raise RuntimeError('New cert is the same as current. Perhaps the script ran prematurely?')
|
|
# restore old symlinks in /live/ for a bit
|
|
for link, target in old_links.items():
|
|
link.unlink()
|
|
link.symlink_to(target)
|
|
# this little maneuver's gonna cost us 3 hours
|
|
print('Waiting 3h for TLSA changes to propagate')
|
|
rotate_tlsa(current_cert_hash, new_cert_hash)
|
|
# use new certs in /live/
|
|
for link, target in new_links.items():
|
|
link.unlink()
|
|
link.symlink_to(target)
|
|
# alert maddy to new certs
|
|
systemd.ReloadUnit('maddy.service', 'replace')
|
|
print('Reloaded maddy')
|
|
|
|
|
|
def save_live_symlinks() -> dict[Path, Path]:
|
|
links = {}
|
|
for child in LIVE_CERT_PATH.iterdir():
|
|
if child.is_symlink():
|
|
links[child] = child.resolve(strict=True)
|
|
print(f'{child} is symlink to {links[child]}')
|
|
return links
|
|
|
|
|
|
def rotate_cert_and_tlsa(current_cert: x509.Certificate) -> None:
|
|
print('Getting current cert hash')
|
|
chash = get_cert_hash(current_cert)
|
|
print('Saving current (old) symlinks')
|
|
old_links = save_live_symlinks()
|
|
print('Renewing cert')
|
|
renew_cert(chash, old_links)
|
|
|
|
|
|
def renew_local(bus: dasbus.connection.MessageBus) -> tuple[str, dict[Path, Path], dasbus.connection.ObjectProxy]:
|
|
# serve challenges from subdomain
|
|
caddy_subconf_content = f'{MX_SUBDOMAIN}:80 {{\n root * /srv/{DOMAIN} \n file_server \n }}'
|
|
caddy_subconf_file = Path(f'/etc/caddy/Caddyfile.d/{MX_SUBDOMAIN}')
|
|
caddy_subconf_file.write_text(caddy_subconf_content)
|
|
systemd = bus.get_proxy('org.freedesktop.systemd1', '/org/freedesktop/systemd1')
|
|
systemd.ReloadUnit('caddy.service', 'replace')
|
|
print('Caddy reloaded')
|
|
|
|
# make holes in FW for plaintext traffic
|
|
firewalld = bus.get_proxy('org.fedoraproject.FirewallD1', '/org/fedoraproject/FirewallD1')
|
|
try:
|
|
zone = firewalld.addPort('', '80', 'tcp', 0)
|
|
print(f'Added port 80/tcp to {zone}')
|
|
except dasbus.error.DBusError as e:
|
|
if e == 'ALREADY_ENABLED':
|
|
print('Port 80/tcp already open')
|
|
|
|
# non-interactive (-n) renewals insert random delays for over 5 minutes
|
|
cbmain.main(['renew', '--webroot', '-w', f'/srv/{DOMAIN}', '--cert-name', MX_SUBDOMAIN, '-n'])
|
|
print('Cert renewed')
|
|
subprocess.run(
|
|
['/bin/setfacl', '--recursive', '--logical', '--modify', 'u:maddy:rX', f'/etc/letsencrypt/live/{MX_SUBDOMAIN}'],
|
|
check=True,
|
|
)
|
|
print("Cert acl'ed")
|
|
|
|
zone = firewalld.removePort('', '80', 'tcp')
|
|
print(f'Removed port 80/tcp from {zone}')
|
|
caddy_subconf_file.unlink()
|
|
systemd.ReloadUnit('caddy.service', 'replace')
|
|
print('Caddy reloaded')
|
|
|
|
# live cert changed, get new hash, save links
|
|
pem_data = (LIVE_CERT_PATH / 'cert.pem').read_bytes()
|
|
new_cert = x509.load_pem_x509_certificate(pem_data)
|
|
new_hash = get_cert_hash(new_cert)
|
|
new_links = save_live_symlinks()
|
|
|
|
return new_hash, new_links, systemd
|
|
|
|
|
|
def rotate_tlsa(current_cert_hash: str, new_cert_hash: str) -> None:
|
|
# get tlsas from njalla, get current id, delete extraneous
|
|
entries = list_tlsa()['records']
|
|
current_id = None
|
|
for entry in entries:
|
|
if entry['content'] != current_cert_hash:
|
|
remove_tlsa(entry['id'])
|
|
else:
|
|
current_id = entry['id']
|
|
|
|
add_tlsa(f'3 1 1 {new_cert_hash}')
|
|
# give DNS time to propagate
|
|
hours_3 = 60 * 60 * 3
|
|
time.sleep(hours_3)
|
|
if current_id is not None:
|
|
remove_tlsa(current_id)
|
|
else:
|
|
print('No previous entry to remove.')
|
|
|
|
|
|
pem_data = (LIVE_CERT_PATH / 'cert.pem').read_bytes()
|
|
current_cert = x509.load_pem_x509_certificate(pem_data)
|
|
to_expiry = current_cert.not_valid_after_utc - datetime.datetime.now(datetime.UTC)
|
|
if to_expiry < datetime.timedelta(days=8):
|
|
rotate_cert_and_tlsa(current_cert)
|
|
print('TLSA rotation complete')
|
|
else:
|
|
print(f'Cert has {to_expiry} days remaining. Nothing to do.')
|