数据准备

import faiss
import numpy as np 


d = 512          # 维数
# 向量集合
n_data = 2000   
np.random.seed(0) 
data = []
mu = 3
sigma = 0.1
for i in range(n_data):
    data.append(np.random.normal(mu, sigma, d))
data = np.array(data).astype('float32')

# query 向量
n_query = 10
np.random.seed(12) 
query = []
for i in range(n_query):
    query.append(np.random.normal(mu, sigma, d))
query = np.array(query).astype('float32')

精确搜索-L2距离(Exact Search for L2)

baseline

index = faiss.IndexFlatL2(d)  # L2距离
# index = faiss.index_factory(d, "Flat")  # 两种定义方式
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[8.61838   8.782156  8.782816  8.832027  8.837635  8.8484955 8.897978
  8.9166355 8.919006  8.937399 ]
 [9.033302  9.038906  9.091706  9.155842  9.164592  9.200113  9.201885
  9.220333  9.279479  9.312859 ]
 [8.063819  8.211029  8.306456  8.373353  8.459253  8.459894  8.498556
  8.546466  8.555407  8.621424 ]
 [8.193895  8.211957  8.34701   8.446963  8.45299   8.45486   8.473572
  8.504771  8.513636  8.530685 ]
 [8.369623  8.549446  8.704066  8.736764  8.760081  8.777317  8.831345
  8.835485  8.858271  8.860057 ]
 [8.299071  8.432397  8.434382  8.457373  8.539217  8.562357  8.579033
  8.618738  8.630859  8.6433935]
 [8.615003  8.615164  8.72604   8.730944  8.762621  8.796932  8.797066
  8.797366  8.813984  8.834725 ]
 [8.377228  8.522776  8.711159  8.724562  8.745737  8.763845  8.7686
  8.7728    8.786858  8.828223 ]
 [8.3429165 8.488056  8.655106  8.662771  8.701336  8.741288  8.7436075
  8.770506  8.786265  8.8490505]
 [8.522163  8.575702  8.684618  8.767246  8.782908  8.850494  8.883732
  8.903692  8.909395  8.917681 ]]
  
 [[1269 1525 1723 1160 1694   48 1075 1028  544  916]
 [1035  259 1279 1116 1398  879  289  882 1420 1927]
 [ 327  345 1401  389 1904 1992 1612  106  981 1179]
 [1259  112  351  804 1412 1987 1377  250 1624  133]
 [1666  854 1135  616   94  280   30   99 1212    3]
 [ 574 1523  366  766 1046   91  456  649   46  896]
 [1945  944  244  655 1686  981  256 1555 1280 1969]
 [ 879 1025  390  269 1115 1662 1831  610   11  191]
 [ 156  154   99   31 1237  289  769 1524   56  661]
 [ 427  182  375 1826  610 1384 1299  750    2 1430]]

精确搜索-点乘距离(Exact Search for Inner Product)

当数据库向量是标准化的,计算返回的distance就是余弦相似度

index = faiss.IndexFlatIP(d)  # 点乘
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[4621.75   4621.5464 4619.745  4619.381  4619.176  4618.0625 4617.169
  4617.057  4617.048  4616.6304]
 [4637.3965 4637.289  4635.3677 4635.2446 4634.881  4633.6074 4633.021
  4632.7646 4632.5596 4632.373 ]
 [4621.755  4621.47   4619.748  4619.562  4619.423  4618.0186 4616.992
  4616.962  4616.9014 4616.735 ]
 [4623.608  4623.5596 4621.3965 4621.1577 4620.9062 4619.838  4618.9756
  4618.913  4618.7695 4618.4775]
 [4625.553  4625.064  4623.461  4623.1963 4622.957  4621.337  4620.7363
  4620.717  4620.5645 4620.248 ]
 [4628.489  4628.449  4626.4917 4626.4873 4625.6406 4624.615  4624.29
  4623.999  4623.7524 4623.618 ]
 [4637.746  4637.338  4635.3047 4635.126  4634.7476 4633.0137 4632.8633
  4632.58   4632.3027 4632.233 ]
 [4630.4717 4630.334  4628.2646 4627.9375 4627.7383 4626.8975 4625.8135
  4625.7227 4625.445  4625.0913]
 [4635.7725 4635.4893 4633.6904 4633.5674 4632.658  4631.4634 4631.4307
  4631.101  4630.99   4630.3066]
 [4625.6763 4625.558  4623.454  4623.3916 4623.324  4622.2827 4621.7783
  4621.1147 4620.9043 4620.8545]]
  
 [[1562   27  681  169 1262  942 1566   31 1207  252]
 [  27 1562  169  681 1262  942 1566 1392   31  252]
 [1562   27  681  169 1262  942 1566  252   31 1392]
 [1562   27  681  169 1262  942  252 1566   31 1207]
 [1562   27  681  169 1262  942 1566  252   31 1513]
 [1562   27  169  681 1262  942  252   31 1566  911]
 [1562   27  169  681 1262 1566  942  252   31 1392]
 [1562   27  681  169 1262  942 1566   31  252 1207]
 [  27 1562  169  681 1262  942   31  252 1566 1392]
 [1562   27 1262  681  169  942 1566   31 1207  252]]

