Skip to content
Snippets Groups Projects
Commit 0e16108f authored by David Sehnal's avatar David Sehnal
Browse files

Updated tensors in mol-math

parent 606dde23
No related branches found
No related tags found
No related merge requests found
...@@ -34,7 +34,7 @@ namespace Column { ...@@ -34,7 +34,7 @@ namespace Column {
export type Float = { '@type': 'float', T: number } & Base<'float'> export type Float = { '@type': 'float', T: number } & Base<'float'>
export type Coordinate = { '@type': 'coord', T: number } & Base<'float'> export type Coordinate = { '@type': 'coord', T: number } & Base<'float'>
export type Tensor = { '@type': 'tensor', T: Tensors, space: Tensors.Space } & Base<'tensor'> export type Tensor = { '@type': 'tensor', T: Tensors.Data, space: Tensors.Space } & Base<'tensor'>
export type Aliased<T> = { '@type': 'aliased', T: T } & Base<'str' | 'int'> export type Aliased<T> = { '@type': 'aliased', T: T } & Base<'str' | 'int'>
export type List<T extends number|string> = { '@type': 'list', T: T[], separator: string, itemParse: (x: string) => T } & Base<'list'> export type List<T extends number|string> = { '@type': 'list', T: T[], separator: string, itemParse: (x: string) => T } & Base<'list'>
......
...@@ -78,7 +78,7 @@ export interface Field { ...@@ -78,7 +78,7 @@ export interface Field {
toFloatArray(params?: Column.ToArrayParams<number>): ReadonlyArray<number> toFloatArray(params?: Column.ToArrayParams<number>): ReadonlyArray<number>
} }
export function getTensor(category: Category, field: string, space: Tensor.Space, row: number): Tensor { export function getTensor(category: Category, field: string, space: Tensor.Space, row: number): Tensor.Data {
const ret = space.create(); const ret = space.create();
if (space.rank === 1) { if (space.rank === 1) {
const rows = space.dimensions[0]; const rows = space.dimensions[0];
......
...@@ -71,7 +71,7 @@ function createListColumn<T extends number|string>(schema: Column.Schema.List<T> ...@@ -71,7 +71,7 @@ function createListColumn<T extends number|string>(schema: Column.Schema.List<T>
}; };
} }
function createTensorColumn(schema: Column.Schema.Tensor, category: Data.Category, key: string): Column<Tensor> { function createTensorColumn(schema: Column.Schema.Tensor, category: Data.Category, key: string): Column<Tensor.Data> {
const space = schema.space; const space = schema.space;
let firstFieldName: string; let firstFieldName: string;
switch (space.rank) { switch (space.rank) {
...@@ -82,7 +82,7 @@ function createTensorColumn(schema: Column.Schema.Tensor, category: Data.Categor ...@@ -82,7 +82,7 @@ function createTensorColumn(schema: Column.Schema.Tensor, category: Data.Categor
} }
const first = category.getField(firstFieldName) || Column.Undefined(category.rowCount, schema); const first = category.getField(firstFieldName) || Column.Undefined(category.rowCount, schema);
const value = (row: number) => Data.getTensor(category, key, space, row); const value = (row: number) => Data.getTensor(category, key, space, row);
const toArray: Column<Tensor>['toArray'] = params => ColumnHelpers.createAndFillArray(category.rowCount, value, params) const toArray: Column<Tensor.Data>['toArray'] = params => ColumnHelpers.createAndFillArray(category.rowCount, value, params)
return { return {
schema, schema,
......
...@@ -6,18 +6,20 @@ ...@@ -6,18 +6,20 @@
import { Mat4, Vec3, Vec4 } from './3d' import { Mat4, Vec3, Vec4 } from './3d'
export interface Tensor extends Array<number> { '@type': 'tensor' } export interface Tensor { data: Tensor.Data, space: Tensor.Space }
export namespace Tensor { export namespace Tensor {
export type ArrayCtor = { new (size: number): ArrayLike<number> } export type ArrayCtor = { new (size: number): ArrayLike<number> }
export interface Data extends Array<number> { '@type': 'tensor' }
export interface Space { export interface Space {
readonly rank: number, readonly rank: number,
readonly dimensions: ReadonlyArray<number>, readonly dimensions: ReadonlyArray<number>,
readonly axisOrderSlowToFast: ReadonlyArray<number>, readonly axisOrderSlowToFast: ReadonlyArray<number>,
create(array?: ArrayCtor): Tensor, create(array?: ArrayCtor): Tensor.Data,
get(data: Tensor, ...coords: number[]): number get(data: Tensor.Data, ...coords: number[]): number
set(data: Tensor, ...coordsAndValue: number[]): number set(data: Tensor.Data, ...coordsAndValue: number[]): number
} }
interface Layout { interface Layout {
...@@ -39,6 +41,8 @@ export namespace Tensor { ...@@ -39,6 +41,8 @@ export namespace Tensor {
return { dimensions, axisOrderFastToSlow, axisOrderSlowToFast, accessDimensions, defaultCtor: ctor || Float64Array } return { dimensions, axisOrderFastToSlow, axisOrderSlowToFast, accessDimensions, defaultCtor: ctor || Float64Array }
} }
export function create(space: Space, data: Data): Tensor { return { space, data }; }
export function Space(dimensions: number[], axisOrderSlowToFast: number[], ctor?: ArrayCtor): Space { export function Space(dimensions: number[], axisOrderSlowToFast: number[], ctor?: ArrayCtor): Space {
const layout = Layout(dimensions, axisOrderSlowToFast, ctor); const layout = Layout(dimensions, axisOrderSlowToFast, ctor);
const { get, set } = accessors(layout); const { get, set } = accessors(layout);
...@@ -49,7 +53,7 @@ export namespace Tensor { ...@@ -49,7 +53,7 @@ export namespace Tensor {
export function ColumnMajorMatrix(rows: number, cols: number, ctor?: ArrayCtor) { return Space([rows, cols], [1, 0], ctor); } export function ColumnMajorMatrix(rows: number, cols: number, ctor?: ArrayCtor) { return Space([rows, cols], [1, 0], ctor); }
export function RowMajorMatrix(rows: number, cols: number, ctor?: ArrayCtor) { return Space([rows, cols], [0, 1], ctor); } export function RowMajorMatrix(rows: number, cols: number, ctor?: ArrayCtor) { return Space([rows, cols], [0, 1], ctor); }
export function toMat4(space: Space, data: Tensor): Mat4 { export function toMat4(space: Space, data: Tensor.Data): Mat4 {
if (space.rank !== 2) throw new Error('Invalid tensor rank'); if (space.rank !== 2) throw new Error('Invalid tensor rank');
const mat = Mat4.zero(); const mat = Mat4.zero();
const d0 = Math.min(4, space.dimensions[0]), d1 = Math.min(4, space.dimensions[1]); const d0 = Math.min(4, space.dimensions[0]), d1 = Math.min(4, space.dimensions[1]);
...@@ -59,7 +63,7 @@ export namespace Tensor { ...@@ -59,7 +63,7 @@ export namespace Tensor {
return mat; return mat;
} }
export function toVec3(space: Space, data: Tensor): Vec3 { export function toVec3(space: Space, data: Tensor.Data): Vec3 {
if (space.rank !== 1) throw new Error('Invalid tensor rank'); if (space.rank !== 1) throw new Error('Invalid tensor rank');
const vec = Vec3.zero(); const vec = Vec3.zero();
const d0 = Math.min(3, space.dimensions[0]); const d0 = Math.min(3, space.dimensions[0]);
...@@ -67,7 +71,7 @@ export namespace Tensor { ...@@ -67,7 +71,7 @@ export namespace Tensor {
return vec; return vec;
} }
export function toVec4(space: Space, data: Tensor): Vec4 { export function toVec4(space: Space, data: Tensor.Data): Vec4 {
if (space.rank !== 1) throw new Error('Invalid tensor rank'); if (space.rank !== 1) throw new Error('Invalid tensor rank');
const vec = Vec4.zero(); const vec = Vec4.zero();
const d0 = Math.min(4, space.dimensions[0]); const d0 = Math.min(4, space.dimensions[0]);
...@@ -75,7 +79,7 @@ export namespace Tensor { ...@@ -75,7 +79,7 @@ export namespace Tensor {
return vec; return vec;
} }
export function areEqualExact(a: Tensor, b: Tensor) { export function areEqualExact(a: Tensor.Data, b: Tensor.Data) {
const len = a.length; const len = a.length;
if (len !== b.length) return false; if (len !== b.length) return false;
for (let i = 0; i < len; i++) if (a[i] !== b[i]) return false; for (let i = 0; i < len; i++) if (a[i] !== b[i]) return false;
...@@ -136,7 +140,7 @@ export namespace Tensor { ...@@ -136,7 +140,7 @@ export namespace Tensor {
const { dimensions: ds } = layout; const { dimensions: ds } = layout;
let size = 1; let size = 1;
for (let i = 0, _i = ds.length; i < _i; i++) size *= ds[i]; for (let i = 0, _i = ds.length; i < _i; i++) size *= ds[i];
return ctor => new (ctor || layout.defaultCtor)(size) as Tensor; return ctor => new (ctor || layout.defaultCtor)(size) as Tensor.Data;
} }
function dataOffset(layout: Layout, coord: number[]) { function dataOffset(layout: Layout, coord: number[]) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment