faiss-7: 基础索引类型
数据准备import faissimport numpy as npd = 512# 维数# 向量集合n_data = 2000np.random.seed(0)data = []mu = 3sigma = 0.1for i in range(n_data):data.append(np.random.normal(mu, sigma, d))data = np.array(data).astyp
数据准备
import faiss
import numpy as np
d = 512 # 维数
# 向量集合
n_data = 2000
np.random.seed(0)
data = []
mu = 3
sigma = 0.1
for i in range(n_data):
data.append(np.random.normal(mu, sigma, d))
data = np.array(data).astype('float32')
# query 向量
n_query = 10
np.random.seed(12)
query = []
for i in range(n_query):
query.append(np.random.normal(mu, sigma, d))
query = np.array(query).astype('float32')
精确搜索-L2距离(Exact Search for L2)
baseline
index = faiss.IndexFlatL2(d) # L2距离
# index = faiss.index_factory(d, "Flat") # 两种定义方式
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[8.61838 8.782156 8.782816 8.832027 8.837635 8.8484955 8.897978
8.9166355 8.919006 8.937399 ]
[9.033302 9.038906 9.091706 9.155842 9.164592 9.200113 9.201885
9.220333 9.279479 9.312859 ]
[8.063819 8.211029 8.306456 8.373353 8.459253 8.459894 8.498556
8.546466 8.555407 8.621424 ]
[8.193895 8.211957 8.34701 8.446963 8.45299 8.45486 8.473572
8.504771 8.513636 8.530685 ]
[8.369623 8.549446 8.704066 8.736764 8.760081 8.777317 8.831345
8.835485 8.858271 8.860057 ]
[8.299071 8.432397 8.434382 8.457373 8.539217 8.562357 8.579033
8.618738 8.630859 8.6433935]
[8.615003 8.615164 8.72604 8.730944 8.762621 8.796932 8.797066
8.797366 8.813984 8.834725 ]
[8.377228 8.522776 8.711159 8.724562 8.745737 8.763845 8.7686
8.7728 8.786858 8.828223 ]
[8.3429165 8.488056 8.655106 8.662771 8.701336 8.741288 8.7436075
8.770506 8.786265 8.8490505]
[8.522163 8.575702 8.684618 8.767246 8.782908 8.850494 8.883732
8.903692 8.909395 8.917681 ]]
[[1269 1525 1723 1160 1694 48 1075 1028 544 916]
[1035 259 1279 1116 1398 879 289 882 1420 1927]
[ 327 345 1401 389 1904 1992 1612 106 981 1179]
[1259 112 351 804 1412 1987 1377 250 1624 133]
[1666 854 1135 616 94 280 30 99 1212 3]
[ 574 1523 366 766 1046 91 456 649 46 896]
[1945 944 244 655 1686 981 256 1555 1280 1969]
[ 879 1025 390 269 1115 1662 1831 610 11 191]
[ 156 154 99 31 1237 289 769 1524 56 661]
[ 427 182 375 1826 610 1384 1299 750 2 1430]]
精确搜索-点乘距离(Exact Search for Inner Product)
当数据库向量是标准化的,计算返回的distance就是余弦相似度
index = faiss.IndexFlatIP(d) # 点乘
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[4621.75 4621.5464 4619.745 4619.381 4619.176 4618.0625 4617.169
4617.057 4617.048 4616.6304]
[4637.3965 4637.289 4635.3677 4635.2446 4634.881 4633.6074 4633.021
4632.7646 4632.5596 4632.373 ]
[4621.755 4621.47 4619.748 4619.562 4619.423 4618.0186 4616.992
4616.962 4616.9014 4616.735 ]
[4623.608 4623.5596 4621.3965 4621.1577 4620.9062 4619.838 4618.9756
4618.913 4618.7695 4618.4775]
[4625.553 4625.064 4623.461 4623.1963 4622.957 4621.337 4620.7363
4620.717 4620.5645 4620.248 ]
[4628.489 4628.449 4626.4917 4626.4873 4625.6406 4624.615 4624.29
4623.999 4623.7524 4623.618 ]
[4637.746 4637.338 4635.3047 4635.126 4634.7476 4633.0137 4632.8633
4632.58 4632.3027 4632.233 ]
[4630.4717 4630.334 4628.2646 4627.9375 4627.7383 4626.8975 4625.8135
4625.7227 4625.445 4625.0913]
[4635.7725 4635.4893 4633.6904 4633.5674 4632.658 4631.4634 4631.4307
4631.101 4630.99 4630.3066]
[4625.6763 4625.558 4623.454 4623.3916 4623.324 4622.2827 4621.7783
4621.1147 4620.9043 4620.8545]]
[[1562 27 681 169 1262 942 1566 31 1207 252]
[ 27 1562 169 681 1262 942 1566 1392 31 252]
[1562 27 681 169 1262 942 1566 252 31 1392]
[1562 27 681 169 1262 942 252 1566 31 1207]
[1562 27 681 169 1262 942 1566 252 31 1513]
[1562 27 169 681 1262 942 252 31 1566 911]
[1562 27 169 681 1262 1566 942 252 31 1392]
[1562 27 681 169 1262 942 1566 31 252 1207]
[ 27 1562 169 681 1262 942 31 252 1566 1392]
[1562 27 1262 681 169 942 1566 31 1207 252]]
HNSW(Hierarchical Navigable Small World graph exploration)
基于图近似方法获取返回的近似结果
index = faiss.IndexHNSWFlat(d, 16)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[8.61838 8.832027 8.8484955 8.897978 8.9166355 8.919006 8.937399
8.9597 8.984709 8.998905 ]
[9.038906 9.164592 9.200113 9.201885 9.220333 9.312859 9.344341
9.34485 9.416972 9.421429 ]
[8.063819 8.459253 8.459894 8.498556 8.555407 8.631897 8.71368
8.735945 8.770473 8.792957 ]
[8.193895 8.211957 8.34701 8.446963 8.45486 8.473572 8.504771
8.513636 8.530685 8.545483 ]
[8.369623 8.760081 8.831345 8.860057 8.862643 8.93695 8.972281
8.996923 9.065968 9.070428 ]
[8.299071 8.432397 8.434382 8.539217 8.562357 8.6433935 8.6983185
8.753673 8.768753 8.780443 ]
[8.762621 8.796932 8.797066 8.813984 8.860753 8.867388 8.911812
8.922768 8.928856 8.942961 ]
[8.377228 8.522776 8.711159 8.724562 8.7728 8.828223 8.879469
8.888437 8.914921 8.924161 ]
[8.3429165 8.488056 8.662771 8.741288 8.7436075 8.770506 8.857254
8.893715 8.933592 8.960606 ]
[8.522163 8.575702 8.850494 8.903692 8.917681 8.936615 8.961666
8.977329 9.009894 9.031724 ]]
[[1269 48 1075 1028 916 239 897 1627 120 1567]
[ 259 1398 879 289 882 1927 13 70 1023 121]
[1401 1904 106 981 1623 1393 1632 539 1143 366]
[1259 112 351 804 1987 1377 250 1624 133 879]
[1666 94 1212 277 1723 581 106 472 807 884]
[ 574 1523 366 766 1046 91 154 911 902 685]
[ 944 1686 981 1391 849 280 1337 1263 91 1540]
[ 879 1025 390 269 1115 1831 610 11 191 1686]
[ 154 31 1237 289 769 1524 661 426 1008 1727]
[ 182 1299 1430 511 1339 1010 1173 1457 664 529]]
倒排表搜索(Inverted file with exact post-verification)
quantizer = faiss.IndexFlatL2(d) # 量化器
nlist = 50
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
index.train(data)
index.add(data)
index.nprobe = 30 # 选择nprobe个维诺空间进行索引
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[8.61838 8.782156 8.782816 8.832027 8.837635 8.8484955 8.897978
8.9166355 8.919006 8.937399 ]
[9.033302 9.038906 9.091706 9.164592 9.200113 9.201885 9.220333
9.279479 9.312859 9.344341 ]
[8.063819 8.211029 8.306456 8.373353 8.459253 8.459894 8.498556
8.546466 8.555407 8.621424 ]
[8.193895 8.211957 8.34701 8.446963 8.45299 8.45486 8.473572
8.504771 8.513636 8.530685 ]
[8.369623 8.549446 8.704066 8.736764 8.760081 8.777317 8.831345
8.835485 8.858271 8.860057 ]
[8.299071 8.432397 8.434382 8.457373 8.539217 8.562357 8.579033
8.618738 8.630859 8.6983185]
[8.615003 8.615164 8.72604 8.730944 8.762621 8.796932 8.797066
8.797366 8.813984 8.834725 ]
[8.377228 8.522776 8.711159 8.724562 8.745737 8.763845 8.7686
8.7728 8.786858 8.828223 ]
[8.3429165 8.488056 8.655106 8.662771 8.701336 8.741288 8.7436075
8.770506 8.786265 8.8490505]
[8.522163 8.575702 8.684618 8.767246 8.782908 8.850494 8.883732
8.903692 8.909395 8.917681 ]]
[[1269 1525 1723 1160 1694 48 1075 1028 544 916]
[1035 259 1279 1398 879 289 882 1420 1927 13]
[ 327 345 1401 389 1904 1992 1612 106 981 1179]
[1259 112 351 804 1412 1987 1377 250 1624 133]
[1666 854 1135 616 94 280 30 99 1212 3]
[ 574 1523 366 766 1046 91 456 649 46 154]
[1945 944 244 655 1686 981 256 1555 1280 1969]
[ 879 1025 390 269 1115 1662 1831 610 11 191]
[ 156 154 99 31 1237 289 769 1524 56 661]
[ 427 182 375 1826 610 1384 1299 750 2 1430]]
LSH(Locality-Sensitive Hashing (binary flat index))
nbits = 2 * d
index = faiss.IndexLSH(d, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[ 8. 10. 10. 10. 10. 10. 10. 11. 11. 11.]
[ 7. 8. 9. 9. 9. 10. 10. 10. 10. 10.]
[ 7. 8. 8. 9. 9. 9. 9. 9. 9. 9.]
[ 9. 9. 10. 11. 12. 12. 12. 12. 12. 12.]
[ 6. 6. 6. 7. 7. 8. 8. 8. 8. 8.]
[ 8. 8. 8. 9. 9. 9. 9. 9. 10. 10.]
[ 6. 7. 8. 8. 9. 9. 9. 9. 9. 9.]
[ 9. 9. 9. 9. 9. 9. 9. 9. 9. 10.]
[ 7. 8. 8. 8. 8. 8. 8. 9. 9. 9.]
[ 9. 9. 9. 10. 10. 10. 10. 10. 10. 10.]]
[[1424 345 1544 760 1589 1043 668 492 148 666]
[1974 1436 1476 51 711 696 28 934 541 1125]
[1667 1356 1149 512 1592 1544 1677 309 1021 1018]
[ 708 107 606 243 18 612 598 615 269 250]
[1455 541 1142 843 1140 888 165 961 797 1003]
[1735 193 953 1071 1518 1109 449 263 1329 1216]
[1129 1231 1731 123 860 907 381 993 336 1071]
[1622 336 1970 845 70 1921 1973 980 331 42]
[1335 713 1589 395 263 1206 346 698 913 678]
[1395 279 1305 427 1707 1574 1710 226 1205 1160]]
SQ量化(Scalar quantizer (SQ) in flat mode)
index = faiss.IndexScalarQuantizer(d, 4)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[8.623228 8.777794 8.785317 8.828827 8.835491 8.845296 8.896896
8.914822 8.922382 8.934984 ]
[9.028503 9.037548 9.099254 9.152615 9.16542 9.196389 9.200497
9.224977 9.274048 9.305386 ]
[8.064029 8.213011 8.310526 8.376434 8.4578285 8.462004 8.500969
8.550645 8.556992 8.624526 ]
[8.196653 8.210537 8.3464365 8.444774 8.452023 8.454119 8.474525
8.4966135 8.510042 8.525611 ]
[8.370451 8.547961 8.704324 8.733618 8.763927 8.776734 8.82951
8.835646 8.857151 8.859047 ]
[8.29591 8.432422 8.435947 8.454732 8.542397 8.565366 8.579685
8.621871 8.632036 8.64478 ]
[8.609016 8.612934 8.726631 8.734137 8.758858 8.797329 8.797971
8.798654 8.815296 8.838221 ]
[8.37895 8.521536 8.710902 8.726156 8.748387 8.75966 8.768217
8.769184 8.792372 8.834644 ]
[8.340463 8.489507 8.659348 8.664953 8.702758 8.741514 8.741941
8.768995 8.781276 8.852154 ]
[8.520282 8.5749855 8.68346 8.769207 8.782043 8.851276 8.881114
8.906744 8.907756 8.924014 ]]
[[1269 1723 1525 1160 1694 48 1075 544 1028 916]
[1035 259 1279 1116 1398 289 879 882 1420 1927]
[ 327 345 1401 389 1992 1904 1612 106 981 1179]
[1259 112 351 804 1987 1412 1377 250 1624 133]
[1666 854 1135 616 94 280 30 99 3 1212]
[ 574 1523 366 766 1046 91 456 649 46 896]
[ 944 1945 244 655 1686 256 981 1555 1280 1969]
[ 879 1025 390 269 1115 1662 1831 610 11 191]
[ 156 154 99 31 1237 289 769 1524 56 661]
[ 427 182 375 1826 610 1384 1299 750 2 1430]]
PQ量化(Product quantizer (PQ) in flat mode)
M = 8 # 必须是d的因数
nbits = 6 # 只能是8,12,16
index = faiss.IndexPQ(d, M, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[5.3148193 5.33667 5.390381 5.3969727 5.4020996 5.402466 5.4088135
5.420532 5.4210205 5.4370117]
[5.694214 5.71875 5.7408447 5.7418213 5.743042 5.7543945 5.7611084
5.764282 5.786255 5.791748 ]
[4.880615 4.9449463 4.9451904 5.0146484 5.022583 5.0406494 5.0444336
5.0650635 5.06604 5.0666504]
[4.8305664 4.852661 4.8569336 4.87146 4.8901367 4.8969727 4.9003906
4.9073486 4.9093018 4.911743 ]
[5.2233887 5.3170166 5.3239746 5.338623 5.347534 5.356201 5.3599854
5.362915 5.387085 5.40271 ]
[5.024658 5.0458984 5.0463867 5.069214 5.1254883 5.1533203 5.1557617
5.17395 5.1887207 5.188843 ]
[5.070923 5.1273193 5.144409 5.1882324 5.1896973 5.194702 5.20813
5.223633 5.223633 5.242798 ]
[5.173218 5.2506104 5.26355 5.3077393 5.3099365 5.3240967 5.324585
5.3448486 5.3449707 5.3479004]
[5.1730957 5.246826 5.2974854 5.319092 5.3239746 5.3275146 5.3409424
5.3464355 5.352051 5.364624 ]
[5.204468 5.27124 5.272217 5.324463 5.335449 5.348755 5.3531494
5.355713 5.359619 5.3670654]]
[[1000 1775 1651 702 1963 1063 249 1995 689 1075]
[ 243 532 1923 52 304 1212 449 1264 1092 622]
[1904 981 735 492 458 1810 1945 839 875 616]
[1307 148 250 1773 576 187 864 394 1920 1550]
[1135 1429 151 773 250 502 1945 1408 694 1849]
[1427 1463 816 1314 1096 896 8 1366 1673 939]
[ 244 854 85 560 1154 1473 951 1626 218 885]
[ 278 1176 787 70 235 326 190 1843 892 1756]
[ 81 855 416 1545 145 1811 172 383 1856 90]
[ 902 1238 725 1141 1255 593 1507 596 1400 1434]]
倒排表乘积量化(IVFADC (coarse quantizer+PQ on residuals))
M = 8
nbits = 4
nlist = 50
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, M, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[5.1813607 5.213015 5.227062 5.240297 5.3074746 5.3143606 5.3209453
5.323071 5.3241897 5.3336196]
[5.531564 5.5329485 5.562758 5.563456 5.5912466 5.6556826 5.672931
5.68467 5.6997647 5.702039 ]
[4.8162827 4.822724 4.8304024 4.8753777 4.885644 4.8879986 4.8881545
4.893022 4.8931584 4.900308 ]
[4.816623 4.83104 4.8495483 4.8536286 4.876833 4.877784 4.880099
4.884579 4.8872194 4.891768 ]
[5.0868177 5.118717 5.1225142 5.1229815 5.1336117 5.1365952 5.1425185
5.1445856 5.1706796 5.1717625]
[4.896156 4.923249 4.941252 4.9426517 4.951362 4.9712415 4.9738026
4.9810586 4.9901094 4.991113 ]
[4.98639 4.9903703 4.9991083 5.0107374 5.011694 5.0129166 5.016504
5.020243 5.0209174 5.024062 ]
[5.175429 5.175681 5.1776314 5.189157 5.1944485 5.222612 5.2232976
5.226807 5.236521 5.2451673]
[5.044052 5.080878 5.1016216 5.109776 5.1117053 5.124481 5.1256466
5.1480865 5.151196 5.1520753]
[5.129094 5.1587934 5.1708508 5.171063 5.18175 5.182567 5.1942434
5.1995926 5.201419 5.2037187]]
[[1962 1880 311 1897 666 201 647 283 1588 171]
[ 569 148 162 39 753 1032 983 934 560 1715]
[ 851 380 1322 1803 1678 1486 1504 1847 1206 306]
[ 960 1741 636 510 1568 880 866 1134 615 381]
[1799 1166 572 1631 1244 343 1212 1859 756 1630]
[ 816 1753 749 1621 444 1399 658 771 369 913]
[1902 1238 913 1546 654 969 1187 350 979 1251]
[ 464 839 1705 772 77 490 1976 998 968 57]
[1445 505 426 300 1220 122 806 755 100 1343]
[1432 1263 1198 1537 1816 263 1 430 598 260]]
cell-probe方法
为了加速索引过程,经常采用划分子类空间(如k-means)的方法,虽然这样无法保证最后返回的结果是完全正确的。先划分子类空间,再在部分子空间中搜索的方法,就是cell-probe方法。
具体流程为:
- 数据集空间被划分为
n
个部分,在k-means
中,表现为n
个类; - 每个类中的向量保存在一个倒排表中,共有
n
个倒排表; - 查询时,选中
nprobe
个倒排表; - 将这几个倒排表中的向量与查询向量作对比。
在这种方法中,只需要排查数据库中的一部分向量,大约只有nprobe/n
的数据,因为每个倒排表的长度并不一致(每个类中的向量个数不一定相等)
cell-probe粗量化
在一些索引类型中,需要一个Flat index
作为粗量化器,如IndexIVFFlat
,在训练的时候会将类中心保存在Flat index
中,在add
和search
阶段,会首先判定将其落入哪个类空间。在search
阶段,nprobe
参数需要调整以权衡检索精度与检索速度。
实验表明,对高维数据,需要维持比较高的nprobe
数值才能保证精度。
与LSH的优劣
LSH也是一种cell-probe方法,与其相比,LSH有以下几点不足:
- LSH需要大量的哈希方程,会带来额外的内存开销;
- 哈希函数不适合输入数据;
参考
开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!
更多推荐
所有评论(0)