HNSW(Hierarchical Navigable Small World graph exploration)

基于图近似方法获取返回的近似结果

index = faiss.IndexHNSWFlat(d, 16)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[8.61838   8.832027  8.8484955 8.897978  8.9166355 8.919006  8.937399
  8.9597    8.984709  8.998905 ]
 [9.038906  9.164592  9.200113  9.201885  9.220333  9.312859  9.344341
  9.34485   9.416972  9.421429 ]
 [8.063819  8.459253  8.459894  8.498556  8.555407  8.631897  8.71368
  8.735945  8.770473  8.792957 ]
 [8.193895  8.211957  8.34701   8.446963  8.45486   8.473572  8.504771
  8.513636  8.530685  8.545483 ]
 [8.369623  8.760081  8.831345  8.860057  8.862643  8.93695   8.972281
  8.996923  9.065968  9.070428 ]
 [8.299071  8.432397  8.434382  8.539217  8.562357  8.6433935 8.6983185
  8.753673  8.768753  8.780443 ]
 [8.762621  8.796932  8.797066  8.813984  8.860753  8.867388  8.911812
  8.922768  8.928856  8.942961 ]
 [8.377228  8.522776  8.711159  8.724562  8.7728    8.828223  8.879469
  8.888437  8.914921  8.924161 ]
 [8.3429165 8.488056  8.662771  8.741288  8.7436075 8.770506  8.857254
  8.893715  8.933592  8.960606 ]
 [8.522163  8.575702  8.850494  8.903692  8.917681  8.936615  8.961666
  8.977329  9.009894  9.031724 ]]
  
 [[1269   48 1075 1028  916  239  897 1627  120 1567]
 [ 259 1398  879  289  882 1927   13   70 1023  121]
 [1401 1904  106  981 1623 1393 1632  539 1143  366]
 [1259  112  351  804 1987 1377  250 1624  133  879]
 [1666   94 1212  277 1723  581  106  472  807  884]
 [ 574 1523  366  766 1046   91  154  911  902  685]
 [ 944 1686  981 1391  849  280 1337 1263   91 1540]
 [ 879 1025  390  269 1115 1831  610   11  191 1686]
 [ 154   31 1237  289  769 1524  661  426 1008 1727]
 [ 182 1299 1430  511 1339 1010 1173 1457  664  529]]

倒排表搜索(Inverted file with exact post-verification)

