Skip to content
Snippets Groups Projects
Commit 3add0d06 authored by Alexander Rose's avatar Alexander Rose
Browse files

add tensor handling to cif encoder

parent 13e63ced
No related branches found
No related tags found
No related merge requests found
...@@ -106,6 +106,6 @@ export function getTensor(category: Category, field: string, space: Tensor.Space ...@@ -106,6 +106,6 @@ export function getTensor(category: Category, field: string, space: Tensor.Space
} }
} }
} }
} else throw new Error('Tensors with rank > 3 currently not supported.'); } else throw new Error('Tensors with rank > 3 or rank 0 are currently not supported.');
return ret; return ret;
} }
\ No newline at end of file
...@@ -25,7 +25,7 @@ function getColumnCtor(t: Column.Schema): ColumnCtor { ...@@ -25,7 +25,7 @@ function getColumnCtor(t: Column.Schema): ColumnCtor {
case 'int': return (f, c, k) => createColumn(t, f, f.int, f.toIntArray); case 'int': return (f, c, k) => createColumn(t, f, f.int, f.toIntArray);
case 'float': return (f, c, k) => createColumn(t, f, f.float, f.toFloatArray); case 'float': return (f, c, k) => createColumn(t, f, f.float, f.toFloatArray);
case 'list': return (f, c, k) => createColumn(t, f, f.list, f.toListArray); case 'list': return (f, c, k) => createColumn(t, f, f.list, f.toListArray);
case 'tensor': throw new Error(`Use createTensorColumn instead.`); case 'tensor': throw new Error('Use createTensorColumn instead.');
} }
} }
...@@ -44,7 +44,14 @@ function createColumn<T>(schema: Column.Schema, field: Data.Field, value: (row: ...@@ -44,7 +44,14 @@ function createColumn<T>(schema: Column.Schema, field: Data.Field, value: (row:
function createTensorColumn(schema: Column.Schema.Tensor, category: Data.Category, key: string): Column<Tensor> { function createTensorColumn(schema: Column.Schema.Tensor, category: Data.Category, key: string): Column<Tensor> {
const space = schema.space; const space = schema.space;
const first = category.getField(`${key}[1]`) || Column.Undefined(category.rowCount, schema); let firstFieldName: string;
switch (space.rank) {
case 1: firstFieldName = `${key}[1]`; break;
case 2: firstFieldName = `${key}[1][1]`; break;
case 3: firstFieldName = `${key}[1][1][1]`; break;
default: throw new Error('Tensors with rank > 3 or rank 0 are currently not supported.');
}
const first = category.getField(firstFieldName) || Column.Undefined(category.rowCount, schema);
const value = (row: number) => Data.getTensor(category, key, space, row); const value = (row: number) => Data.getTensor(category, key, space, row);
const toArray: Column<Tensor>['toArray'] = params => ColumnHelpers.createAndFillArray(category.rowCount, value, params) const toArray: Column<Tensor>['toArray'] = params => ColumnHelpers.createAndFillArray(category.rowCount, value, params)
......
/** /**
* Copyright (c) 2017 mol* contributors, licensed under MIT, See LICENSE file for more info. * Copyright (c) 2017-2018 mol* contributors, licensed under MIT, See LICENSE file for more info.
* *
* @author David Sehnal <david.sehnal@gmail.com> * @author David Sehnal <david.sehnal@gmail.com>
* @author Alexander Rose <alexander.rose@weirdbyte.de>
*/ */
import Iterator from 'mol-data/iterator' import Iterator from 'mol-data/iterator'
import { Column } from 'mol-data/db' import { Column, Table } from 'mol-data/db'
import { Tensor } from 'mol-math/linear-algebra'
import Encoder from '../encoder' import Encoder from '../encoder'
// TODO: support for "coordinate fields", make "coordinate precision" a parameter of the encoder // TODO: support for "coordinate fields", make "coordinate precision" a parameter of the encoder
...@@ -13,7 +15,6 @@ import Encoder from '../encoder' ...@@ -13,7 +15,6 @@ import Encoder from '../encoder'
// TODO: automatically detect "best encoding" for integer arrays. This could be used for "fixed-point" as well. // TODO: automatically detect "best encoding" for integer arrays. This could be used for "fixed-point" as well.
// TODO: add "repeat encoding"? [[1, 2], [1, 2], [1, 2]] --- Repeat ---> [[1, 2], 3] // TODO: add "repeat encoding"? [[1, 2], [1, 2], [1, 2]] --- Repeat ---> [[1, 2], 3]
// TODO: Add "higher level fields"? (i.e. generalization of repeat) // TODO: Add "higher level fields"? (i.e. generalization of repeat)
// TODO: Add tensor field definition
// TODO: align "data blocks" to 8 byte offsets for fast typed array windows? (prolly needs some testing if this is actually the case too) // TODO: align "data blocks" to 8 byte offsets for fast typed array windows? (prolly needs some testing if this is actually the case too)
// TODO: "parametric encoders"? Specify encoding as [{ param: 'value1', encoding1 }, { param: 'value2', encoding2 }] // TODO: "parametric encoders"? Specify encoding as [{ param: 'value1', encoding1 }, { param: 'value2', encoding2 }]
// then the encoder can specify { param: 'value1' } and the correct encoding will be used. // then the encoder can specify { param: 'value1' } and the correct encoding will be used.
...@@ -35,7 +36,6 @@ export type FieldDefinition<Key = any, Data = any> = ...@@ -35,7 +36,6 @@ export type FieldDefinition<Key = any, Data = any> =
| FieldDefinitionBase<Key, Data> & { type: FieldType.Str, value(key: Key, data: Data): string } | FieldDefinitionBase<Key, Data> & { type: FieldType.Str, value(key: Key, data: Data): string }
| FieldDefinitionBase<Key, Data> & { type: FieldType.Int, value(key: Key, data: Data): number } | FieldDefinitionBase<Key, Data> & { type: FieldType.Int, value(key: Key, data: Data): number }
| FieldDefinitionBase<Key, Data> & { type: FieldType.Float, value(key: Key, data: Data): number } | FieldDefinitionBase<Key, Data> & { type: FieldType.Float, value(key: Key, data: Data): number }
// TODO: add tensor
export interface FieldFormat { export interface FieldFormat {
// TODO: do we actually need this? // TODO: do we actually need this?
...@@ -75,3 +75,78 @@ export interface CIFEncoder<T = string | Uint8Array, Context = any> extends Enco ...@@ -75,3 +75,78 @@ export interface CIFEncoder<T = string | Uint8Array, Context = any> extends Enco
writeCategory(category: CategoryProvider, contexts?: Context[]): void, writeCategory(category: CategoryProvider, contexts?: Context[]): void,
getData(): T getData(): T
} }
function columnValue(k: string) {
return (i: number, d: any) => d[k].value(i);
}
function columnTensorValue(k: string, ...coords: number[]) {
return (i: number, d: any) => d[k].schema.space.get(d[k].value(i), ...coords);
}
function columnValueKind(k: string) {
return (i: number, d: any) => d[k].valueKind(i);
}
function getTensorDefinitions(field: string, space: Tensor.Space) {
const fieldDefinitions: FieldDefinition[] = []
const type = FieldType.Float
const valueKind = columnValueKind(field)
if (space.rank === 1) {
const rows = space.dimensions[0]
for (let i = 0; i < rows; i++) {
const name = `${field}[${i + 1}]`
fieldDefinitions.push({ name, type, value: columnTensorValue(field, i), valueKind })
}
} else if (space.rank === 2) {
const rows = space.dimensions[0], cols = space.dimensions[1]
for (let i = 0; i < rows; i++) {
for (let j = 0; j < cols; j++) {
const name = `${field}[${i + 1}][${j + 1}]`
fieldDefinitions.push({ name, type, value: columnTensorValue(field, i, j), valueKind })
}
}
} else if (space.rank === 3) {
const d0 = space.dimensions[0], d1 = space.dimensions[1], d2 = space.dimensions[2]
for (let i = 0; i < d0; i++) {
for (let j = 0; j < d1; j++) {
for (let k = 0; k < d2; k++) {
const name = `${field}[${i + 1}][${j + 1}][${k + 1}]`
fieldDefinitions.push({ name, type, value: columnTensorValue(field, i, j, k), valueKind })
}
}
}
} else {
throw new Error('Tensors with rank > 3 or rank 0 are currently not supported.')
}
return fieldDefinitions
}
export namespace FieldDefinitions {
export function ofSchema(schema: Table.Schema) {
const fields: FieldDefinition[] = [];
for (const k of Object.keys(schema)) {
const t = schema[k];
if (t.valueType === 'int') {
fields.push({ name: k, type: FieldType.Int, value: columnValue(k), valueKind: columnValueKind(k) });
} else if (t.valueType === 'float') {
fields.push({ name: k, type: FieldType.Float, value: columnValue(k), valueKind: columnValueKind(k) });
} else if (t.valueType === 'str') {
fields.push({ name: k, type: FieldType.Str, value: columnValue(k), valueKind: columnValueKind(k) });
} else if (t.valueType === 'list') {
throw new Error('list not implemented');
} else if (t.valueType === 'tensor') {
fields.push(...getTensorDefinitions(k, t.space))
} else {
throw new Error(`Unknown valueType ${t.valueType}`);
}
}
return fields;
}
}
export namespace CategoryDefinition {
export function ofTable<S extends Table.Schema>(name: string, table: Table<S>): CategoryDefinition<number> {
return { name, fields: FieldDefinitions.ofSchema(table._schema) }
}
}
/** /**
* Copyright (c) 2017 mol* contributors, licensed under MIT, See LICENSE file for more info. * Copyright (c) 2017-2018 mol* contributors, licensed under MIT, See LICENSE file for more info.
* *
* @author David Sehnal <david.sehnal@gmail.com> * @author David Sehnal <david.sehnal@gmail.com>
* @author Alexander Rose <alexander.rose@weirdbyte.de>
*/ */
import { Column, Table } from 'mol-data/db' import { Column } from 'mol-data/db'
import Iterator from 'mol-data/iterator' import Iterator from 'mol-data/iterator'
import * as Encoder from 'mol-io/writer/cif' import * as Encoder from 'mol-io/writer/cif'
// import { mmCIF_Schema } from 'mol-io/reader/cif/schema/mmcif' // import { mmCIF_Schema } from 'mol-io/reader/cif/schema/mmcif'
...@@ -36,29 +37,6 @@ function float<K, D = any>(name: string, value: (k: K, d: D) => number, valueKin ...@@ -36,29 +37,6 @@ function float<K, D = any>(name: string, value: (k: K, d: D) => number, valueKin
// return { name, type, value, valueKind } // return { name, type, value, valueKind }
// } // }
function columnValue(k: string) {
return (i: number, d: any) => d[k].value(i);
}
function columnValueKind(k: string) {
return (i: number, d: any) => d[k].valueKind(i);
}
function ofSchema(schema: Table.Schema) {
const fields: Encoder.FieldDefinition[] = [];
for (const k of Object.keys(schema)) {
const t = schema[k];
// TODO: matrix/vector/support
const type: any = t.valueType === 'str' ? Encoder.FieldType.Str : t.valueType === 'int' ? Encoder.FieldType.Int : Encoder.FieldType.Float;
fields.push({ name: k, type, value: columnValue(k), valueKind: columnValueKind(k) })
}
return fields;
}
function ofTable<S extends Table.Schema>(name: string, table: Table<S>): Encoder.CategoryDefinition<number> {
return { name, fields: ofSchema(table._schema) }
}
// type Entity = Table.Columns<typeof mmCIF_Schema.entity> // type Entity = Table.Columns<typeof mmCIF_Schema.entity>
// const entity: Encoder.CategoryDefinition<number, Entity> = { // const entity: Encoder.CategoryDefinition<number, Entity> = {
...@@ -66,7 +44,6 @@ function ofTable<S extends Table.Schema>(name: string, table: Table<S>): Encoder ...@@ -66,7 +44,6 @@ function ofTable<S extends Table.Schema>(name: string, table: Table<S>): Encoder
// fields: ofSchema(mmCIF_Schema.entity) // fields: ofSchema(mmCIF_Schema.entity)
// } // }
// [ // [
// str('id', (i, e) => e.id.value(i)), // str('id', (i, e) => e.id.value(i)),
// str('type', (i, e) => e.type.value(i)), // str('type', (i, e) => e.type.value(i)),
...@@ -126,7 +103,7 @@ const atom_site: Encoder.CategoryDefinition<Atom.Location> = { ...@@ -126,7 +103,7 @@ const atom_site: Encoder.CategoryDefinition<Atom.Location> = {
function entityProvider({ model }: Context): Encoder.CategoryInstance { function entityProvider({ model }: Context): Encoder.CategoryInstance {
return { return {
data: model.hierarchy.entities, data: model.hierarchy.entities,
definition: ofTable('entity', model.hierarchy.entities), //entity, definition: Encoder.CategoryDefinition.ofTable('entity', model.hierarchy.entities),
keys: () => Iterator.Range(0, model.hierarchy.entities._rowCount - 1), keys: () => Iterator.Range(0, model.hierarchy.entities._rowCount - 1),
rowCount: model.hierarchy.entities._rowCount rowCount: model.hierarchy.entities._rowCount
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment