flynn.gg

Christopher Flynn

Machine Learning
Systems Architect,
PhD Mathematician

Home
Projects
Open Source
Blog
Résumé

GitHub
LinkedIn

Blog


Airflow (1.9) Mods

2018-07-05 Feed

My first responsibility since starting at TiltingPoint was to fix their data ingestion pipeline. Most of their daily jobs were running on Jenkins without any retry logic, logging, or graceful handling of errors. I decided to replace it with Apache Airflow, originally developed at Airbnb, to replace the existing pipeline.

Airflow comes with a rich set of features out of the box: clean UI, relational DB metastore, built-in scheduler, task sensors, logging, etc., but I made a few customizations that helped make it more useful and secure.

Here’s a quick guide (for Airflow 1.9)

Logging to S3

There are a few stackoverflow posts about how to log worker processes to S3. None of them are complete, but I managed to piece them together to get it to work. Here is a copy of the file I used from a stackoverflow response. I have this file saved as config/log_config.py in the project directory. You’ll definitely want this because when using the default local logging: if your worker instance dies, so will all of its logs.

# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
import os

from airflow import configuration as conf

# TODO: Logging format and level should be configured
# in this file instead of from airflow.cfg. Currently
# there are other log format and level configurations in
# settings.py and cli.py. Please see AIRFLOW-1455.

LOG_LEVEL = conf.get('core', 'LOGGING_LEVEL').upper()
LOG_FORMAT = conf.get('core', 'log_format')

BASE_LOG_FOLDER = conf.get('core', 'BASE_LOG_FOLDER')
PROCESSOR_LOG_FOLDER = conf.get('scheduler', 'child_process_log_directory')

FILENAME_TEMPLATE = '{{ ti.dag_id }}/{{ ti.task_id }}/{{ ts }}/{{ try_number }}.log'
PROCESSOR_FILENAME_TEMPLATE = '{{ filename }}.log'

DEFAULT_LOGGING_CONFIG = {
    'version': 1,
    'disable_existing_loggers': False,
    'formatters': {
        'airflow.task': {
            'format': LOG_FORMAT,
        },
        'airflow.processor': {
            'format': LOG_FORMAT,
        },
    },
    'handlers': {
        'console': {
            'class': 'logging.StreamHandler',
            'formatter': 'airflow.task',
            'stream': 'ext://sys.stdout'
        },
        'file.task': {
            'class': 'airflow.utils.log.file_task_handler.FileTaskHandler',
            'formatter': 'airflow.task',
            'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER),
            'filename_template': FILENAME_TEMPLATE,
        },
        'file.processor': {
            'class': 'airflow.utils.log.file_processor_handler.FileProcessorHandler',
            'formatter': 'airflow.processor',
            'base_log_folder': os.path.expanduser(PROCESSOR_LOG_FOLDER),
            'filename_template': PROCESSOR_FILENAME_TEMPLATE,
        }
        # When using s3 or gcs, provide a customized LOGGING_CONFIG
        # in airflow_local_settings within your PYTHONPATH, see UPDATING.md
        # for details
        # 's3.task': {
        #     'class': 'airflow.utils.log.s3_task_handler.S3TaskHandler',
        #     'formatter': 'airflow.task',
        #     'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER),
        #     's3_log_folder': S3_LOG_FOLDER,
        #     'filename_template': FILENAME_TEMPLATE,
        # },
        # 'gcs.task': {
        #     'class': 'airflow.utils.log.gcs_task_handler.GCSTaskHandler',
        #     'formatter': 'airflow.task',
        #     'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER),
        #     'gcs_log_folder': GCS_LOG_FOLDER,
        #     'filename_template': FILENAME_TEMPLATE,
        # },
    },
    'loggers': {
        '': {
            'handlers': ['console'],
            'level': LOG_LEVEL
        },
        'airflow': {
            'handlers': ['console'],
            'level': LOG_LEVEL,
            'propagate': False,
        },
        'airflow.processor': {
            'handlers': ['file.processor'],
            'level': LOG_LEVEL,
            'propagate': True,
        },
        'airflow.task': {
            'handlers': ['file.task'],
            'level': LOG_LEVEL,
            'propagate': False,
        },
        'airflow.task_runner': {
            'handlers': ['file.task'],
            'level': LOG_LEVEL,
            'propagate': True,
        },
    }
}

