generated from rom1504/python-template
- Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathinference_example.py
98 lines (73 loc) · 3.21 KB
/
inference_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
"""
This is an example on how to use embedding reader to do an inference over a set of billion
of clip vit-l/14 embeddings to predict whether the corresponding images are safe or not
"""
fromembedding_readerimportEmbeddingReader
importfire
importos
os.environ["CUDA_VISIBLE_DEVICES"] =""
importnumpyasnp
importfsspec
importmath
importpandasaspd
importgc
defload_safety_model():
"""load the safety model"""
importautokerasasak# pylint: disable=import-outside-toplevel
fromtensorflow.keras.modelsimportload_model# pylint: disable=import-outside-toplevel
fromos.pathimportexpanduser# pylint: disable=import-outside-toplevel
home=expanduser("~")
cache_folder=home+"/.cache/clip_retrieval"
model_dir=cache_folder+"/clip_autokeras_binary_nsfw"
ifnotos.path.exists(model_dir):
os.makedirs(cache_folder, exist_ok=True)
fromurllib.requestimporturlretrieve# pylint: disable=import-outside-toplevel
path_to_zip_file=cache_folder+"/clip_autokeras_binary_nsfw.zip"
url_model= (
"https://raw.githubusercontent.com/LAION-AI/CLIP-based-NSFW-Detector/main/clip_autokeras_binary_nsfw.zip"
)
urlretrieve(url_model, path_to_zip_file)
importzipfile# pylint: disable=import-outside-toplevel
withzipfile.ZipFile(path_to_zip_file, "r") aszip_ref:
zip_ref.extractall(cache_folder)
loaded_model=load_model(model_dir, custom_objects=ak.CUSTOM_OBJECTS)
loaded_model.predict(np.random.rand(10**3, 768).astype("float32"), batch_size=10**3)
returnloaded_model
importmmh3
defcompute_hash(url, text):
ifurlisNone:
url=""
iftextisNone:
text=""
total= (url+text).encode("utf-8")
returnmmh3.hash64(total)[0]
defmain(
embedding_folder="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion1B-nolang/img_emb/",
metadata_folder="https://mystic.the-eye.eu/public/AI/cah/laion5b/embeddings/laion1B-nolang/laion1B-nolang-metadata/",
output_folder="output",
batch_size=10**5,
end=None,
):
"""main function"""
reader=EmbeddingReader(
embedding_folder, metadata_folder=metadata_folder, file_format="parquet_npy", meta_columns=["url", "caption"]
)
fs, relative_output_path=fsspec.core.url_to_fs(output_folder)
fs.mkdirs(relative_output_path, exist_ok=True)
model=load_safety_model()
total=reader.count
batch_count=math.ceil(total//batch_size)
padding=int(math.log10(batch_count)) +1
importtensorflowastf# pylint: disable=import-outside-toplevel
fori, (embeddings, ids) inenumerate(reader(batch_size=batch_size, start=0, end=end)):
predictions=model.predict_on_batch(embeddings)
batch=np.hstack(predictions)
padded_id=str(i).zfill(padding)
output_file_path=os.path.join(relative_output_path, padded_id+".parquet")
df=pd.DataFrame(batch, columns=["prediction"])
df["hash"] = [compute_hash(x, y) forx, yinzip(ids["url"], ids["caption"])]
df["url"] =ids["url"]
withfs.open(output_file_path, "wb") asf:
df.to_parquet(f)
if__name__=="__main__":
fire.Fire(main)