SSO for your App via Auth0 + Nginx + Docker + Vouch-Proxy

This post is a tutorial on “How to setup SSO via Auth0 using Nginx and Vouch-Proxy”. I couldn’t find an existing nifty blog post on this; so I ended up having to figure it out. Here, I want to document the steps so that others (also future me) may have an easier time setting this up.

If you don’t want a full tutorial, and just look for an example, here is a link to the repository with the final config files: https://github.com/FirefoxMetzger/sso_example

The setup I am presenting here works on localhost, and is mainly aimed at local development. It is a Docker-based setup, so there are tons of existing tutorials for deployment. Another thing that you may want to look into is hardening (making things super secure). I left this part out (for the most part) to avoid distraction; I really just want to focus on getting SSO up and running.

Setup Auth0

The first step is to set up Auth0 and create a new tenant. Make sure to pick a region that is close to your physical location; this will affect the login speed, but also how the data you send to Auth0 will be handled (data laws).

Setup window for a new tenant (Oct 2020)

Currently (2020), this will create a default app and enable authentication via email/password and google as a social login provider. We will use this default app. You can of course customize, but I recommend you first set it up following this tutorial, and then add your customization afterward.

Next, we will navigate to the settings of the default app.

Navigate to the settings page.

There are a few useful items in the settings which we will need, but the first thing is to allow users of our app to log in and log out. For this, we need to tell Auth0 which URLs are okay to use as callbacks for both login (Allowed Callback URLs) and logout (Allowed Logout URLs).

Navigate to Application URIs. For Allowed Callback URLs add http://localhost/sso/auth, https://localhost/sso/auth, http://localhost:9090/auth . For allowed logout URLs add http://localhost, https://localhost, http://localhost/sso/validate, http://localhost:9090/validate .

Configuration of Login and Logout inside Auth0.

Make sure to hit save changes at the bottom of the page.

We will delete most of these URLs as we move along, and they mainly exist for testing (so that we can assemble this incrementally). The two HTTPS URLs are the final ones, that we will use when we are done. The URLs on port 9090 are for testing vouch-proxy, which by default runs on port 9090, and the remaining HTTP URLs are for testing nginx as a reverse proxy for vouch-proxy and your app.

While we are now done with the setup for Auth0, don’t leave the settings page yet. At the top of the page you can find the applications domain, client ID and the client secret. We will need this info in the next steps, so keep it around.

Client ID and Client Secret location.

Setup Vouch-Proxy

Vouch-Proxy can almost run out of the box and all we need to do is add a config file. It follows the example for a generic OIDC provider, which you can find on the vouch-proxy repo. I made some modifications to make it work with Auth0.

When you use this template, be sure to replace the Auth0 domain with your domain, replace the client ID with your client ID, and replace the client secret with your client secret.

# vouch config
# bare minimum to get vouch running with OpenID Connect (such as okta)
vouch:
logLevel: debug
testing: true
listen: 0.0.0.0
port: 9090
allowAllUsers: true
jwt:
secret: Your-64-character-secret-key-here
issuer: Vouch
compress: false
cookie:
name: my-vouch-ct
secure: false
domain: localhost
headers:
jwt: X-Vouch-Token
querystring: access_token
redirect: X-Vouch-Requested-URI
accesstoken: X-Vouch-IdP-AccessToken
idtoken: X-Vouch-IdP-IdToken
post_logout_redirect_uris:
https://your-project-name-here.eu.auth0.com/v2/logout?client_id={client_id_from_auth0}&returnTo=http://localhost/
oauth:
# Generic OpenID Connect
# including okta
provider: oidc
client_id: {client_id_from_auth0}
client_secret: {client_secret_from_auth0}
auth_url: https://your-project-name-here.eu.auth0.com/authorize
token_url: https://your-project-name-here.eu.auth0.com/oauth/token
user_info_url: https://your-project-name-here.eu.auth0.com/userinfo
scopes:
openid
email
profile
callback_url: http://localhost:9090/auth
view raw config.yml hosted with ❤ by GitHub
Example config.yml for vouch proxy

To store this file, in your project’s folder create a sub-folder named vouch and in it another one named config. The relative path (from the project root) to the file is ./vouch/config/config.yml. You can check the GitHub repo for reference.

Next, it is time to test if vouch-proxy can correctly communicate with Auth0. I promised a dockerized setup, so let’s create a docker-compose file. (We will expand this file later.)

version: '3'
services:
vouch:
image: voucher/vouch-proxy
volumes:
./vouch/config/config.yml:/config/config.yml
ports:
9090:9090
Initial Docker-Compose config file.

Save the file in the project’s root directory as docker-compose.yml. Then, navigate to the root directory and bring up the “stack” with a

docker-compose up

Once it is up and running, you can open your browser and navigate to localhost:9090. You should be greeted with a

404 page not found

This is expected, so don’t worry; it’s merely a test if the server is alive (page not found >> Site can’t be reached).

Next, we will test the login flow between vouch-proxy and Auth0. For this navigate to

http://localhost:9090/login?url=http://localhost:9090/validate

The url= parameter specifies the location that we want the user to return to after the login has completed. In this case, we navigate to /validate, which is the endpoint we will use throughout the app to validate the client’s access token.

Once you put that into your browser, you will be greeted by a simple HTML page telling you that you are being redirected to some address. This is vouch-proxy’s debug mode which lets you check if your flow works correctly. Click the long that forwards to Auth0.

Vouch Proxy testing page.

This should present you with Auth0’s login form. Here we want to create a new user with username and password and authenticate ourselves.

Signup Page at Auth0.

Once you have an account and have accepted the permissions, you will be redirected to vouch-proxy and it will confirm that you are logged in with your chosen email.

Vouch proxy indicated that the user is authorized.

Now the only thing to test is to logout the user. Here there are multiple options. (1) You can log out the user from your app (vouch-proxy), (2) log the user out of your app and Auth0, or (3) you can log the user out of your app, Auth0, and their social login provider (if they use a social login). We will not cover the third one here.

To use the first option navigate to

localhost:9090/logout