REMOTE_BASE_LOG_FOLDER = conf.get("core", "REMOTE_BASE_LOG_FOLDER")


REMOTE_HANDLERS = {
    's3': {
        'file.task': {
            'class': 'airflow.utils.log.s3_task_handler.S3TaskHandler',
            'formatter': 'airflow.task',
            'base_log_folder': os.path.expanduser(BASE_LOG_FOLDER),
            's3_log_folder': REMOTE_BASE_LOG_FOLDER,
            'filename_template': FILENAME_TEMPLATE,
        },
        'file.processor': {
            'class': 'airflow.utils.log.s3_task_handler.S3TaskHandler',
            'formatter': 'airflow.processor',
            'base_log_folder': os.path.expanduser(PROCESSOR_LOG_FOLDER),
            's3_log_folder': REMOTE_BASE_LOG_FOLDER,
            'filename_template': PROCESSOR_FILENAME_TEMPLATE,
        },
    },
}

REMOTE_LOGGING = conf.get('core', 'remote_logging')

if REMOTE_LOGGING and REMOTE_BASE_LOG_FOLDER.startswith('s3://'):
        DEFAULT_LOGGING_CONFIG['handlers'].update(REMOTE_HANDLERS['s3'])

You’ll also need these settings in your airflow.cfg file:

[core]
# ...

# Airflow can store logs remotely in AWS S3 or Google Cloud Storage. Users
# must supply an Airflow connection id that provides access to the storage
# location.
remote_logging = True
remote_base_log_folder = s3://mybucket/airflow/logs

# Logging class
# Specify the class that will specify the logging configuration
# This class has to be on the python classpath
logging_config_class = config.log_config.DEFAULT_LOGGING_CONFIG

Forcing SSL

I used an airflow plugin from this gist to force https on all requests. The Airflow frontend UI is written in Flask, so this code actually extends the underlying Flask application that powers the webserver. Place this file in the plugins directory in your Airflow project.

"""Force SSL."""
from flask import Blueprint
from flask import redirect
# from flask import url_for
from flask import request
from flask import current_app
from airflow.plugins_manager import AirflowPlugin


YEAR_IN_SECS = 31536000
ssl_bp = Blueprint('ssl_everything', __name__)


@ssl_bp.before_app_request
def before_request():
    app = current_app._get_current_object()
    criteria = [
        request.is_secure,
        app.debug,
        request.headers.get('X-Forwarded-Proto', 'http') == 'https'
    ]

    if not any(criteria):
        if request.url.startswith('http://'):
            url = request.url.replace('http://', 'https://', 1)
            r = redirect(url, code=302)
            return r


@ssl_bp.after_app_request
def after_request(response):
    hsts_policy = 'max-age={0}'.format(YEAR_IN_SECS)
    response.headers.setdefault('Strict-Transport-Security', hsts_policy)
    return response


class AirflowSSLPlugin(AirflowPlugin):
    name = 'ssl_everything'
    flask_blueprints = [ssl_bp]

Make sure that the airflow.cfg points to the plugins directory:

[core]
# ...

# Where your Airflow plugins are stored
plugins_folder = /myairflow/plugins

GitHub OAuth

Airflow comes with GitHub Enterprise OAuth configuration out of the box, but it’s actually broken in Airflow 1.9 due to the logger mixin being named incorrectly. Luckily there are only a few small changes needed to get it to work with standard GitHub organizations. When deploying, prior to launching Airflow, you’ll need to be sure to overwrite Airflow’s package file airflow/contrib/auth/backends/github_enterprise_auth.py with this one:

# Copyright 2015 Matthew Pelland (matt@pelland.io)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
import flask_login

# Need to expose these downstream
# pylint: disable=unused-import
from flask_login import (current_user,
                         logout_user,
                         login_required,
                         login_user)
# pylint: enable=unused-import

from flask import url_for, redirect, request

from flask_oauthlib.client import OAuth

from airflow import models, configuration, settings
from airflow.configuration import AirflowConfigException
from airflow.utils.log.logging_mixin import LoggingMixin

log = LoggingMixin().log


def get_config_param(param):
    return str(configuration.get('github_enterprise', param))


class GHEUser(models.User):

    def __init__(self, user):
        self.user = user

    def is_active(self):
        '''Required by flask_login'''
        return True

    def is_authenticated(self):
        '''Required by flask_login'''
        return True

    def is_anonymous(self):
        '''Required by flask_login'''
        return False

    def get_id(self):
        '''Returns the current user id as required by flask_login'''
        return self.user.get_id()

    def data_profiling(self):
        '''Provides access to data profiling tools'''
        return True

    def is_superuser(self):
        '''Access all the things'''
        return True


class AuthenticationError(Exception):
    pass


class GHEAuthBackend(object):

    def __init__(self):
        self.ghe_host = get_config_param('host')
        self.login_manager = flask_login.LoginManager()
        self.login_manager.login_view = 'airflow.login'
        self.flask_app = None
        self.ghe_oauth = None
        self.api_rev = None

    def ghe_api_route(self, leaf):
        if not self.api_rev:
            self.api_rev = get_config_param('api_rev')
        # Modified to work with GitHub organizations
        return '/'.join(['https:/',
                         "api." + self.ghe_host,
                         leaf.strip('/')])

    def init_app(self, flask_app):
        self.flask_app = flask_app

        self.login_manager.init_app(self.flask_app)

        self.ghe_oauth = OAuth(self.flask_app).remote_app(
            'ghe',
            consumer_key=get_config_param('client_id'),
            consumer_secret=get_config_param('client_secret'),
            # need read:org to get team member list
            request_token_params={'scope': 'user:email,read:org'},
            base_url=self.ghe_host,
            request_token_url=None,
            access_token_method='POST',
            access_token_url=''.join(['https://',
                                      self.ghe_host,
                                      '/login/oauth/access_token']),
            authorize_url=''.join(['https://',
                                   self.ghe_host,
                                   '/login/oauth/authorize']))

        self.login_manager.user_loader(self.load_user)

        self.flask_app.add_url_rule(get_config_param('oauth_callback_route'),
                                    'ghe_oauth_callback',
                                    self.oauth_callback)

    def login(self, request):
        log.debug('Redirecting user to GHE login')
        return self.ghe_oauth.authorize(callback=url_for(
            'ghe_oauth_callback',
            _external=True,
            next=request.args.get('next') or request.referrer or None))

    def get_ghe_user_profile_info(self, ghe_token):
        resp = self.ghe_oauth.get(self.ghe_api_route('/user'),
                                  token=(ghe_token, ''))

        if not resp or resp.status != 200:
            raise AuthenticationError(
                'Failed to fetch user profile, status ({0})'.format(
                    resp.status if resp else 'None'))

        return resp.data['login'], resp.data['email']

    def ghe_team_check(self, username, ghe_token):
        try:
            # the response from ghe returns the id of the team as an integer
            try:
                allowed_teams = [int(team.strip())
                                 for team in
                                 get_config_param('allowed_teams').split(',')]
            except ValueError:
                # this is to deprecate using the string name for a team
                raise ValueError('it appears that you are using the string name for a team, '
                                 'please use the id number instead')

        except AirflowConfigException:
            # No allowed teams defined, let anyone in GHE in.
            return True

        # https://developer.github.com/v3/orgs/teams/#list-user-teams
        resp = self.ghe_oauth.get(self.ghe_api_route('/user/teams'),
                                  token=(ghe_token, ''))

        if not resp or resp.status != 200:
            raise AuthenticationError(
                'Bad response from GHE ({0})'.format(
                    resp.status if resp else 'None'))

        for team in resp.data:
            # mylons: previously this line used to be if team['slug'] in teams
            # however, teams are part of organizations. organizations are unique,
            # but teams are not therefore 'slug' for a team is not necessarily unique.
            # use id instead
            if team['id'] in allowed_teams:
                return True

        log.debug('Denying access for user "%s", not a member of "%s"',
                   username,
                   str(allowed_teams))

        return False

    def load_user(self, userid):
        if not userid or userid == 'None':
            return None

        session = settings.Session()
        user = session.query(models.User).filter(
            models.User.id == int(userid)).first()
        session.expunge_all()
        session.commit()
        session.close()
        return GHEUser(user)

    def oauth_callback(self):
        log.debug('GHE OAuth callback called')

        next_url = request.args.get('next') or url_for('admin.index')

        resp = self.ghe_oauth.authorized_response()

        try:
            if resp is None:
                raise AuthenticationError(
                    'Null response from GHE, denying access.'
                )

            ghe_token = resp['access_token']

            username, email = self.get_ghe_user_profile_info(ghe_token)

            if not self.ghe_team_check(username, ghe_token):
                return redirect(url_for('airflow.noaccess'))

        except AuthenticationError:
            log.exception('')
            return redirect(url_for('airflow.noaccess'))

        session = settings.Session()

        user = session.query(models.User).filter(
            models.User.username == username).first()

        if not user:
            user = models.User(
                username=username,
                email=email,
                is_superuser=False)

        session.merge(user)
        session.commit()
        login_user(GHEUser(user))
        session.commit()
        session.close()

        return redirect(next_url)

login_manager = GHEAuthBackend()


def login(self, request):
    return login_manager.login(request)

This connects to the github.com API rather than an enterprise URI. This also allows you to use the same configuration options in airflow.cfg for GitHub organization OAuth but with your standard GitHub org.

[webserver]
# ...

# Set to true to turn on authentication:
# http://pythonhosted.org/airflow/security.html#web-authentication
authenticate = True
auth_backend = airflow.contrib.auth.backends.github_enterprise_auth

[github_enterprise]
host = github.com
# From your OAuth app
client_id = your_client_id
client_secret = your_secret_id
# Make sure this is the same as your OAuth app callback path on your org's GitHub OAuth app
oauth_callback_route = /mycompany/oauth/callback
# Get organization team id numbers from the github API
# https://developer.github.com/v3/teams/
allowed_teams = 123,456,789

Slack messaging

You’ll probably want to subclass Airflow’s BaseOperator when building your own operators. When you do this you can add an on_failure_callback as a parameter to the parent BaseOperator class. You can pass a function that looks like this, done in the same style as my celery-slack Python package. You’ll have to ensure that the webhook is kept secret in your backend metastore (I used a PostgreSQL db).

from airflow.models import Variable
import requests


BASE_URL = "https://airflow.mycompany.com"


def post_failure_to_slack(context):
    """Post failure to slack."""
    dag_id = context["dag"].dag_id
    task_id = context["task"].task_id
    author = context["dag"].owner
    execution_date = context["execution_date"].strftime("%Y-%m-%dT%H:%M:%S")
    log_url = f"{BASE_URL}/admin/airflow/log?task_id={task_id}&dag_id={dag_id}&execution_date={execution_date}"  # noqa

    attachment = {
        "fallback": f"<{log_url}|View failure log>",
        "color": "danger",
        "pretext": f"<{log_url}|View failure log>",
        "author_name": author,
        "fields": [
            {
                "title": "Failed task",
                "value": f"{dag_id}.{task_id}",
                "short": False,
            }
        ]
    }

    payload = {"text": "", "attachments": [attachment]}
    response = requests.post(Variable.get("SLACK_WEBHOOK"), json=payload)

After configuring the Slack webhook as the Airflow app, you’ll get messages like this on task failures with links to the Airflow webservers log page for the task:

Airflow Slack

Dynamic DAGs

While not really mentioned in the docs, you can dynamically create DAGs for Airflow using a pattern similar to the following. You must add the DAG instances to the Python globals so they are picked up by the scheduler. This example creates three separate DAGs consisting of a BashOperator that echos one of the param values.

"""Dynamically define DAGs."""
import datetime

from airflow.models import DAG
from airflow import operators


start_date = datetime.datetime(2018, 5, 22)
params = [
    "one",
    "two",
    "three",
]

# Dynamically generate the dags from the list of parameters
dags = {}
for param in params:

    dag_args = {
        "owner": "airflow",
        "start_date": start_date,
    }

    dag = DAG(
        f"my_dag_{param}",
        default_args=dag_args,
        schedule_interval="0 0 * * *",
    )

    databricks_process = operators.BashOperator(
        task_id=f"echo_param_{param}",
        bash_command=f"echo {param}"
        dag=dag,
    )
    dags[f"dag_{param}"] = dag

# This is necessary to make dynamically generated dags visible to the scheduler
globals().update(dags)

Further reading

Airflow

Apache Airflow

Back to the posts.