El código usado para crear los TFRecords está basado en TensorFlow 1.x, esta versión de TensorFlow no debería ser usada para nuevos proyectos, aquí les preparé una implementación con TensorFlow 2.x si tienen comentarios y/o correcciones, son bienvenidos:
def class_text_to_int(row_label):
if row_label == MOTORCYCLE_LABEL:
return 1
elif row_label == CAR_LABEL:
return 2
else: return None
def split(df, group):
""" Create a namedtuple with the filename and all data related to cars
and/or motorcycles found on the picture"""
data = namedtuple('data', ['filename', 'object'])
gb = df.groupby(group) # group DataFrame by filename
return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]
def create_tf_example(group, path):
""" Create TFRecord files"""
image = tf.keras.utils.load_img(os.path.join(DATASET_PATH, group.filename))
width, height = image.size
image_format = image.format.encode('utf8')
image = tf.io.encode_jpeg(tf.keras.utils.img_to_array(image))
filename = group.filename.encode('utf8')
# check if the image format is matching with your images.
xmins = []
xmaxs = []
ymins = []
ymaxs = []
classes_text = []
classes = []
for index, row in group.object.iterrows():
xmins.append(row['xmin'] / width)
xmaxs.append(row['xmax'] / width)
ymins.append(row['ymin'] / height)
ymaxs.append(row['ymax'] / height)
classes_text.append(row['classname'].encode('utf8'))
classes.append(class_text_to_int(row['classname']))
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(image.numpy()),
'image/format': dataset_util.bytes_feature(image_format),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
}))
return tf_example
¿Quieres ver más aportes, preguntas y respuestas de la comunidad?