Skip to content
Merged
24 changes: 22 additions & 2 deletions types/ndarray/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,21 @@ declare namespace ndarray {
T: NdArray<D>;
}

type Data<T = any> = T[] | TypedArray;
interface GenericArray<T> {
get(idx: number): T;
set(idx: number, value: T): void;
length: number;
}

type MaybeBigInt64Array = InstanceType<typeof globalThis extends { BigInt64Array: infer T } ? T : never>;
type MaybeBigUint64Array = InstanceType<typeof globalThis extends { BigUint64Array: infer T } ? T : never>;

type Data<T = any> = T extends number
? GenericArray<T> | T[] | TypedArray
: T extends bigint
? GenericArray<T> | T[] | MaybeBigInt64Array | MaybeBigUint64Array
: GenericArray<T> | T[];

type TypedArray =
| Int8Array
| Int16Array
Expand All @@ -46,7 +60,7 @@ declare namespace ndarray {
| Float32Array
| Float64Array;

type Value<D extends Data> = D extends Array<infer T> ? T : number;
type Value<D extends Data> = D extends GenericArray<infer T> | Record<number, infer T> ? T : never;

type DataType<D extends Data = Data> = D extends Int8Array
? 'int8'
Expand All @@ -66,6 +80,12 @@ declare namespace ndarray {
? 'float32'
: D extends Float64Array
? 'float64'
: D extends MaybeBigInt64Array
? 'bigint64'
: D extends MaybeBigUint64Array
? 'biguint64'
: D extends GenericArray<unknown>
? 'generic'
: 'array';
}

Expand Down
38 changes: 38 additions & 0 deletions types/ndarray/ndarray-tests.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
/// <reference lib="esnext.bigint" />
import ndarray = require('ndarray');

const data1 = new Int32Array(2 * 2 * 2 + 10);
Expand Down Expand Up @@ -56,3 +57,40 @@ console.log(typeof firstVal === 'string' ? firstVal.length : firstVal.valueOf())
function getFirstValue(arr: ndarray.NdArray): number {
return arr.get(0);
}

function getTypedArrayOrNumberArray(arr: ndarray.NdArray<ndarray.Data<number>>): {
data: ndarray.GenericArray<number> | ndarray.TypedArray | number[];
scalar: number;
} {
return { data: arr.data, scalar: arr.get(0) };
}

function getBigIntTypedArrayOrBigIntArray(arr: ndarray.NdArray<ndarray.Data<bigint>>): {
data: ndarray.GenericArray<bigint> | BigUint64Array | BigInt64Array | Array<bigint>;
scalar: bigint;
} {
return { data: arr.data, scalar: arr.get(0) };
}

function getBigIntOrNumeric(arr: ndarray.NdArray<ndarray.Data<number | bigint>>): {
data: ndarray.GenericArray<number> | ndarray.GenericArray<bigint> | number[] | ndarray.TypedArray | BigUint64Array | BigInt64Array | Array<bigint>;
scalar: number | bigint;
} {
return { data: arr.data, scalar: arr.get(0) };
}

function infersStringOnly(arr: ndarray.NdArray<ndarray.Data<string>>): { data: ndarray.GenericArray<string> | string[]; scalar: string } {
return { data: arr.data, scalar: arr.get(0) };
}

function genericDtype(arr: ndarray.NdArray<ndarray.GenericArray<any>>): 'generic' {
return arr.dtype;
}

function bigintDtype(arr: ndarray.NdArray<ndarray.Data<bigint>>): 'generic' | 'array' | 'bigint64' | 'biguint64' {
return arr.dtype;
}

function stringDtype(arr: ndarray.NdArray<ndarray.Data<string>>): 'generic' | 'array' {
return arr.dtype;
}
4 changes: 2 additions & 2 deletions types/numjs/numjs-tests.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ arr.transpose(1, 0, 2);
// array([[[ 0, 1, 2]],
// [[ 3, 4, 5]]])

const b = nj.array([2, 3, 4]);
const b = nj.array([2, 3, 4] as number[]);

const c = nj.uint8([1, 2, 3]);
const c = nj.uint8([1, 2, 3] as number[]);

const d = nj.array<number[]>([[2], [3, 4]]);

Expand Down