quantizer = faiss.IndexFlatL2(d)  # 量化器
nlist = 50
index = faiss.IndexIVFFlat(quantizer, d, nlist, faiss.METRIC_L2)
index.train(data)
index.add(data)
index.nprobe = 30  # 选择nprobe个维诺空间进行索引
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[8.61838   8.782156  8.782816  8.832027  8.837635  8.8484955 8.897978
  8.9166355 8.919006  8.937399 ]
 [9.033302  9.038906  9.091706  9.164592  9.200113  9.201885  9.220333
  9.279479  9.312859  9.344341 ]
 [8.063819  8.211029  8.306456  8.373353  8.459253  8.459894  8.498556
  8.546466  8.555407  8.621424 ]
 [8.193895  8.211957  8.34701   8.446963  8.45299   8.45486   8.473572
  8.504771  8.513636  8.530685 ]
 [8.369623  8.549446  8.704066  8.736764  8.760081  8.777317  8.831345
  8.835485  8.858271  8.860057 ]
 [8.299071  8.432397  8.434382  8.457373  8.539217  8.562357  8.579033
  8.618738  8.630859  8.6983185]
 [8.615003  8.615164  8.72604   8.730944  8.762621  8.796932  8.797066
  8.797366  8.813984  8.834725 ]
 [8.377228  8.522776  8.711159  8.724562  8.745737  8.763845  8.7686
  8.7728    8.786858  8.828223 ]
 [8.3429165 8.488056  8.655106  8.662771  8.701336  8.741288  8.7436075
  8.770506  8.786265  8.8490505]
 [8.522163  8.575702  8.684618  8.767246  8.782908  8.850494  8.883732
  8.903692  8.909395  8.917681 ]]
 
 [[1269 1525 1723 1160 1694   48 1075 1028  544  916]
 [1035  259 1279 1398  879  289  882 1420 1927   13]
 [ 327  345 1401  389 1904 1992 1612  106  981 1179]
 [1259  112  351  804 1412 1987 1377  250 1624  133]
 [1666  854 1135  616   94  280   30   99 1212    3]
 [ 574 1523  366  766 1046   91  456  649   46  154]
 [1945  944  244  655 1686  981  256 1555 1280 1969]
 [ 879 1025  390  269 1115 1662 1831  610   11  191]
 [ 156  154   99   31 1237  289  769 1524   56  661]
 [ 427  182  375 1826  610 1384 1299  750    2 1430]]

LSH(Locality-Sensitive Hashing (binary flat index))

nbits = 2 * d
index = faiss.IndexLSH(d, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[ 8. 10. 10. 10. 10. 10. 10. 11. 11. 11.]
 [ 7.  8.  9.  9.  9. 10. 10. 10. 10. 10.]
 [ 7.  8.  8.  9.  9.  9.  9.  9.  9.  9.]
 [ 9.  9. 10. 11. 12. 12. 12. 12. 12. 12.]
 [ 6.  6.  6.  7.  7.  8.  8.  8.  8.  8.]
 [ 8.  8.  8.  9.  9.  9.  9.  9. 10. 10.]
 [ 6.  7.  8.  8.  9.  9.  9.  9.  9.  9.]
 [ 9.  9.  9.  9.  9.  9.  9.  9.  9. 10.]
 [ 7.  8.  8.  8.  8.  8.  8.  9.  9.  9.]
 [ 9.  9.  9. 10. 10. 10. 10. 10. 10. 10.]]
 
 [[1424  345 1544  760 1589 1043  668  492  148  666]
 [1974 1436 1476   51  711  696   28  934  541 1125]
 [1667 1356 1149  512 1592 1544 1677  309 1021 1018]
 [ 708  107  606  243   18  612  598  615  269  250]
 [1455  541 1142  843 1140  888  165  961  797 1003]
 [1735  193  953 1071 1518 1109  449  263 1329 1216]
 [1129 1231 1731  123  860  907  381  993  336 1071]
 [1622  336 1970  845   70 1921 1973  980  331   42]
 [1335  713 1589  395  263 1206  346  698  913  678]
 [1395  279 1305  427 1707 1574 1710  226 1205 1160]]

SQ量化(Scalar quantizer (SQ) in flat mode)

