1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
|
import pandas as pd
import time
from tqdm import tqdm
start = time.time()
blank_df = pd.DataFrame(colunms=['idx','text','labels','probs'])
batch_size = 32
chunk_size = batch_size * 100
chunk_df = pd.read_table(path, sep="\t", header=None, chunksize=chunk_size)
model.load_weights(model_path)
df_all = blank_df.copy()
for df_sample in tqdm(chunk_df):
df_sample.columns = ['idx','text']
tokens, segments = [], []
for j, text in enumerate(df_sample.text):
try:
token_ids, segment_ids = tokenizer.encode(text, max_length=maxlen)
except:
# df_sample['labels'] = np.nan
# df_sample['probs'] = np.nan
token_ids = [101, 102]
segment_ids = [0, 0]
tokens.append(token_ids)
segments.append(segment_ids)
y_preds = model.predict([sequence_padding(tokens),sequence_padding(segments)],batch_size=batch_size)
df_sample['labels'] = y_preds.argmax(axis=1)
df_sample['probs'] = y_preds.max(axis=1)
# 这里要设置header=None,否则列名会追加到表格内容中
df_sample.to_csv(csv_path, index=None, header=None, mode='a')
df_all = pd.concat([df_all, df_sample])
end = time.time()
print(round(end-start, 2))
|