In [1]:
import pandas as pd
import os
import pathlib
from ggplot import aes, ggplot, geom_point, stat_smooth, geom_text, ggtitle, geom_bar, ylim
In [111]:
def plot(df, step, metric, smooth=True, point=True, alpha=0.3, size=3):
    g = ggplot(aes(x=step, y=metric, color='name'), data=df)
    if point:
        g += geom_point(alpha=alpha)
    if smooth:
        g += stat_smooth(se=False, size=size)
    #g += geom_text(aes(label=metric),hjust=0, vjust=0, size=10)
    #g += ggtitle(metric)
    display(g)
In [3]:
def show(df, metric, low=0.0, high=1.0, dir='max',
         percentage=None, graph=True):
    x = df.groupby('name')[metric]
    x = getattr(x, dir)()
    x = x.reset_index()
    if percentage and not graph:
        x[metric] = x[metric].map("{:.2%}".format)
    df = x
    if graph:
        x = ggplot(aes(x='name', y=metric, weight=metric, fill='name'), data=x)
        x += geom_bar()
        x += ylim(low=low, high=high)
        x += ggtitle(u'performance: {}'.format(metric))
        display(x)
    else:
        display(df)
In [4]:
%matplotlib inline
In [5]:
def get_name(x):
    return os.path.splitext(x.name)[0]
In [65]:
def read_df(dir_name,  mapping={}, slice=None):
    d = dict([get_name(file), pd.read_csv(file.as_posix())] for file in\
            pathlib.Path(dir_name).iterdir() if file.is_file())
    collect = []
    for k, v in d.items():
        v.insert(0, 'name', k)
        if mapping:
            for old, new in mapping.items():
                v.insert(0, new, v[old])
        collect.append(v)
    df = pd.concat(collect)
    if slice:
        df = df[slice + ['name']]
        df.reset_index()
    return df
In [66]:
!ls bert_stat/ --color=always
4block_bert.csv  acc  bert.csv  loss  word_bert.csv

训练过程

In [114]:
acc = read_df('bert_stat/acc', mapping={'Value': 'acc'}, slice=['Step', 'acc'])
In [118]:
plot(acc, 'Step', 'acc', point=True, alpha=0)
<ggplot: (8735935513025)>
In [116]:
loss = read_df('bert_stat/loss', mapping={'Value': 'loss'}, slice=['Step', 'loss'])
In [117]:
plot(loss, 'Step', 'loss')
<ggplot: (8736009145113)>

验证集性能

In [84]:
df = read_df('bert_stat')
In [88]:
plot(df, 'epoch', 'val_acc')
<ggplot: (8736009154245)>
In [90]:
plot(df, 'epoch', 'val_fscore')
<ggplot: (8735937561577)>
In [91]:
show(df, 'val_acc', low=0.7, high=0.85)
<ggplot: (8735937470677)>
In [13]:
show(df, 'val_fscore', low=0.65, high=0.8)
<ggplot: (8736021161237)>
In [92]:
show(df, 'val_loss', 'min', graph=False)
name val_loss
0 4block_bert 0.650032
1 bert 0.890977
2 word_bert 0.567510