index = faiss.IndexScalarQuantizer(d, 4)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[8.623228  8.777794  8.785317  8.828827  8.835491  8.845296  8.896896
  8.914822  8.922382  8.934984 ]
 [9.028503  9.037548  9.099254  9.152615  9.16542   9.196389  9.200497
  9.224977  9.274048  9.305386 ]
 [8.064029  8.213011  8.310526  8.376434  8.4578285 8.462004  8.500969
  8.550645  8.556992  8.624526 ]
 [8.196653  8.210537  8.3464365 8.444774  8.452023  8.454119  8.474525
  8.4966135 8.510042  8.525611 ]
 [8.370451  8.547961  8.704324  8.733618  8.763927  8.776734  8.82951
  8.835646  8.857151  8.859047 ]
 [8.29591   8.432422  8.435947  8.454732  8.542397  8.565366  8.579685
  8.621871  8.632036  8.64478  ]
 [8.609016  8.612934  8.726631  8.734137  8.758858  8.797329  8.797971
  8.798654  8.815296  8.838221 ]
 [8.37895   8.521536  8.710902  8.726156  8.748387  8.75966   8.768217
  8.769184  8.792372  8.834644 ]
 [8.340463  8.489507  8.659348  8.664953  8.702758  8.741514  8.741941
  8.768995  8.781276  8.852154 ]
 [8.520282  8.5749855 8.68346   8.769207  8.782043  8.851276  8.881114
  8.906744  8.907756  8.924014 ]]
 
 [[1269 1723 1525 1160 1694   48 1075  544 1028  916]
 [1035  259 1279 1116 1398  289  879  882 1420 1927]
 [ 327  345 1401  389 1992 1904 1612  106  981 1179]
 [1259  112  351  804 1987 1412 1377  250 1624  133]
 [1666  854 1135  616   94  280   30   99    3 1212]
 [ 574 1523  366  766 1046   91  456  649   46  896]
 [ 944 1945  244  655 1686  256  981 1555 1280 1969]
 [ 879 1025  390  269 1115 1662 1831  610   11  191]
 [ 156  154   99   31 1237  289  769 1524   56  661]
 [ 427  182  375 1826  610 1384 1299  750    2 1430]]

PQ量化(Product quantizer (PQ) in flat mode)

M = 8 # 必须是d的因数
nbits = 6  # 只能是8,12,16
index = faiss.IndexPQ(d, M, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[5.3148193 5.33667   5.390381  5.3969727 5.4020996 5.402466  5.4088135
  5.420532  5.4210205 5.4370117]
 [5.694214  5.71875   5.7408447 5.7418213 5.743042  5.7543945 5.7611084
  5.764282  5.786255  5.791748 ]
 [4.880615  4.9449463 4.9451904 5.0146484 5.022583  5.0406494 5.0444336
  5.0650635 5.06604   5.0666504]
 [4.8305664 4.852661  4.8569336 4.87146   4.8901367 4.8969727 4.9003906
  4.9073486 4.9093018 4.911743 ]
 [5.2233887 5.3170166 5.3239746 5.338623  5.347534  5.356201  5.3599854
  5.362915  5.387085  5.40271  ]
 [5.024658  5.0458984 5.0463867 5.069214  5.1254883 5.1533203 5.1557617
  5.17395   5.1887207 5.188843 ]
 [5.070923  5.1273193 5.144409  5.1882324 5.1896973 5.194702  5.20813
  5.223633  5.223633  5.242798 ]
 [5.173218  5.2506104 5.26355   5.3077393 5.3099365 5.3240967 5.324585
  5.3448486 5.3449707 5.3479004]
 [5.1730957 5.246826  5.2974854 5.319092  5.3239746 5.3275146 5.3409424
  5.3464355 5.352051  5.364624 ]
 [5.204468  5.27124   5.272217  5.324463  5.335449  5.348755  5.3531494
  5.355713  5.359619  5.3670654]]
 
 [[1000 1775 1651  702 1963 1063  249 1995  689 1075]
 [ 243  532 1923   52  304 1212  449 1264 1092  622]
 [1904  981  735  492  458 1810 1945  839  875  616]
 [1307  148  250 1773  576  187  864  394 1920 1550]
 [1135 1429  151  773  250  502 1945 1408  694 1849]
 [1427 1463  816 1314 1096  896    8 1366 1673  939]
 [ 244  854   85  560 1154 1473  951 1626  218  885]
 [ 278 1176  787   70  235  326  190 1843  892 1756]
 [  81  855  416 1545  145 1811  172  383 1856   90]
 [ 902 1238  725 1141 1255  593 1507  596 1400 1434]]

