Skip to content

Commit

Permalink
introduce bf16 quantization
Browse files Browse the repository at this point in the history
- add bf16 quantization
- enable on hgraph & brute_force

Signed-off-by: LHT129 <[email protected]>
  • Loading branch information
LHT129 committed Feb 19, 2025
1 parent 7b1ebaa commit 5046cae
Show file tree
Hide file tree
Showing 16 changed files with 459 additions and 14 deletions.
5 changes: 3 additions & 2 deletions src/algorithm/hgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,8 @@ HGraph::EstimateMemory(uint64_t num_elements) const {
(static_cast<double>(this->bottom_graph_->maximum_degree_) / 2 + 1);
estimate_memory += static_cast<uint64_t>(sparse_graph_memory);

auto other_memory = element_count * (sizeof(LabelType) + sizeof(std::shared_mutex));
auto other_memory = element_count * (sizeof(LabelType) + sizeof(std::shared_mutex) +
sizeof(std::shared_ptr<std::shared_mutex>));
estimate_memory += other_memory;

return estimate_memory;
Expand Down Expand Up @@ -788,7 +789,7 @@ HGraph::init_features() {

// About Train
auto name = this->basic_flatten_codes_->GetQuantizerName();
if (name != QUANTIZATION_TYPE_VALUE_FP32) {
if (name != QUANTIZATION_TYPE_VALUE_FP32 and name != QUANTIZATION_TYPE_VALUE_BF16) {
feature_list_.SetFeature(IndexFeature::NEED_TRAIN);
} else {
feature_list_.SetFeature(IndexFeature::SUPPORT_CAL_DISTANCE_BY_ID);
Expand Down
3 changes: 3 additions & 0 deletions src/data_cell/flatten_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ make_instance(const FlattenDataCellParamPtr& param, const IndexCommonParam& comm
if (quantization_string == QUANTIZATION_TYPE_VALUE_SQ8_UNIFORM) {
return make_instance<SQ8UniformQuantizer<metric>, IOTemp>(param, common_param);
}
if (quantization_string == QUANTIZATION_TYPE_VALUE_BF16) {
return make_instance<BF16Quantizer<metric>, IOTemp>(param, common_param);
}
return nullptr;
}

Expand Down
2 changes: 1 addition & 1 deletion src/index/brute_force.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ void
BruteForce::init_feature_list() {
// About Train
auto name = this->inner_codes_->GetQuantizerName();
if (name != QUANTIZATION_TYPE_VALUE_FP32) {
if (name != QUANTIZATION_TYPE_VALUE_FP32 and name != QUANTIZATION_TYPE_VALUE_BF16) {
feature_list_.SetFeature(IndexFeature::NEED_TRAIN);
} else {
feature_list_.SetFeatures({
Expand Down
1 change: 1 addition & 0 deletions src/inner_string_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ const char* const QUANTIZATION_TYPE_VALUE_SQ8_UNIFORM = "sq8_uniform";
const char* const QUANTIZATION_TYPE_VALUE_SQ4 = "sq4";
const char* const QUANTIZATION_TYPE_VALUE_SQ4_UNIFORM = "sq4_uniform";
const char* const QUANTIZATION_TYPE_VALUE_FP32 = "fp32";
const char* const QUANTIZATION_TYPE_VALUE_BF16 = "bf16";
const char* const QUANTIZATION_TYPE_VALUE_PQ = "pq";

// graph param value
Expand Down
1 change: 1 addition & 0 deletions src/quantization/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ set (QUANTIZER_SRC
scalar_quantization/sq8_uniform_quantizer_parameter.cpp
scalar_quantization/sq4_quantizer_parameter.cpp
scalar_quantization/sq4_uniform_quantizer_parameter.cpp
scalar_quantization/bf16_quantizer_parameter.cpp
scalar_quantization/scalar_quantization_trainer.cpp
)

Expand Down
3 changes: 3 additions & 0 deletions src/quantization/quantizer_parameter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ QuantizerParameter::GetQuantizerParameterByJson(const JsonType& json) {
} else if (type_name == QUANTIZATION_TYPE_VALUE_SQ4_UNIFORM) {
quantizer_param = std::make_shared<SQ4UniformQuantizerParameter>();
quantizer_param->FromJson(json);
} else if (type_name == QUANTIZATION_TYPE_VALUE_BF16) {
quantizer_param = std::make_shared<BF16QuantizerParameter>();
quantizer_param->FromJson(json);
} else {
throw std::invalid_argument(fmt::format("invalid quantizer name {}", type_name));
}
Expand Down
12 changes: 6 additions & 6 deletions src/quantization/quantizer_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,37 +171,37 @@ TestComputer(
need_normalize = false;
}
auto vecs = fixtures::generate_vectors(count, dim, need_normalize);
auto querys = fixtures::generate_vectors(query_count, dim, need_normalize, 165);
auto queries = fixtures::generate_vectors(query_count, dim, need_normalize, 165);
if (retrain) {
quant.ReTrain(vecs.data(), count);
}

auto gt_func = [&](int base_idx, int query_idx) -> float {
if constexpr (metric == vsag::MetricType::METRIC_TYPE_IP) {
return 1 - InnerProduct(
vecs.data() + base_idx * dim, querys.data() + query_idx * dim, &dim);
vecs.data() + base_idx * dim, queries.data() + query_idx * dim, &dim);
} else if constexpr (metric == vsag::MetricType::METRIC_TYPE_L2SQR) {
return L2Sqr(vecs.data() + base_idx * dim, querys.data() + query_idx * dim, &dim);
return L2Sqr(vecs.data() + base_idx * dim, queries.data() + query_idx * dim, &dim);
} else if constexpr (metric == vsag::MetricType::METRIC_TYPE_COSINE) {
std::vector<float> v1(dim);
std::vector<float> v2(dim);
Normalize(vecs.data() + base_idx * dim, v1.data(), dim);
Normalize(querys.data() + query_idx * dim, v2.data(), dim);
Normalize(queries.data() + query_idx * dim, v2.data(), dim);
return 1 - InnerProduct(v1.data(), v2.data(), &dim);
}
};

for (int i = 0; i < query_count; ++i) {
std::shared_ptr<Computer<T>> computer;
computer = quant.FactoryComputer();
computer->SetQuery(querys.data() + i * dim);
computer->SetQuery(queries.data() + i * dim);

// Test Compute One Dist;
for (int j = 0; j < 100; ++j) {
auto idx1 = random() % count;
auto* codes1 = new uint8_t[quant.GetCodeSize()];
quant.EncodeOne(vecs.data() + idx1 * dim, codes1);
float value = 0.0f;
float value = 0.0F;
quant.ComputeDist(*computer, codes1, &value);
REQUIRE(quant.ComputeDist(*computer, codes1) == value);
auto gt = gt_func(idx1, i);
Expand Down
225 changes: 225 additions & 0 deletions src/quantization/scalar_quantization/bf16_quantizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cstring>
#include <limits>
#include <vector>

#include "bf16_quantizer_parameter.h"
#include "byte_buffer.h"
#include "index/index_common_param.h"
#include "inner_string_params.h"
#include "quantization/quantizer.h"
#include "simd/bf16_simd.h"
#include "simd/normalize.h"
#include "typing.h"

namespace vsag {

template <MetricType metric = MetricType::METRIC_TYPE_L2SQR>
class BF16Quantizer : public Quantizer<BF16Quantizer<metric>> {
public:
explicit BF16Quantizer(int dim, Allocator* allocator);

explicit BF16Quantizer(const BF16QuantizerParamPtr& param,
const IndexCommonParam& common_param);

explicit BF16Quantizer(const QuantizerParamPtr& param, const IndexCommonParam& common_param);

bool
TrainImpl(const DataType* data, uint64_t count);

bool
EncodeOneImpl(const DataType* data, uint8_t* codes) const;

bool
EncodeBatchImpl(const DataType* data, uint8_t* codes, uint64_t count);

bool
DecodeOneImpl(const uint8_t* codes, DataType* data);

bool
DecodeBatchImpl(const uint8_t* codes, DataType* data, uint64_t count);

inline float
ComputeImpl(const uint8_t* codes1, const uint8_t* codes2);

inline void
ProcessQueryImpl(const DataType* query, Computer<BF16Quantizer>& computer) const;

inline void
ComputeDistImpl(Computer<BF16Quantizer>& computer, const uint8_t* codes, float* dists) const;

inline void
ComputeBatchDistImpl(Computer<BF16Quantizer<metric>>& computer,
uint64_t count,
const uint8_t* codes,
float* dists) const;

inline void
ReleaseComputerImpl(Computer<BF16Quantizer<metric>>& computer) const;

inline void
SerializeImpl(StreamWriter& writer){};

inline void
DeserializeImpl(StreamReader& reader){};

[[nodiscard]] std::string
NameImpl() const {
return QUANTIZATION_TYPE_VALUE_BF16;
}
};

template <MetricType metric>
BF16Quantizer<metric>::BF16Quantizer(int dim, Allocator* allocator)
: Quantizer<BF16Quantizer<metric>>(dim, allocator) {
this->code_size_ = dim * 2;
}

template <MetricType metric>
BF16Quantizer<metric>::BF16Quantizer(const BF16QuantizerParamPtr& param,
const IndexCommonParam& common_param)
: BF16Quantizer<metric>(common_param.dim_, common_param.allocator_.get()){};

template <MetricType metric>
BF16Quantizer<metric>::BF16Quantizer(const QuantizerParamPtr& param,
const IndexCommonParam& common_param)
: BF16Quantizer<metric>(std::dynamic_pointer_cast<BF16QuantizerParameter>(param),
common_param){};

template <MetricType metric>
bool
BF16Quantizer<metric>::TrainImpl(const DataType* data, uint64_t count) {
if (data == nullptr) {
return false;

Check warning on line 109 in src/quantization/scalar_quantization/bf16_quantizer.h

View check run for this annotation

Codecov / codecov/patch

src/quantization/scalar_quantization/bf16_quantizer.h#L109

Added line #L109 was not covered by tests
}
return true;
}

template <MetricType metric>
bool
BF16Quantizer<metric>::EncodeOneImpl(const DataType* data, uint8_t* codes) const {
const DataType* cur = data;
Vector<float> tmp(this->allocator_);
if constexpr (metric == MetricType::METRIC_TYPE_COSINE) {
tmp.resize(this->dim_);
Normalize(data, tmp.data(), this->dim_);
cur = tmp.data();
}
auto* codes_bf16 = reinterpret_cast<uint16_t*>(codes);
for (int i = 0; i < this->dim_; ++i) {
codes_bf16[i] = generic::FloatToBF16(cur[i]);
}

return true;
}

template <MetricType metric>
bool
BF16Quantizer<metric>::EncodeBatchImpl(const DataType* data, uint8_t* codes, uint64_t count) {
for (uint64_t i = 0; i < count; ++i) {
this->EncodeOneImpl(data + i * this->dim_, codes + i * this->code_size_);
}
return true;
}

template <MetricType metric>
bool
BF16Quantizer<metric>::DecodeOneImpl(const uint8_t* codes, DataType* data) {
const auto* codes_bf16 = reinterpret_cast<const uint16_t*>(codes);

for (uint64_t d = 0; d < this->dim_; d++) {
data[d] = generic::BF16ToFloat(codes_bf16[d]);
}
return true;
}

template <MetricType metric>
bool
BF16Quantizer<metric>::DecodeBatchImpl(const uint8_t* codes, DataType* data, uint64_t count) {
for (uint64_t i = 0; i < count; ++i) {
this->DecodeOneImpl(codes + i * this->code_size_, data + i * this->dim_);
}
return true;
}

template <MetricType metric>
inline float
BF16Quantizer<metric>::ComputeImpl(const uint8_t* codes1, const uint8_t* codes2) {
if constexpr (metric == MetricType::METRIC_TYPE_L2SQR) {
return BF16ComputeL2Sqr(codes1, codes2, this->dim_);
} else if constexpr (metric == MetricType::METRIC_TYPE_IP) {
return 1 - BF16ComputeIP(codes1, codes2, this->dim_);
} else if constexpr (metric == MetricType::METRIC_TYPE_COSINE) {
return 1 - BF16ComputeIP(codes1, codes2, this->dim_);
} else {
return 0;
}
}

template <MetricType metric>
void
BF16Quantizer<metric>::ProcessQueryImpl(const DataType* query,
Computer<BF16Quantizer>& computer) const {
try {
computer.buf_ = reinterpret_cast<uint8_t*>(this->allocator_->Allocate(this->code_size_));
this->EncodeOneImpl(query, computer.buf_);
} catch (const std::bad_alloc& e) {
computer.buf_ = nullptr;
logger::error("bad alloc when init computer buf");
throw std::bad_alloc();
}
}

template <MetricType metric>
void
BF16Quantizer<metric>::ComputeDistImpl(Computer<BF16Quantizer>& computer,
const uint8_t* codes,
float* dists) const {
auto* buf = computer.buf_;
if constexpr (metric == MetricType::METRIC_TYPE_L2SQR) {
dists[0] = BF16ComputeL2Sqr(buf, codes, this->dim_);
} else if constexpr (metric == MetricType::METRIC_TYPE_IP) {
dists[0] = 1 - BF16ComputeIP(buf, codes, this->dim_);
} else if constexpr (metric == MetricType::METRIC_TYPE_COSINE) {
dists[0] = 1 - BF16ComputeIP(buf, codes, this->dim_);
} else {
logger::error("unsupported metric type");
dists[0] = 0;
}
}

template <MetricType metric>
void
BF16Quantizer<metric>::ComputeBatchDistImpl(Computer<BF16Quantizer<metric>>& computer,
uint64_t count,
const uint8_t* codes,
float* dists) const {
// TODO(LHT): Optimize batch for simd
for (uint64_t i = 0; i < count; ++i) {
this->ComputeDistImpl(computer, codes + i * this->code_size_, dists + i);
}
}

template <MetricType metric>
void
BF16Quantizer<metric>::ReleaseComputerImpl(Computer<BF16Quantizer<metric>>& computer) const {
this->allocator_->Deallocate(computer.buf_);
}

} // namespace vsag
36 changes: 36 additions & 0 deletions src/quantization/scalar_quantization/bf16_quantizer_parameter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@

// Copyright 2024-present the vsag project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "bf16_quantizer_parameter.h"

#include "inner_string_params.h"

namespace vsag {

BF16QuantizerParameter::BF16QuantizerParameter()
: QuantizerParameter(QUANTIZATION_TYPE_VALUE_BF16) {
}

void
BF16QuantizerParameter::FromJson(const JsonType& json) {
}

JsonType
BF16QuantizerParameter::ToJson() {
JsonType json;
json[QUANTIZATION_TYPE_KEY] = QUANTIZATION_TYPE_VALUE_BF16;
return json;
}
} // namespace vsag
Loading

0 comments on commit 5046cae

Please sign in to comment.