my DNN predict script started hanging recently and I wonder if anyone could tell me why this hangs when I execute the predict method in a multiprocessing pool?
import multiprocessing
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
dataset = np.random.rand(50,9)
# split into input (X) and output (y) variables
X = dataset[:,0:8]
y = dataset[:,8]
model = Sequential()
model.add(Dense(12, input_shape=(8,), activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy')
model.fit(X, y, epochs=3)
# make class predictions with the model
def predict(i):
print(f"predict rows {i}:{i+2}")
return model.predict(X[i:i+2,:])
with multiprocessing.Pool(processes=2) as pool:
data = pool.map(predict, [0,2])
print(data)
I’m trying to evaluate a DNN model and split the job between processors. I’m running on casper with 3 cpus reserved.
@David Ahijevych which version of tensorflow are you using? Would it make sense to use GPUs for distributed prediction (and/or training)?
Hi Katie,
tensorflow 2.11
I tried earlier versions but conda couldn't find a match with the C libraries. Yes, GPUs probably make sense but I've had troubles. When I run on GPUs my scripts execute faster, but I sometimes get an out-of-memory error that is resolved by using CPUs instead.
I see. With tensorflow version 1 I had problems with multiprocessing pool but I thought that should be resolved with version 2. There is a use_multiprocessing
argument that you can use with keras model.predict
, maybe that will help? https://keras.io/api/models/model_training_apis/#predict-method
good idea.. I tried use_multiprocessing
but to no avail. According to the method description, use_multiprocessing
and workers
may help if your input is a generator
or keras.utils.Sequence
only. But mine is not a generator and I don't even know what a keras Sequence is. So I might have to look elsewhere.
Going forward, I'll avoid multiprocessing with model.predict(). For my larger files (50M+ rows), I'll simply increase the batch_size from the default 32 to 5000 to get the speedup I want. I had no idea that would help so much, and I don't need multiprocessing.
Last updated: May 16 2025 at 17:14 UTC