倒排表乘积量化(IVFADC (coarse quantizer+PQ on residuals))

M = 8
nbits = 4
nlist = 50
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, M, nbits)
index.train(data)
index.add(data)
dis, ind = index.search(query, 10)
print(dis)
print(ind)
[[5.1813607 5.213015  5.227062  5.240297  5.3074746 5.3143606 5.3209453
  5.323071  5.3241897 5.3336196]
 [5.531564  5.5329485 5.562758  5.563456  5.5912466 5.6556826 5.672931
  5.68467   5.6997647 5.702039 ]
 [4.8162827 4.822724  4.8304024 4.8753777 4.885644  4.8879986 4.8881545
  4.893022  4.8931584 4.900308 ]
 [4.816623  4.83104   4.8495483 4.8536286 4.876833  4.877784  4.880099
  4.884579  4.8872194 4.891768 ]
 [5.0868177 5.118717  5.1225142 5.1229815 5.1336117 5.1365952 5.1425185
  5.1445856 5.1706796 5.1717625]
 [4.896156  4.923249  4.941252  4.9426517 4.951362  4.9712415 4.9738026
  4.9810586 4.9901094 4.991113 ]
 [4.98639   4.9903703 4.9991083 5.0107374 5.011694  5.0129166 5.016504
  5.020243  5.0209174 5.024062 ]
 [5.175429  5.175681  5.1776314 5.189157  5.1944485 5.222612  5.2232976
  5.226807  5.236521  5.2451673]
 [5.044052  5.080878  5.1016216 5.109776  5.1117053 5.124481  5.1256466
  5.1480865 5.151196  5.1520753]
 [5.129094  5.1587934 5.1708508 5.171063  5.18175   5.182567  5.1942434
  5.1995926 5.201419  5.2037187]]
  
 [[1962 1880  311 1897  666  201  647  283 1588  171]
 [ 569  148  162   39  753 1032  983  934  560 1715]
 [ 851  380 1322 1803 1678 1486 1504 1847 1206  306]
 [ 960 1741  636  510 1568  880  866 1134  615  381]
 [1799 1166  572 1631 1244  343 1212 1859  756 1630]
 [ 816 1753  749 1621  444 1399  658  771  369  913]
 [1902 1238  913 1546  654  969 1187  350  979 1251]
 [ 464  839 1705  772   77  490 1976  998  968   57]
 [1445  505  426  300 1220  122  806  755  100 1343]
 [1432 1263 1198 1537 1816  263    1  430  598  260]]

cell-probe方法

为了加速索引过程,经常采用划分子类空间(如k-means)的方法,虽然这样无法保证最后返回的结果是完全正确的。先划分子类空间,再在部分子空间中搜索的方法,就是cell-probe方法。
具体流程为:

  • 数据集空间被划分为n个部分,在k-means中,表现为n个类;
  • 每个类中的向量保存在一个倒排表中,共有n个倒排表;
  • 查询时,选中nprobe个倒排表;
  • 将这几个倒排表中的向量与查询向量作对比。

在这种方法中,只需要排查数据库中的一部分向量,大约只有nprobe/n的数据,因为每个倒排表的长度并不一致(每个类中的向量个数不一定相等)

cell-probe粗量化

在一些索引类型中,需要一个Flat index作为粗量化器,如IndexIVFFlat,在训练的时候会将类中心保存在Flat index中,在addsearch阶段,会首先判定将其落入哪个类空间。在search阶段,nprobe参数需要调整以权衡检索精度与检索速度
实验表明,对高维数据,需要维持比较高的nprobe数值才能保证精度

与LSH的优劣

LSH也是一种cell-probe方法,与其相比,LSH有以下几点不足:

  • LSH需要大量的哈希方程,会带来额外的内存开销
  • 哈希函数不适合输入数据

参考

Logo

开放原子开发者工作坊旨在鼓励更多人参与开源活动,与志同道合的开发者们相互交流开发经验、分享开发心得、获取前沿技术趋势。工作坊有多种形式的开发者活动,如meetup、训练营等,主打技术交流,干货满满,真诚地邀请各位开发者共同参与!

更多推荐