Skip to content
Snippets Groups Projects
Commit e3fa3561 authored by David Sehnal's avatar David Sehnal
Browse files

BinaryCIF field classifier, auto classify fields & field encoding provider

parent 308a6b57
No related branches found
No related tags found
No related merge requests found
......@@ -7,172 +7,7 @@
import { Column } from 'mol-data/db'
import { CifField } from 'mol-io/reader/cif/data-model'
import { CifWriter } from 'mol-io/writer/cif'
import { ArrayEncoder, ArrayEncoding as E } from 'mol-io/common/binary-cif';
namespace IntClassifier {
function packSize(value: number, upperLimit: number) {
return value >= 0
? Math.ceil((value + 1) / upperLimit)
: Math.ceil((value + 1) / (-upperLimit - 1));
}
type IntColumnInfo = { signed: boolean, limit8: number, limit16: number };
function getInfo(data: number[]): IntColumnInfo {
let signed = false;
for (let i = 0, n = data.length; i < n; i++) {
if (data[i] < 0) {
signed = true;
break;
}
}
return signed ? { signed, limit8: 0x7F, limit16: 0x7FFF } : { signed, limit8: 0xFF, limit16: 0xFFFF };
}
type SizeInfo = { pack8: number, pack16: number, count: number }
function SizeInfo(): SizeInfo { return { pack8: 0, pack16: 0, count: 0 } };
function incSize({ limit8, limit16 }: IntColumnInfo, info: SizeInfo, value: number) {
info.pack8 += packSize(value, limit8);
info.pack16 += packSize(value, limit16);
info.count += 1;
}
function incSizeSigned(info: SizeInfo, value: number) {
info.pack8 += packSize(value, 0x7F);
info.pack16 += packSize(value, 0x7FFF);
info.count += 1;
}
function byteSize(info: SizeInfo) {
if (info.count * 4 < info.pack16 * 2) return { length: info.count * 4, elem: 4 };
if (info.pack16 * 2 < info.pack8) return { length: info.pack16 * 2, elem: 2 };
return { length: info.pack8, elem: 1 };
}
function packingSize(data: number[], info: IntColumnInfo) {
const size = SizeInfo();
for (let i = 0, n = data.length; i < n; i++) {
incSize(info, size, data[i]);
}
return { ...byteSize(size), kind: 'pack' };
}
function deltaSize(data: number[], info: IntColumnInfo) {
const size = SizeInfo();
let prev = data[0];
for (let i = 1, n = data.length; i < n; i++) {
incSizeSigned(size, data[i] - prev);
prev = data[i];
}
return { ...byteSize(size), kind: 'delta' };
}
function rleSize(data: number[], info: IntColumnInfo) {
const size = SizeInfo();
let run = 1;
for (let i = 1, n = data.length; i < n; i++) {
if (data[i - 1] !== data[i]) {
incSize(info, size, data[i - 1]);
incSize(info, size, run);
run = 1;
} else {
run++;
}
}
incSize(info, size, data[data.length - 1]);
incSize(info, size, run);
return { ...byteSize(size), kind: 'rle' };
}
function deltaRleSize(data: number[], info: IntColumnInfo) {
const size = SizeInfo();
let run = 1, prev = 0, prevValue = 0;
for (let i = 1, n = data.length; i < n; i++) {
const v = data[i] - prev;
if (prevValue !== v) {
incSizeSigned(size, prevValue);
incSizeSigned(size, run);
run = 1;
} else {
run++;
}
prevValue = v;
prev = data[i];
}
incSizeSigned(size, prevValue);
incSizeSigned(size, run);
return { ...byteSize(size), kind: 'delta-rle' };
}
export function getSize(data: number[]) {
const info = getInfo(data);
const sizes = [packingSize(data, info), rleSize(data, info), deltaSize(data, info), deltaRleSize(data, info)];
sizes.sort((a, b) => a.length - b.length);
return sizes;
}
export function classify(data: number[], name: string): ArrayEncoder {
if (data.length < 2) return E.by(E.byteArray);
const sizes = getSize(data);
const size = sizes[0];
// console.log(`${name}: ${size.kind} ${size.length}b ${data.length}`);
// console.log(`${name}: ${sizes.map(s => `${s.kind}: ${s.length}b`).join(' | ')}`);
switch (size.kind) {
case 'pack': return E.by(E.integerPacking);
case 'rle': return E.by(E.runLength).and(E.integerPacking);
case 'delta': return E.by(E.delta).and(E.integerPacking);
case 'delta-rle': return E.by(E.delta).and(E.runLength).and(E.integerPacking);
}
throw 'bug';
}
}
namespace FloatClassifier {
const delta = 1e-6;
function digitCount(v: number) {
let m = 1;
for (let i = 0; i < 5; i++) {
const r = Math.round(m * v) / m;
if (Math.abs(v - r) < delta) return m;
m *= 10;
}
return 10000;
}
export function classify(data: number[], name: string) {
// if a vector/matrix, do not reduce precision
if (name.indexOf('[') > 0) return { encoder: E.by(E.byteArray), typedArray: Float64Array };
let dc = 10;
for (let i = 0, n = data.length; i < n; i++) dc = Math.max(dc, digitCount(data[i]));
if (dc >= 10000) return { encoder: E.by(E.byteArray), typedArray: Float64Array };
const intArray = new Int32Array(data.length);
for (let i = 0, n = data.length; i < n; i++) intArray[i] = data[i] * dc;
const sizes = IntClassifier.getSize(intArray as any);
const size = sizes[0];
// console.log(`>> ${name}: ${size.kind} ${size.length}b ${data.length} x${dc}`);
// console.log(` ${name}: ${sizes.map(s => `${s.kind}: ${s.length}b`).join(' | ')}`);
switch (size.kind) {
case 'pack': return { encoder: E.by(E.fixedPoint(dc)).and(E.integerPacking), typedArray: Float32Array };
case 'rle': return { encoder: E.by(E.fixedPoint(dc)).and(E.runLength).and(E.integerPacking), typedArray: Float32Array };
case 'delta': return { encoder: E.by(E.fixedPoint(dc)).and(E.delta).and(E.integerPacking), typedArray: Float32Array };
case 'delta-rle': return { encoder: E.by(E.fixedPoint(dc)).and(E.delta).and(E.runLength).and(E.integerPacking), typedArray: Float32Array };
}
throw 'bug';
}
}
import { classifyFloatArray, classifyIntArray } from 'mol-io/common/binary-cif/classifier';
const intRegex = /^-?\d+$/
const floatRegex = /^-?(([0-9]+)[.]?|([0-9]*[.][0-9]+))([(][0-9]+[)])?([eE][+-]?[0-9]+)?$/
......@@ -192,10 +27,10 @@ function classify(name: string, field: CifField): CifWriter.Field {
if (hasString) return { name, type: CifWriter.Field.Type.Str, value: field.str, valueKind: field.valueKind };
if (floatCount > 0) {
const { encoder, typedArray } = FloatClassifier.classify(field.toFloatArray({ array: Float64Array }) as number[], name)
return CifWriter.Field.float(name, field.float, { valueKind: field.valueKind, encoder, typedArray });
const encoder = classifyFloatArray(field.toFloatArray({ array: Float64Array }));
return CifWriter.Field.float(name, field.float, { valueKind: field.valueKind, encoder, typedArray: Float64Array });
} else {
const encoder = IntClassifier.classify(field.toIntArray({ array: Int32Array }) as number[], name);
const encoder = classifyIntArray(field.toIntArray({ array: Int32Array }));
return CifWriter.Field.int(name, field.int, { valueKind: field.valueKind, encoder, typedArray: Int32Array });
}
}
......
......@@ -53,6 +53,24 @@ export namespace ArrayEncoder {
export function by(f: ArrayEncoding.Provider): ArrayEncoder {
return new ArrayEncoderImpl([f]);
}
export function fromEncoding(encoding: Encoding[]) {
const e = by(getProvider(encoding[0]));
for (let i = 1; i < encoding.length; i++) e.and(getProvider(encoding[i]));
return e;
}
function getProvider(e: Encoding): ArrayEncoding.Provider {
switch (e.kind) {
case 'ByteArray': return ArrayEncoding.byteArray;
case 'FixedPoint': return ArrayEncoding.fixedPoint(e.factor);
case 'IntervalQuantization': return ArrayEncoding.intervalQuantizaiton(e.min, e.max, e.numSteps);
case 'RunLength': return ArrayEncoding.runLength;
case 'Delta': return ArrayEncoding.delta;
case 'IntegerPacking': return ArrayEncoding.integerPacking;
case 'StringArray': return ArrayEncoding.stringArray;
}
}
}
export namespace ArrayEncoding {
......
......@@ -6,7 +6,7 @@
*/
import { ArrayEncoder, ArrayEncoding as E } from './array-encoder';
import { getArrayMantissaMultiplier } from 'mol-util/number';
import { getArrayDigitCount } from 'mol-util/number';
export function classifyIntArray(xs: ArrayLike<number>) {
return IntClassifier.classify(xs as number[]);
......@@ -141,33 +141,42 @@ namespace IntClassifier {
namespace FloatClassifier {
const delta = 1e-6;
export function classify(data: number[]) {
const digitCount = getArrayMantissaMultiplier(data, 4, delta);
if (digitCount < 0) return { encoder: E.by(E.byteArray), typedArray: Float64Array };
const maxDigits = 4;
// TODO: check for overflows here?
if (digitCount === 1) return { encoder: IntClassifier.classify(data), typedArray: Int32Array }
const { mantissaDigits, integerDigits } = getArrayDigitCount(data, maxDigits, delta);
// TODO: better check for overflows here?
if (mantissaDigits < 0 || mantissaDigits + integerDigits > 10) return E.by(E.byteArray);
// TODO: this needs a conversion to Int?Array?
if (mantissaDigits === 0) return IntClassifier.classify(data);
const multiplier = getMultiplier(mantissaDigits);
const intArray = new Int32Array(data.length);
for (let i = 0, n = data.length; i < n; i++) {
const v = digitCount * data[i];
intArray[i] = v;
// check if the value didn't overflow
if (Math.abs(Math.round(v) / digitCount - intArray[i] / digitCount) > delta) {
return { encoder: E.by(E.byteArray), typedArray: Float64Array };
}
intArray[i] = Math.round(multiplier * data[i]);
// TODO: enable this again?
// const v = Math.round(multiplier * data[i]);
// if (Math.abs(Math.round(v) / multiplier - intArray[i] / multiplier) > delta) {
// return E.by(E.byteArray);
// }
}
const sizes = IntClassifier.getSize(intArray as any);
const size = sizes[0];
const fp = E.by(E.fixedPoint(digitCount));
const fp = E.by(E.fixedPoint(multiplier));
switch (size.kind) {
case 'pack': return { encoder: fp.and(E.integerPacking), typedArray: Float32Array };
case 'rle': return { encoder: fp.and(E.runLength).and(E.integerPacking), typedArray: Float32Array };
case 'delta': return { encoder: fp.and(E.delta).and(E.integerPacking), typedArray: Float32Array };
case 'delta-rle': return { encoder: fp.and(E.delta).and(E.runLength).and(E.integerPacking), typedArray: Float32Array };
case 'pack': return fp.and(E.integerPacking);
case 'rle': return fp.and(E.runLength).and(E.integerPacking);
case 'delta': return fp.and(E.delta).and(E.integerPacking);
case 'delta-rle': return fp.and(E.delta).and(E.runLength).and(E.integerPacking);
}
throw new Error('should not happen :)');
}
function getMultiplier(mantissaDigits: number) {
let m = 1;
for (let i = 0; i < mantissaDigits; i++) m *= 10;
return m;
}
}
\ No newline at end of file
......@@ -39,6 +39,7 @@ export default function Field(column: EncodedColumn): Data.CifField {
return {
__array: data,
binaryEncoding: column.data.encoding,
isDefined: true,
rowCount,
str,
......
......@@ -8,6 +8,7 @@
import { Column } from 'mol-data/db'
import { Tensor } from 'mol-math/linear-algebra'
import { getNumberType, NumberType } from '../common/text/number-parser';
import { Encoding } from '../../common/binary-cif';
export interface CifFile {
readonly name?: string,
......@@ -62,7 +63,8 @@ export namespace CifCategory {
* This is to ensure that the functions can invoked without having to "bind" them.
*/
export interface CifField {
readonly __array: ArrayLike<any> | undefined
readonly __array: ArrayLike<any> | undefined,
readonly binaryEncoding: Encoding[] | undefined,
readonly isDefined: boolean,
readonly rowCount: number,
......
......@@ -39,6 +39,7 @@ export default function CifTextField(tokens: Tokens, rowCount: number): Data.Cif
return {
__array: void 0,
binaryEncoding: void 0,
isDefined: true,
rowCount,
str,
......
......@@ -6,9 +6,10 @@
*/
import TextEncoder from './cif/encoder/text'
import BinaryEncoder from './cif/encoder/binary'
import BinaryEncoder, { EncodingProvider } from './cif/encoder/binary'
import * as _Encoder from './cif/encoder'
import { ArrayEncoding } from '../common/binary-cif';
import { ArrayEncoding, ArrayEncoder } from '../common/binary-cif';
import { CifFrame } from '../reader/cif';
export namespace CifWriter {
export import Encoder = _Encoder.Encoder
......@@ -16,9 +17,16 @@ export namespace CifWriter {
export import Field = _Encoder.Field
export import Encoding = ArrayEncoding
export function createEncoder(params?: { binary?: boolean, encoderName?: string }): Encoder {
export interface EncoderParams {
binary?: boolean,
encoderName?: string,
binaryEncodingPovider?: EncodingProvider,
binaryAutoClassifyEncoding?: boolean
}
export function createEncoder(params?: EncoderParams): Encoder {
const { binary = false, encoderName = 'mol*' } = params || {};
return binary ? new BinaryEncoder(encoderName) : new TextEncoder();
return binary ? new BinaryEncoder(encoderName, params ? params.binaryEncodingPovider : void 0, params ? !!params.binaryAutoClassifyEncoding : false) : new TextEncoder();
}
export function fields<K = number, D = any>() {
......@@ -31,4 +39,15 @@ export namespace CifWriter {
fixedPoint2: E.by(E.fixedPoint(100)).and(E.delta).and(E.integerPacking),
fixedPoint3: E.by(E.fixedPoint(1000)).and(E.delta).and(E.integerPacking),
};
export function createEncodingProviderFromCifFrame(frame: CifFrame): EncodingProvider {
return {
get(c, f) {
const cat = frame.categories[c];
if (!cat) return void 0;
const ff = cat.getField(f);
return ff && ff.binaryEncoding ? ArrayEncoder.fromEncoding(ff.binaryEncoding) : void 0;
}
}
}
}
\ No newline at end of file
......@@ -10,11 +10,16 @@ import { Iterator } from 'mol-data'
import { Column } from 'mol-data/db'
import encodeMsgPack from '../../../common/msgpack/encode'
import {
EncodedColumn, EncodedData, EncodedFile, EncodedDataBlock, EncodedCategory, ArrayEncoder, ArrayEncoding as E, VERSION, ArrayEncoding
EncodedColumn, EncodedData, EncodedFile, EncodedDataBlock, EncodedCategory, ArrayEncoder, ArrayEncoding as E, VERSION
} from '../../../common/binary-cif'
import { Field, Category, Encoder } from '../encoder'
import Writer from '../../writer'
import { getIncludedFields } from './util';
import { classifyIntArray, classifyFloatArray } from '../../../common/binary-cif/classifier';
export interface EncodingProvider {
get(category: string, field: string): ArrayEncoder | undefined;
}
export default class BinaryEncoder implements Encoder<Uint8Array> {
private data: EncodedFile;
......@@ -64,7 +69,7 @@ export default class BinaryEncoder implements Encoder<Uint8Array> {
if (!this.filter.includeField(category.name, f.name)) continue;
const format = this.formatter.getFormat(category.name, f.name);
cat.columns.push(encodeField(f, data, count, getArrayCtor(f, format), getEncoder(f, format)));
cat.columns.push(encodeField(category.name, f, data, count, format, this.encodingProvider, this.autoClassify));
}
// no columns included.
if (!cat.columns.length) return;
......@@ -88,7 +93,7 @@ export default class BinaryEncoder implements Encoder<Uint8Array> {
return this.encodedData;
}
constructor(encoder: string) {
constructor(encoder: string, private encodingProvider: EncodingProvider | undefined, private autoClassify: boolean) {
this.data = {
encoder,
version: VERSION,
......@@ -97,37 +102,74 @@ export default class BinaryEncoder implements Encoder<Uint8Array> {
}
}
function createArray(type: Field.Type, arrayCtor: ArrayEncoding.TypedArrayCtor | undefined, count: number) {
if (type === Field.Type.Str) return new Array(count) as any;
else if (arrayCtor) return new arrayCtor(count) as any;
else return (type === Field.Type.Int ? new Int32Array(count) : new Float32Array(count)) as any;
}
function getArrayCtor(field: Field, format: Field.Format | undefined) {
function getArrayCtor(field: Field, format: Field.Format | undefined): Helpers.ArrayCtor<string | number> {
if (format && format.typedArray) return format.typedArray;
if (field.defaultFormat && field.defaultFormat.typedArray) return field.defaultFormat.typedArray;
return void 0;
if (field.type === Field.Type.Str) return Array;
if (field.type === Field.Type.Int) return Int32Array;
return Float64Array;
}
function getEncoder(field: Field, format: Field.Format | undefined) {
if (format && format.encoder) return format.encoder;
if (field.defaultFormat && field.defaultFormat.encoder) {
function getDefaultEncoder(type: Field.Type): ArrayEncoder {
if (type === Field.Type.Str) return ArrayEncoder.by(E.stringArray);
return ArrayEncoder.by(E.byteArray);
}
function tryGetEncoder(categoryName: string, field: Field, format: Field.Format | undefined, provider: EncodingProvider | undefined) {
if (format && format.encoder) {
return format.encoder;
} else if (field.defaultFormat && field.defaultFormat.encoder) {
return field.defaultFormat.encoder;
} else if (field.type === Field.Type.Str) {
return ArrayEncoder.by(E.stringArray);
} else if (provider) {
return provider.get(categoryName, field.name);
} else {
return ArrayEncoder.by(E.byteArray);
return void 0;
}
}
function encodeField(field: Field, data: { data: any, keys: () => Iterator<any> }[], totalCount: number, arrayCtor: ArrayEncoding.TypedArrayCtor | undefined, encoder: ArrayEncoder): EncodedColumn {
function classify(type: Field.Type, data: ArrayLike<any>) {
if (type === Field.Type.Str) return ArrayEncoder.by(E.stringArray);
if (type === Field.Type.Int) return classifyIntArray(data);
return classifyFloatArray(data);
}
function encodeField(categoryName: string, field: Field, data: { data: any, keys: () => Iterator<any> }[], totalCount: number,
format: Field.Format | undefined, encoderProvider: EncodingProvider | undefined, autoClassify: boolean): EncodedColumn {
const { array, allPresent, mask } = getFieldData(field, getArrayCtor(field, format), totalCount, data);
let encoder: ArrayEncoder | undefined = tryGetEncoder(categoryName, field, format, encoderProvider);
if (!encoder) {
if (autoClassify) encoder = classify(field.type, array);
else encoder = getDefaultEncoder(field.type);
}
const encoded = encoder.encode(array);
let maskData: EncodedData | undefined = void 0;
if (!allPresent) {
const maskRLE = ArrayEncoder.by(E.runLength).and(E.byteArray).encode(mask);
if (maskRLE.data.length < mask.length) {
maskData = maskRLE;
} else {
maskData = ArrayEncoder.by(E.byteArray).encode(mask);
}
}
return {
name: field.name,
data: encoded,
mask: maskData
};
}
function getFieldData(field: Field<any, any>, arrayCtor: Helpers.ArrayCtor<string | number>, totalCount: number, data: { data: any; keys: () => Iterator<any>; }[]) {
const isStr = field.type === Field.Type.Str;
const array = createArray(field.type, arrayCtor, totalCount);
const array = new arrayCtor(totalCount);
const mask = new Uint8Array(totalCount);
const valueKind = field.valueKind;
const getter = field.value;
let allPresent = true;
let offset = 0;
for (let _d = 0; _d < data.length; _d++) {
const d = data[_d].data;
......@@ -137,32 +179,16 @@ function encodeField(field: Field, data: { data: any, keys: () => Iterator<any>
const p = valueKind ? valueKind(key, d) : Column.ValueKind.Present;
if (p !== Column.ValueKind.Present) {
mask[offset] = p;
if (isStr) array[offset] = '';
if (isStr)
array[offset] = '';
allPresent = false;
} else {
}
else {
mask[offset] = Column.ValueKind.Present;
array[offset] = getter(key, d, offset);
}
offset++;
}
}
const encoded = encoder.encode(array);
let maskData: EncodedData | undefined = void 0;
if (!allPresent) {
const maskRLE = ArrayEncoder.by(E.runLength).and(E.byteArray).encode(mask);
if (maskRLE.data.length < mask.length) {
maskData = maskRLE;
} else {
maskData = ArrayEncoder.by(E.byteArray).encode(mask);
}
}
return {
name: field.name,
data: encoded,
mask: maskData
};
}
\ No newline at end of file
return { array, allPresent, mask };
}
......@@ -8,26 +8,40 @@
* If no such M exists, return -1.
*/
export function getMantissaMultiplier(v: number, maxDigits: number, delta: number) {
let m = 1;
for (let i = 0; i < maxDigits; i++) {
let m = 1, i;
for (i = 0; i < maxDigits; i++) {
let mv = m * v;
if (Math.abs(Math.round(mv) - mv) <= delta) return m;
if (Math.abs(Math.round(mv) - mv) <= delta) return i;
m *= 10;
}
return -1;
}
export function integerDigitCount(v: number, delta: number) {
const f = Math.abs(v);
if (f < delta) return 0;
return Math.floor(Math.log10(Math.abs(v))) + 1;
}
/**
* Determine the maximum number of digits in a floating point array.
* Find a number M such that round(M * v) - M * v <= delta.
* If no such M exists, return -1.
*/
export function getArrayMantissaMultiplier(xs: ArrayLike<number>, maxDigits: number, delta: number) {
let m = 1;
export function getArrayDigitCount(xs: ArrayLike<number>, maxDigits: number, delta: number) {
let mantissaDigits = 1;
let integerDigits = 0;
for (let i = 0, _i = xs.length; i < _i; i++) {
const t = getMantissaMultiplier(xs[i], maxDigits, delta);
if (t < 0) return -1;
if (t > m) m = t;
if (mantissaDigits >= 0) {
const t = getMantissaMultiplier(xs[i], maxDigits, delta);
if (t < 0) mantissaDigits = -1;
else if (t > mantissaDigits) mantissaDigits = t;
}
const abs = Math.abs(xs[i]);
if (abs > delta) {
const d = Math.floor(Math.log10(Math.abs(abs))) + 1;
if (d > integerDigits) integerDigits = d;
}
}
return m;
return { mantissaDigits, integerDigits };
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment