Faster RNG for Reinforcement Learning in Python

When you start doing reinforcement learning you will sooner or later come to the point where you will generate random numbers. Initializing policy networks or Q-tables, choosing between exploration or exploitation, or selecting among equally good actions are a few examples. Numpy is very efficient at generating those random numbers, and most of the time (like 95%) this is all you need. However, there is one particular edge case were numpy is not the best solution, and that is exactly the case we encounter in RL a lot: generating single random numbers (i.e., to select an action epsilon-greedy).

Generating single random numbers in numpy is a bad idea, because every numpy call gets sent to the numpy engine and then back to python, which creates overhead that dominates runtime for single random numbers. In this case it is (much) more efficient to use the python standard library instead. However, if you can generate the random numbers in batches, numpy is significantly faster than the standard library again. Take a look at this simple profile:

import timeit
number = 10000
numpy_time = timeit.timeit("[np.random.rand() for _ in range(int(1e3))]", "import numpy as np", number=number)
random_time = timeit.timeit("[random.random() for _ in range(int(1e3))]", "import random", number=number)
numpy_batch_time = timeit.timeit("np.random.rand(int(1e3))", "import numpy as np", number=number)
print("Timings")
print("=======")
print(f"Numpy Single: {numpy_time:.3f}")
print(f"Random: {random_time:.3f}")
print(f"Numpy Batch: {numpy_batch_time:.3f}")
print("=======")
# =======
# Numpy Single: 3.245
# Random: 1.003
# Numpy Batch: 0.085
# =======
Comparison between ways to generate random numbers

So if you do your profiling your code and notice that RNG adds up to a significant portion of your runtime, consider pre-generating the random numbers in numpy and then save them to a list. This solution sacrifices a bit of readability, but allows for much faster code. Here is an example that mimics the syntax of the python standard library:

import numpy as np
random_numbers = np.random.rand(int(2e8)).tolist()
def random():
try:
return random_numbers.pop()
except IndexError:
raise IndexError("Out of random numbers; generate more next time.")
def randint(a, b):
return int(round((ba) * random())) + a
view raw RandomDropin.py hosted with ❤ by GitHub
Source Code for Random replacement.

That’s it. Thanks for reading and Happy Coding!

Advertisement