which will tell you that you have been logged out. At this point, you are still logged in at Auth0, so after logging back into the app ( http://localhost:9090/login?url=http://localhost:9090/validate ), you will not be asked to log in at Auth0.

To log yourself out of Auth0 in parallel with your app, you have to tell vouch-proxy to redirect the user to the logout URL of Auth0. You can read more about it here. To logout in both places, use the URL below.

http://localhost:9090/logout?url=https://dev-simple.eu.auth0.com/v2/logout?client_id={client_id_from_auth0}%26returnTo=http://localhost:9090/validate

Be sure to replace {client_id_from_auth0} with your client ID. Also, notice the percent encoding of the ampersand (%26), which, if left out, will break the logout procedure. If you log out with this link, and try to log in again, you will be asked to provide your username and password at Auth0 again.

Behind the scenes there are two places where this callback needs to be authorized (otherwise it won’t happen). First, vouch-proxy needs to find the url= parameter inside the list of post_logout_redirect_uris (check the config.yml). Second, the returnTo= parameter of the redirect needs to be added to Allowed Logout URLs in Auth0’s config. We have done this in the previous section. If something breaks for you, make sure to check these locations(and the returned X-Vouch-Error header).

Vouch-Proxy is working and communicating with Auth0! Next we will setup nginx as a reverse proxy sitting in front of vouch-proxy and our app.

Setup Nginx

The next step in the process is to setup an Nginx server that can act as a reverse proxy for our app and vouch-proxy. For this, create a new config file at ./nginx/conf.d/server.conf filled with the configuration below

server {
listen 80;
server_name localhost;
location ^~ /sso/ {
location /sso/validate {
proxy_pass http://vouch:9090/validate;
proxy_set_header Host $http_host;
proxy_pass_request_body off;
}
location = /sso/logout {
proxy_pass http://vouch:9090/logout?url=https://your-project-name-here.eu.auth0.com/v2/logout?client_id={client_id_from_auth0}%26returnTo=http://localhost/;
proxy_set_header Host $http_host;
}
proxy_set_header Host $http_host;
proxy_pass http://vouch:9090/;
}
# uncomment this to forward static content of vouch-proxy
# used when running vouch-proxy with `testing: true`
location /static/ {
proxy_set_header Host $http_host;
proxy_pass http://vouch:9090/static/;
}
location / {
root /usr/share/nginx/html;
index index.html;
}
}
view raw server.conf hosted with ❤ by GitHub
Initial nginx configuration.

Also, upate the docker-compose.yml like so

version: '3'
services:
vouch:
image: voucher/vouch-proxy
volumes:
./vouch/config/config.yml:/config/config.yml
nginx:
image: nginx
depends_on:
vouch
volumes:
./nginx/conf.d/:/etc/nginx/conf.d/
ports:
80:80
updated docker-compose.yml

and finally, update the vouch-proxy configuration to callback to the new location. For this, you only have to change the variable callback_url in the last line of the config file.

# vouch config
# bare minimum to get vouch running with OpenID Connect (such as okta)
vouch:
logLevel: debug
testing: true
listen: 0.0.0.0
port: 9090
allowAllUsers: true
jwt:
secret: Your-64-character-secret-key-here
issuer: Vouch
compress: false
cookie:
name: my-vouch-ct
secure: false
domain: localhost
headers:
jwt: X-Vouch-Token
querystring: access_token
redirect: X-Vouch-Requested-URI
accesstoken: X-Vouch-IdP-AccessToken
idtoken: X-Vouch-IdP-IdToken
post_logout_redirect_uris:
https://your-project-name-here.eu.auth0.com/v2/logout?client_id={client_id_from_auth0}&returnTo=http://localhost/
oauth:
# Generic OpenID Connect
# including okta
provider: oidc
client_id: {client_id_from_auth0}
client_secret: {client_secret_from_auth0}
auth_url: https://your-project-name-here.eu.auth0.com/authorize
token_url: https://your-project-name-here.eu.auth0.com/oauth/token
user_info_url: https://your-project-name-here.eu.auth0.com/userinfo
scopes:
openid
email
profile
callback_url: http://localhost/sso/auth
view raw config.yml hosted with ❤ by GitHub
updated vouch-proxy config.yml

Now update the docker stack so that it uses the new configuration.

What has just happened? We have added another node (nginx) to our docker stack and added a configuration for a server at the default http port. the first location block (^~ /sso/) acts as a reverse proxy for vouch-proxy. It has specializations for the /validate endpoint (no body needed), and for the /logout endpoint (for convenience). All authorization calls will, hence, go to localhost/sso/.

The second location block (/static/) handles requests to vouch-proxies static files (the logo, and .css you see for 302 calls). This block is only needed when we set testing: true in the vouch.proxy config. Otherwise, the debugging website will not be shown, and we can remove this block.

The third location block is where our app will live. For now, it is the default nginx website.

Let’s test this setup. First navigate to

localhost/

and make sure that nginx is up and running. Then navigate to

localhost/sso/

and make sure that nginx is correctly forwarding to vouch-proxy (you should see the familiar 404 page not found). Now, test the login by navigating to

localhost/sso/login?url=http://localhost/sso/validate

This should result in the same flow that you are familiar with from the previous section, except that the URL now contains localhost/sso/ instead of localhost:9090/.

To log out simply visit

localhost/sso/logout

Notice how you are also logged out of Auth0. Nginx adds the necessary parameters before passing it to vouch proxy.

Secure your App

So far, we have setup nginx, vouch-proxy, and Auth0 in a neat docker stack and we have verified that everything is working. What we haven’t done yet is to integrate the actual app.

First, let’s create a super basic app that nginx can serve. Create a new file at ./web/index.html and fill it with a simple button to view protected content

<!DOCTYPE html>
<html>
<head>
<title>Simple App</title>
</head>
<body>
<button onclick="location.href='/sso/login?url=http\:\/\/localhost/protected';" id="myButton" class="float-left submit-button" >View Protected Content</button>
</body>
</html>
view raw index.html hosted with ❤ by GitHub
homepage of the app

and also a page that requires login to view at ./web/protected/index.html

<!DOCTYPE html>
<html>
<head>
<title>Simple App</title>
</head>
<body>
<p>This content is protected, and can't be seen without being logged in.</p>
<button onclick="location.href='/sso/logout';" id="myButton" class="float-left submit-button" >Logout</button>
</body>
</html>
view raw index.html hosted with ❤ by GitHub
protected page in the app

Then, add the files to the nginx container by updating the docker-compose.yml

version: '3'
services:
vouch:
image: voucher/vouch-proxy
volumes:
./vouch/config/config.yml:/config/config.yml
nginx:
image: nginx
depends_on:
vouch
volumes:
./nginx/conf.d/:/etc/nginx/conf.d/
./web/:/usr/share/nginx/html/
ports:
80:80
updated docker-compose.yml

Restart/update the stack, and you can see your website at localhost/. When clicking the “view protected content”, you will see the page that should be protected. When you click “logout” on the protected page, you will trigger the logout flow familiar from the previous sections.

To actually protect the content, we need to add a new location to nginx and protect it. This is done easily by updating the config file.

server {
listen 80;
server_name localhost;
location ^~ /sso/ {
location /sso/validate {
proxy_pass http://vouch:9090/validate;
proxy_set_header Host $http_host;
proxy_pass_request_body off;
}
location = /sso/logout {
proxy_pass http://vouch:9090/logout?url=https://your-project-name-here.eu.auth0.com/v2/logout?client_id={client_id_from_auth0}%26returnTo=http://localhost/;
proxy_set_header Host $http_host;
}
proxy_set_header Host $http_host;
proxy_pass http://vouch:9090/;
}
# uncomment this to forward static content of vouch-proxy
# used when running vouch-proxy with `testing: true`
location /static/ {
proxy_set_header Host $http_host;
proxy_pass http://vouch:9090/static/;
}
location / {
root /usr/share/nginx/html;
index index.html;
}
location /protected {
auth_request /sso/validate;
root /usr/share/nginx/html;
index index.html;
expires 0;
add_header Cache-Control "no-cache, no-store, must-revalidate, max-age=0";
add_header Pragma "no-cache";
}
}
view raw server.conf hosted with ❤ by GitHub
updated nginx server config

Now, when you click View Protected Content or manually navigate to localhost/protected, you will see the protected page (if you are logged in) or (if not) you will get a 401 Unauthorized error.

Next, we can have nginx catch the error and, instead of raising it, redirect the user to the login procedure with a simple addition to the server.conf

server {
listen 80;
server_name localhost;
location ^~ /sso/ {
location /sso/validate {
proxy_pass http://vouch:9090/validate;
proxy_set_header Host $http_host;
proxy_pass_request_body off;
}
location = /sso/logout {
proxy_pass http://vouch:9090/logout?url=https://your-project-name-here.eu.auth0.com/v2/logout?client_id={client_id_from_auth0}%26returnTo=http://localhost/;
proxy_set_header Host $http_host;
}
proxy_set_header Host $http_host;
proxy_pass http://vouch:9090/;
}
# uncomment this to forward static content of vouch-proxy
# used when running vouch-proxy with `testing: true`
location /static/ {
proxy_set_header Host $http_host;
proxy_pass http://vouch:9090/static/;
}
location / {
root /usr/share/nginx/html;
index index.html;
}
location /protected {
auth_request /sso/validate;
root /usr/share/nginx/html;
index index.html;
expires 0;
add_header Cache-Control "no-cache, no-store, must-revalidate, max-age=0";
add_header Pragma "no-cache";
error_page 401 = @prompt_login;
}
location @prompt_login {
return 302 http://localhost/sso/login?url=$scheme://$http_host$request_uri;
}
}
view raw server.conf hosted with ❤ by GitHub
updated nginx server config

Now, the user will be asked to log in if they try to access the protected location, and only if the login succeeds will they be able to view the protected page.

Bonus: Add HTTPS via self-signed certificates

In 2020 servers should enforce https, and while it is not necessary for localhost development, it is very nice to have. Especially later, when you have a development and production version of the code.

Adding SSH is very easy (shameless self plug). First generate a self-signed certificate for localhost (make sure to enter localhost as common name):

docker run --rm -it -v$PWD:/certs firefoxmetzger/create_localhost_ssl

(Source: https://github.com/FirefoxMetzger/create_ssl)

This will place a certificate and a private key into your current working directory which you can move to ./cert/ . Also, if you don’t want to be warned about an untrusted certificate, you can consider adding it to your browser’s trusted certificates.

Next, we have to update the server.conf for nginx

server {
listen 80;
server_name _;
return 301 https://$host$request_uri;
}
server {
listen 443 ssl;
server_name localhost;
ssl_certificate /certs/certificate.crt;
ssl_certificate_key /certs/private.key;
location ^~ /sso/ {
location /sso/validate {
proxy_pass http://vouch:9090/validate;
proxy_set_header Host $http_host;
proxy_pass_request_body off;
}
location = /sso/logout {
proxy_pass http://vouch:9090/logout?url=https://your-project-name-here.eu.auth0.com/v2/logout?client_id={client_id_from_auth0}%26returnTo=https://localhost/;
proxy_set_header Host $http_host;
}
proxy_set_header Host $http_host;
proxy_pass http://vouch:9090/;
}
# uncomment this to forward static content of vouch-proxy
# used when running vouch-proxy with `testing: true`
location /static/ {
proxy_set_header Host $http_host;
proxy_pass http://vouch:9090/static/;
}
location / {
root /usr/share/nginx/html;
index index.html;
}
location /protected {
auth_request /sso/validate;
root /usr/share/nginx/html;
index index.html;
expires 0;
add_header Cache-Control "no-cache, no-store, must-revalidate, max-age=0";
add_header Pragma "no-cache";
error_page 401 = @prompt_login;
}
location @prompt_login {
return 302 https://localhost/sso/login?url=$scheme://$http_host$request_uri;
}
}
view raw server.conf hosted with ❤ by GitHub
updated server.conf

The new first server block will forward all HTTP requests to HTTPS. Then, we add the SSL certificate we have just generated and change nginx to listen to the standard HTTPS port. Finally, we change the protocol from HTTP to HTTPS for both redirects.

Next update the callback_url for the vouch-proxy config as well as the post_logout_redirect_uris.

# vouch config
# bare minimum to get vouch running with OpenID Connect (such as okta)
vouch:
logLevel: debug
testing: true
listen: 0.0.0.0
port: 9090
allowAllUsers: true
jwt:
secret: Your-64-character-secret-key-here
issuer: Vouch
compress: false
cookie:
name: my-vouch-ct
secure: false
domain: localhost
headers:
jwt: X-Vouch-Token
querystring: access_token
redirect: X-Vouch-Requested-URI
accesstoken: X-Vouch-IdP-AccessToken
idtoken: X-Vouch-IdP-IdToken
post_logout_redirect_uris:
https://your-project-name-here.eu.auth0.com/v2/logout?client_id={client_id_from_auth0}&returnTo=https://localhost/
oauth:
# Generic OpenID Connect
# including okta
provider: oidc
client_id: {client_id_from_auth0}
client_secret: {client_secret_from_auth0}
auth_url: https://your-project-name-here.eu.auth0.com/authorize
token_url: https://your-project-name-here.eu.auth0.com/oauth/token
user_info_url: https://your-project-name-here.eu.auth0.com/userinfo
scopes:
openid
email
profile
callback_url: https://localhost/sso/auth
view raw config.yml hosted with ❤ by GitHub
updated conf.yml

Last, but not least, make the certificates available to nginx, by mounting the folder into the nginx container and open port 433 to allow SSL connections.

version: '3'
services:
vouch:
image: voucher/vouch-proxy
volumes:
./vouch/config/config.yml:/config/config.yml
nginx:
image: nginx
depends_on:
vouch
volumes:
./nginx/conf.d/:/etc/nginx/conf.d/
./web/:/usr/share/nginx/html/
./cert/:/certs
ports:
443:443
80:80
updated docker-compose.yml

Now, when you update the docker stack, you should be able to navigate to https://localhost (potentially receive a warning that the certificate could not be verified), and browse your app encrypted.

Remove Debugging

The only thing left is to remove some of the config that we have introduced for debugging purposes.

First, in the settings for the Tenant at Auth0, remove the uneeded allowed callback and logout URLs

Final Auth0 settings

Then, disable debug logs and testing mode in vouch-proxy.

# vouch config
# bare minimum to get vouch running with OpenID Connect (such as okta)
vouch:
listen: 0.0.0.0
port: 9090
allowAllUsers: true
jwt:
secret: Your-64-character-secret-key-here
issuer: Vouch
compress: false
cookie:
name: my-vouch-ct
secure: false
domain: localhost
headers:
jwt: X-Vouch-Token
querystring: access_token
redirect: X-Vouch-Requested-URI
accesstoken: X-Vouch-IdP-AccessToken
idtoken: X-Vouch-IdP-IdToken
post_logout_redirect_uris:
https://your-project-name-here.eu.auth0.com/v2/logout?client_id={client_id_from_auth0}&returnTo=https://localhost/
oauth:
# Generic OpenID Connect
# including okta
provider: oidc
client_id: {client_id_from_auth0}
client_secret: {client_secret_from_auth0}
auth_url: https://your-project-name-here.eu.auth0.com/authorize
token_url: https://your-project-name-here.eu.auth0.com/oauth/token
user_info_url: https://your-project-name-here.eu.auth0.com/userinfo
scopes:
openid
email
profile
callback_url: https://localhost/sso/auth
view raw config.yml hosted with ❤ by GitHub
final config.yml for vouch-proxy

And update the nginx config by removing the /static/ route.

server {
listen 80;
server_name _;
return 301 https://$host$request_uri;
}
server {
listen 443 ssl;
server_name localhost;
ssl_certificate /certs/certificate.crt;
ssl_certificate_key /certs/private.key;
location ^~ /sso/ {
location /sso/validate {
proxy_pass http://vouch:9090/validate;
proxy_set_header Host $http_host;
proxy_pass_request_body off;
}
location = /sso/logout {
proxy_pass http://vouch:9090/logout?url=https://your-project-name-here.eu.auth0.com/v2/logout?client_id={client_id_from_auth0}%26returnTo=https://localhost/;
proxy_set_header Host $http_host;
}
proxy_set_header Host $http_host;
proxy_pass http://vouch:9090/;
}
# uncomment this to forward static content of vouch-proxy
# used when running vouch-proxy with `testing: true`
#location /static/ {
# proxy_set_header Host $http_host;
# proxy_pass http://vouch:9090/static/;
#}
location / {
root /usr/share/nginx/html;
index index.html;
}
location /protected {
auth_request /sso/validate;
root /usr/share/nginx/html;
index index.html;
expires 0;
add_header Cache-Control "no-cache, no-store, must-revalidate, max-age=0";
add_header Pragma "no-cache";
error_page 401 = @prompt_login;
}
location @prompt_login {
return 302 https://localhost/sso/login?url=$scheme://$http_host$request_uri;
}
}
view raw server.conf hosted with ❤ by GitHub
final server.conf for nginx

Done! Now you have a simple app that is secured with Auth0 and vouch-proxy. You can add an API to this in the same way we have added the /protected route. Simply add

auth_request /sso/validate;

to the route that proxies the API.

If you have any questions or comments, feel free to comment below.

Thanks for reading and Happy Coding!

Related Useful Links

Layered Layers: Residual Blocks in the Sequential Keras API

I’ve been looking at the AlphaGo:Zero network architecture [1] and was searching for existing implementations. I’ve found quite a few (here , here and here) with varying degrees of completeness. The cleanest is probably this one but it depends on Jupyter.

What surprised me was that I couldn’t find one that used Keras’ sequential API. While residual blocks aren’t exactly sequential, from a high level view the architecture itself is; it simply stacks (a lot of) residual blocks. So it should be possible to create something like this, right?

The answer is, of course: Yes, there isn’t much that you can’t do in Python. We are actually using this strategy already. Sequential itself inherits from Layer and, in fact, Container (a class sitting between Sequential and Layer in the inheritance hierarchy) states so itself: A Container is a directed acyclic graph of layers. It is the topological form of a “model”. A Model is simply a Container with added training routines. (source)

It works by defining the residual block as a new Keras layer. Depending on how tightly integrated you want it this can be quite short:


from keras.engine.topology import Layer
from keras.layers importActivation, Conv2D, Add
class Residual(Layer):
def __init__(self, channels_in,kernel,**kwargs):
super(Residual, self).__init__(**kwargs)
self.channels_in = channels_in
self.kernel = kernel
def call(self, x):
# the residual block using Keras functional API
first_layer = Activation("linear", trainable=False)(x)
x = Conv2D( self.channels_in,
self.kernel,
padding="same")(first_layer)
x = Activation("relu")(x)
x = Conv2D( self.channels_in,
self.kernel,
padding="same")(x)
residual = Add()([x, first_layer])
x = Activation("relu")(residual)
return x
def compute_output_shape(self, input_shape):
return input_shape

Inside the block we fall back to the functional way of stacking layers. If you want better integration, e.g. model.summary() showing the number of trainable weights, there is additional plumbing. Above just shows the gist . . . (gosh! That pun was bad).

Once that is written, we can use model.add( Residual(32, (3,3) )) as we would any other layer. Nice!

To close with an example, I modified the Keras CNN example on CIFAR10 and replaced the hidden convolutional layers with residual ones. I haven’t optimized performance, but you can see how it works. If you are familiar with the example, you might appreciate how similar it looks.


'''Train a simple residual network on the CIFAR10 small images dataset.
It gets to 75% validation accuracy in 25 epochs, and 79% after 50 epochs.
(it's still underfitting at that point, though).
'''
from __future__ import print_function
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D, Add
import os
from keras.engine.topology import Layer
# Define the residual block as a new layer
class Residual(Layer):
def __init__(self, channels_in,kernel,**kwargs):
super(Residual, self).__init__(**kwargs)
self.channels_in = channels_in
self.kernel = kernel
def call(self, x):
# the residual block using Keras functional API
first_layer = Activation("linear", trainable=False)(x)
x = Conv2D( self.channels_in,
self.kernel,
padding="same")(first_layer)
x = Activation("relu")(x)
x = Conv2D( self.channels_in,
self.kernel,
padding="same")(x)
residual = Add()([x, first_layer])
x = Activation("relu")(residual)
return x
def compute_output_shape(self, input_shape):
return input_shape
batch_size = 32
num_classes = 10
epochs = 100
data_augmentation = True
num_predictions = 20
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = 'keras_cifar10_trained_model.h5'
# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)
model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(Residual(32,(3,3)))
model.add(Residual(32,(3,3)))
model.add(Residual(32,(3,3)))
model.add(Residual(32,(3,3)))
model.add(Residual(32,(3,3)))
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))
# initiate RMSprop optimizer
opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)
# Let's train the model using RMSprop
model.compile(loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy'])
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
if not data_augmentation:
print('Not using data augmentation.')
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
shuffle=True)
else:
print('Using real-time data augmentation.')
# This will do preprocessing and realtime data augmentation:
datagen = ImageDataGenerator(
featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
horizontal_flip=True, # randomly flip images
vertical_flip=False) # randomly flip images
# Compute quantities required for feature-wise normalization
# (std, mean, and principal components if ZCA whitening is applied).
datagen.fit(x_train)
# Fit the model on the batches generated by datagen.flow().
model.fit_generator(datagen.flow(x_train, y_train,
batch_size=batch_size),
epochs=epochs,
validation_data=(x_test, y_test),
workers=4)
# Save model and weights
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name)
model.save(model_path)
print('Saved trained model at %s ' % model_path)
# Score trained model.
scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])

