Machine Learning
Systems Architect,
PhD Mathematician
My first responsibility since starting at TiltingPoint was to fix their data ingestion pipeline. Most of their daily jobs were running on Jenkins without any retry logic, logging, or graceful handling of errors. I decided to replace it with Apache Airflow, originally developed at Airbnb, to replace the existing pipeline.
Airflow comes with a rich set of features out of the box: clean UI, relational DB metastore, built-in scheduler, task sensors, logging, etc., but I made a few customizations that helped make it more useful and secure.
Here’s a quick guide (for Airflow 1.9)
There are a few stackoverflow posts about how to log worker processes to S3. None of them are complete, but I managed to piece them together to get it to work. Here is a copy of the file I used from a stackoverflow response. I have this file saved as config/log_config.py
in the project directory. You’ll definitely want this because when using the default local logging: if your worker instance dies, so will all of its logs.
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
import os
from airflow import configuration as conf
# TODO: Logging format and level should be configured
# in this file instead of from airflow.cfg. Currently
# there are other log format and level configurations in
# settings.py and cli.py. Please see AIRFLOW-1455.
LOG_LEVEL = conf.get('core', 'LOGGING_LEVEL').upper()
LOG_FORMAT = conf.get('core', 'log_format')
BASE_LOG_FOLDER = conf.get('core', 'BASE_LOG_FOLDER')
PROCESSOR_LOG_FOLDER = conf.get('scheduler', 'child_process_log_directory')
FILENAME_TEMPLATE = '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log'
PROCESSOR_FILENAME_TEMPLATE = '{{ filename }}.log'
DEFAULT_LOGGING_CONFIG = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'airflow.task': {
'format': LOG_FORMAT,
},
'airflow.processor': {
'format': LOG_FORMAT,
},
},
'handlers': {
'console': {
'class': 'logging.StreamHandler',
'formatter': 'airflow.task',
'stream': 'ext://sys.stdout'
},
'file.task': {
'class': 'airflow.utils.log.file_task_handler.FileTaskHandler',
'formatter': 'airflow.task',
'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER),
'filename_template': FILENAME_TEMPLATE,
},
'file.processor': {
'class': 'airflow.utils.log.file_processor_handler.FileProcessorHandler',
'formatter': 'airflow.processor',
'base_log_folder': os.path.expanduser(PROCESSOR_LOG_FOLDER),
'filename_template': PROCESSOR_FILENAME_TEMPLATE,
}
# When using s3 or gcs, provide a customized LOGGING_CONFIG
# in airflow_local_settings within your PYTHONPATH, see UPDATING.md
# for details
# 's3.task': {
# 'class': 'airflow.utils.log.s3_task_handler.S3TaskHandler',
# 'formatter': 'airflow.task',
# 'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER),
# 's3_log_folder': S3_LOG_FOLDER,
# 'filename_template': FILENAME_TEMPLATE,
# },
# 'gcs.task': {
# 'class': 'airflow.utils.log.gcs_task_handler.GCSTaskHandler',
# 'formatter': 'airflow.task',
# 'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER),
# 'gcs_log_folder': GCS_LOG_FOLDER,
# 'filename_template': FILENAME_TEMPLATE,
# },
},
'loggers': {
'': {
'handlers': ['console'],
'level': LOG_LEVEL
},
'airflow': {
'handlers': ['console'],
'level': LOG_LEVEL,
'propagate': False,
},
'airflow.processor': {
'handlers': ['file.processor'],
'level': LOG_LEVEL,
'propagate': True,
},
'airflow.task': {
'handlers': ['file.task'],
'level': LOG_LEVEL,
'propagate': False,
},
'airflow.task_runner': {
'handlers': ['file.task'],
'level': LOG_LEVEL,
'propagate': True,
},
}
}
REMOTE_BASE_LOG_FOLDER = conf.get("core", "REMOTE_BASE_LOG_FOLDER")
REMOTE_HANDLERS = {
's3': {
'file.task': {
'class': 'airflow.utils.log.s3_task_handler.S3TaskHandler',
'formatter': 'airflow.task',
'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER),
's3_log_folder': REMOTE_BASE_LOG_FOLDER,
'filename_template': FILENAME_TEMPLATE,
},
'file.processor': {
'class': 'airflow.utils.log.s3_task_handler.S3TaskHandler',
'formatter': 'airflow.processor',
'base_log_folder': os.path.expanduser(PROCESSOR_LOG_FOLDER),
's3_log_folder': REMOTE_BASE_LOG_FOLDER,
'filename_template': PROCESSOR_FILENAME_TEMPLATE,
},
},
}
REMOTE_LOGGING = conf.get('core', 'remote_logging')
if REMOTE_LOGGING and REMOTE_BASE_LOG_FOLDER.startswith('s3://'):
DEFAULT_LOGGING_CONFIG['handlers'].update(REMOTE_HANDLERS['s3'])
You’ll also need these settings in your airflow.cfg
file:
[core]
# ...
# Airflow can store logs remotely in AWS S3 or Google Cloud Storage. Users
# must supply an Airflow connection id that provides access to the storage
# location.
remote_logging = True
remote_base_log_folder = s3://mybucket/airflow/logs
# Logging class
# Specify the class that will specify the logging configuration
# This class has to be on the python classpath
logging_config_class = config.log_config.DEFAULT_LOGGING_CONFIG
I used an airflow plugin from this gist to force https
on all requests. The Airflow frontend UI is written in Flask, so this code actually extends the underlying Flask application that powers the webserver. Place this file in the plugins
directory in your Airflow project.
"""Force SSL."""
from flask import Blueprint
from flask import redirect
# from flask import url_for
from flask import request
from flask import current_app
from airflow.plugins_manager import AirflowPlugin
YEAR_IN_SECS = 31536000
ssl_bp = Blueprint('ssl_everything', __name__)
@ssl_bp.before_app_request
def before_request():
app = current_app._get_current_object()
criteria = [
request.is_secure,
app.debug,
request.headers.get('X-Forwarded-Proto', 'http') == 'https'
]
if not any(criteria):
if request.url.startswith('http://'):
url = request.url.replace('http://', 'https://', 1)
r = redirect(url, code=302)
return r
@ssl_bp.after_app_request
def after_request(response):
hsts_policy = 'max-age={0}'.format(YEAR_IN_SECS)
response.headers.setdefault('Strict-Transport-Security', hsts_policy)
return response
class AirflowSSLPlugin(AirflowPlugin):
name = 'ssl_everything'
flask_blueprints = [ssl_bp]
Make sure that the airflow.cfg
points to the plugins directory:
[core]
# ...
# Where your Airflow plugins are stored
plugins_folder = /myairflow/plugins
Airflow comes with GitHub Enterprise OAuth configuration out of the box, but it’s actually broken in Airflow 1.9 due to the logger mixin being named incorrectly. Luckily there are only a few small changes needed to get it to work with standard GitHub organizations. When deploying, prior to launching Airflow, you’ll need to be sure to overwrite Airflow’s package file airflow/contrib/auth/backends/github_enterprise_auth.py
with this one:
# Copyright 2015 Matthew Pelland (matt@pelland.io)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
import flask_login
# Need to expose these downstream
# pylint: disable=unused-import
from flask_login import (current_user,
logout_user,
login_required,
login_user)
# pylint: enable=unused-import
from flask import url_for, redirect, request
from flask_oauthlib.client import OAuth
from airflow import models, configuration, settings
from airflow.configuration import AirflowConfigException
from airflow.utils.log.logging_mixin import LoggingMixin
log = LoggingMixin().log
def get_config_param(param):
return str(configuration.get('github_enterprise', param))
class GHEUser(models.User):
def __init__(self, user):
self.user = user
def is_active(self):
'''Required by flask_login'''
return True
def is_authenticated(self):
'''Required by flask_login'''
return True
def is_anonymous(self):
'''Required by flask_login'''
return False
def get_id(self):
'''Returns the current user id as required by flask_login'''
return self.user.get_id()
def data_profiling(self):
'''Provides access to data profiling tools'''
return True
def is_superuser(self):
'''Access all the things'''
return True
class AuthenticationError(Exception):
pass
class GHEAuthBackend(object):
def __init__(self):
self.ghe_host = get_config_param('host')
self.login_manager = flask_login.LoginManager()
self.login_manager.login_view = 'airflow.login'
self.flask_app = None
self.ghe_oauth = None
self.api_rev = None
def ghe_api_route(self, leaf):
if not self.api_rev:
self.api_rev = get_config_param('api_rev')
# Modified to work with GitHub organizations
return '/'.join(['https:/',
"api." + self.ghe_host,
leaf.strip('/')])
def init_app(self, flask_app):
self.flask_app = flask_app
self.login_manager.init_app(self.flask_app)
self.ghe_oauth = OAuth(self.flask_app).remote_app(
'ghe',
consumer_key=get_config_param('client_id'),
consumer_secret=get_config_param('client_secret'),
# need read:org to get team member list
request_token_params={'scope': 'user:email,read:org'},
base_url=self.ghe_host,
request_token_url=None,
access_token_method='POST',
access_token_url=''.join(['https://',
self.ghe_host,
'/login/oauth/access_token']),
authorize_url=''.join(['https://',
self.ghe_host,
'/login/oauth/authorize']))
self.login_manager.user_loader(self.load_user)
self.flask_app.add_url_rule(get_config_param('oauth_callback_route'),
'ghe_oauth_callback',
self.oauth_callback)
def login(self, request):
log.debug('Redirecting user to GHE login')
return self.ghe_oauth.authorize(callback=url_for(
'ghe_oauth_callback',
_external=True,
next=request.args.get('next') or request.referrer or None))
def get_ghe_user_profile_info(self, ghe_token):
resp = self.ghe_oauth.get(self.ghe_api_route('/user'),
token=(ghe_token, ''))
if not resp or resp.status != 200:
raise AuthenticationError(
'Failed to fetch user profile, status ({0})'.format(
resp.status if resp else 'None'))
return resp.data['login'], resp.data['email']
def ghe_team_check(self, username, ghe_token):
try:
# the response from ghe returns the id of the team as an integer
try:
allowed_teams = [int(team.strip())
for team in
get_config_param('allowed_teams').split(',')]
except ValueError:
# this is to deprecate using the string name for a team
raise ValueError('it appears that you are using the string name for a team, '
'please use the id number instead')
except AirflowConfigException:
# No allowed teams defined, let anyone in GHE in.
return True
# https://developer.github.com/v3/orgs/teams/#list-user-teams
resp = self.ghe_oauth.get(self.ghe_api_route('/user/teams'),
token=(ghe_token, ''))
if not resp or resp.status != 200:
raise AuthenticationError(
'Bad response from GHE ({0})'.format(
resp.status if resp else 'None'))
for team in resp.data:
# mylons: previously this line used to be if team['slug'] in teams
# however, teams are part of organizations. organizations are unique,
# but teams are not therefore 'slug' for a team is not necessarily unique.
# use id instead
if team['id'] in allowed_teams:
return True
log.debug('Denying access for user "%s", not a member of "%s"',
username,
str(allowed_teams))
return False
def load_user(self, userid):
if not userid or userid == 'None':
return None
session = settings.Session()
user = session.query(models.User).filter(
models.User.id == int(userid)).first()
session.expunge_all()
session.commit()
session.close()
return GHEUser(user)
def oauth_callback(self):
log.debug('GHE OAuth callback called')
next_url = request.args.get('next') or url_for('admin.index')
resp = self.ghe_oauth.authorized_response()
try:
if resp is None:
raise AuthenticationError(
'Null response from GHE, denying access.'
)
ghe_token = resp['access_token']
username, email = self.get_ghe_user_profile_info(ghe_token)
if not self.ghe_team_check(username, ghe_token):
return redirect(url_for('airflow.noaccess'))
except AuthenticationError:
log.exception('')
return redirect(url_for('airflow.noaccess'))
session = settings.Session()
user = session.query(models.User).filter(
models.User.username == username).first()
if not user:
user = models.User(
username=username,
email=email,
is_superuser=False)
session.merge(user)
session.commit()
login_user(GHEUser(user))
session.commit()
session.close()
return redirect(next_url)
login_manager = GHEAuthBackend()
def login(self, request):
return login_manager.login(request)
This connects to the github.com API rather than an enterprise URI. This also allows you to use the same configuration options in airflow.cfg
for GitHub organization OAuth but with your standard GitHub org.
[webserver]
# ...
# Set to true to turn on authentication:
# http://pythonhosted.org/airflow/security.html#web-authentication
authenticate = True
auth_backend = airflow.contrib.auth.backends.github_enterprise_auth
[github_enterprise]
host = github.com
# From your OAuth app
client_id = your_client_id
client_secret = your_secret_id
# Make sure this is the same as your OAuth app callback path on your org's GitHub OAuth app
oauth_callback_route = /mycompany/oauth/callback
# Get organization team id numbers from the github API
# https://developer.github.com/v3/teams/
allowed_teams = 123,456,789
You’ll probably want to subclass Airflow’s BaseOperator
when building your own operators. When you do this you can add an on_failure_callback
as a parameter to the parent BaseOperator
class. You can pass a function that looks like this, done in the same style as my celery-slack Python package. You’ll have to ensure that the webhook is kept secret in your backend metastore (I used a PostgreSQL db).
from airflow.models import Variable
import requests
BASE_URL = "https://airflow.mycompany.com"
def post_failure_to_slack(context):
"""Post failure to slack."""
dag_id = context["dag"].dag_id
task_id = context["task"].task_id
author = context["dag"].owner
execution_date = context["execution_date"].strftime("%Y-%m-%dT%H:%M:%S")
log_url = f"{BASE_URL}/admin/airflow/log?task_id={task_id}&dag_id={dag_id}&execution_date={execution_date}" # noqa
attachment = {
"fallback": f"<{log_url}|View failure log>",
"color": "danger",
"pretext": f"<{log_url}|View failure log>",
"author_name": author,
"fields": [
{
"title": "Failed task",
"value": f"{dag_id}.{task_id}",
"short": False,
}
]
}
payload = {"text": "", "attachments": [attachment]}
response = requests.post(Variable.get("SLACK_WEBHOOK"), json=payload)
After configuring the Slack webhook as the Airflow app, you’ll get messages like this on task failures with links to the Airflow webservers log page for the task:
While not really mentioned in the docs, you can dynamically create DAGs for Airflow using a pattern similar to the following. You must add the DAG instances to the Python globals so they are picked up by the scheduler. This example creates three separate DAGs consisting of a BashOperator that echos one of the param values.
"""Dynamically define DAGs."""
import datetime
from airflow.models import DAG
from airflow import operators
start_date = datetime.datetime(2018, 5, 22)
params = [
"one",
"two",
"three",
]
# Dynamically generate the dags from the list of parameters
dags = {}
for param in params:
dag_args = {
"owner": "airflow",
"start_date": start_date,
}
dag = DAG(
f"my_dag_{param}",
default_args=dag_args,
schedule_interval="0 0 * * *",
)
databricks_process = operators.BashOperator(
task_id=f"echo_param_{param}",
bash_command=f"echo {param}"
dag=dag,
)
dags[f"dag_{param}"] = dag
# This is necessary to make dynamically generated dags visible to the scheduler
globals().update(dags)