自然语言处理文本分类,预测大文件。

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import pandas as pd
import time
from tqdm import tqdm

start = time.time()
blank_df = pd.DataFrame(colunms=['idx','text','labels','probs'])
batch_size = 32
chunk_size = batch_size * 100
chunk_df = pd.read_table(path, sep="\t", header=None, chunksize=chunk_size)

model.load_weights(model_path)

df_all = blank_df.copy()
for df_sample in tqdm(chunk_df):
    df_sample.columns = ['idx','text']

    tokens, segments = [], []
    for j, text in enumerate(df_sample.text):
        try:
            token_ids, segment_ids = tokenizer.encode(text, max_length=maxlen)
        
        except:
            # df_sample['labels'] = np.nan
            # df_sample['probs'] = np.nan
            token_ids = [101, 102]
            segment_ids = [0, 0]
    
        tokens.append(token_ids)
        segments.append(segment_ids)
    
    y_preds = model.predict([sequence_padding(tokens),sequence_padding(segments)],batch_size=batch_size)
    df_sample['labels'] = y_preds.argmax(axis=1)
    df_sample['probs'] = y_preds.max(axis=1)

    # 这里要设置header=None,否则列名会追加到表格内容中
    df_sample.to_csv(csv_path, index=None, header=None, mode='a')
    df_all = pd.concat([df_all, df_sample])
        
        
end = time.time()
print(round(end-start, 2))

Reference

打赏

微信 微信 支付宝 支付宝
万分感谢