view raw

cifar10_res.py

hosted with ❤ by GitHub

References

[1] Silver, David, et al. “Mastering the game of go without human knowledge.” Nature 550.7676 (2017): 354.

Parsing TFRecords with the Tensorflow Dataset API

Update: Datasets are now part of the example in the Tensorflow library.

The Datasets API has become the new standard in feeding things into Tensorflow. Moreover, there seem to be plans to deprecate queues and other inputs, unifying the way data is fed into models. The idea now is to (1) create a Dataset object (in this case a TFRecordDataset) and then (2) create an Iterator that will extract elements and feed them into the model.

I’ve modified tensorflow’s example on “how to read data” to reflect that change. I’ve submitted a PR to the tensorflow repo, until it gets merged take a look at the new code below. It is a lot easier to read, see for yourself:


# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Train and Eval the MNIST network.
This version is like fully_connected_feed.py but uses data converted
to a TFRecords file containing tf.train.Example protocol buffers.
See:
https://www.tensorflow.org/programmers_guide/reading_data#reading_from_files
for context.
YOU MUST run convert_to_records before running this (but you only need to
run it once).
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os.path
import sys
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import mnist
# Basic model parameters as external flags.
FLAGS = None
# Constants used for dealing with the files, matches convert_to_records.
TRAIN_FILE = 'train.tfrecords'
VALIDATION_FILE = 'validation.tfrecords'
def decode(serialized_example):
features = tf.parse_example(
[serialized_example],
# Defaults are not specified since both keys are required.
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
# Convert from a scalar string tensor (whose single string has
# length mnist.IMAGE_PIXELS) to a uint8 tensor with shape
# [mnist.IMAGE_PIXELS].
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape((1, mnist.IMAGE_PIXELS))
image = tf.reshape(image,[1])
# Convert label from a scalar uint8 tensor to an int32 scalar.
label = tf.cast(features['label'], tf.int32)
return image, label
def augment(image, label):
# OPTIONAL: Could reshape into a 28×28 image and apply distortions
# here. Since we are not applying any distortions in this
# example, and the next step expects the image to be flattened
# into a vector, we don't bother.
return image, label
def normalize(image, label):
# Convert from [0, 255] -> [-0.5, 0.5] floats.
image = tf.cast(image, tf.float32) * (1. / 255) 0.5
return image, label
def inputs(train, batch_size, num_epochs):
"""Reads input data num_epochs times.
Args:
train: Selects between the training (True) and validation (False) data.
batch_size: Number of examples per returned batch.
num_epochs: Number of times to read the input data, or 0/None to
train forever.
Returns:
A tuple (images, labels), where:
* images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
in the range [-0.5, 0.5].
* labels is an int32 tensor with shape [batch_size] with the true label,
a number in the range [0, mnist.NUM_CLASSES).
Note that an tf.train.QueueRunner is added to the graph, which
must be run using e.g. tf.train.start_queue_runners().
"""
if not num_epochs: num_epochs = None
filename = os.path.join(FLAGS.train_dir,
TRAIN_FILE if train else VALIDATION_FILE)
with tf.name_scope('input'):
# create the dataset
dataset = tf.data.TFRecordDataset(filename)
# iterate this dataset num_epoch times
dataset = dataset.repeat(num_epochs)
dataset = dataset.map(decode)
dataset = dataset.map(augment)
dataset = dataset.map(normalize)
dataset = dataset.shuffle(1000 + 3 * batch_size)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
return iterator.get_next()
def run_training():
"""Train MNIST for a number of steps."""
# Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
# Input images and labels.
image_batch, label_batch = inputs(train=True, batch_size=FLAGS.batch_size,
num_epochs=FLAGS.num_epochs)
# Build a Graph that computes predictions from the inference model.
logits = mnist.inference(image_batch,
FLAGS.hidden1,
FLAGS.hidden2)
#mnist.loss expects a (batch_size,) instead of (batch_size,1)
label_batch = tf.reshape(label_batch,[1])
# Add to the Graph the loss calculation.
loss = mnist.loss(logits, label_batch)
# Add to the Graph operations that train the model.
train_op = mnist.training(loss, FLAGS.learning_rate)
# The op for initializing the variables.
init_op = tf.group(tf.global_variables_initializer(),
tf.local_variables_initializer())
# Create a session for running operations in the Graph.
with tf.Session().as_default() as sess:
# Initialize the variables (the trained variables and the
# epoch counter).
sess.run(init_op)
try:
step = 0
while True: #train until OutOfRangeError
start_time = time.time()
# Run one step of the model. The return values are
# the activations from the `train_op` (which is
# discarded) and the `loss` op. To inspect the values
# of your ops or variables, you may include them in
# the list passed to sess.run() and the value tensors
# will be returned in the tuple from the call.
_, loss_value = sess.run([train_op, loss])
duration = time.time() start_time
# Print an overview fairly often.
if step % 100 == 0:
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
duration))
step += 1
except tf.errors.OutOfRangeError:
print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
def main(_):
run_training()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'–learning_rate',
type=float,
default=0.01,
help='Initial learning rate.'
)
parser.add_argument(
'–num_epochs',
type=int,
default=2,
help='Number of epochs to run trainer.'
)
parser.add_argument(
'–hidden1',
type=int,
default=128,
help='Number of units in hidden layer 1.'
)
parser.add_argument(
'–hidden2',
type=int,
default=32,
help='Number of units in hidden layer 2.'
)
parser.add_argument(
'–batch_size',
type=int,
default=100,
help='Batch size.'
)
parser.add_argument(
'–train_dir',
type=str,
default='/tmp/data',
help='Directory with the training data.'
)
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

Further Reading: