最新消息:雨落星辰是一个专注网站SEO优化、网站SEO诊断、搜索引擎研究、网络营销推广、网站策划运营及站长类的自媒体原创博客

javascript - Tensorflow JS - converting tensor to JSON and back to tensor - Stack Overflow

programmeradmin0浏览0评论

I a training a model in batches and am therefore saving its weights into JSON to store/send.

I need to now load those back into tensors - is there a proper way to do this?

tensor.data().then(d => JSON.stringify(d));

// returns
{"0":0.000016666666851961054,"1":-0.00019999999494757503,"2":-0.000183333337190561}

I can iterate over this an convert back to an array manually - but feel there maybe something in the API which would do this cleaner?

I a training a model in batches and am therefore saving its weights into JSON to store/send.

I need to now load those back into tensors - is there a proper way to do this?

tensor.data().then(d => JSON.stringify(d));

// returns
{"0":0.000016666666851961054,"1":-0.00019999999494757503,"2":-0.000183333337190561}

I can iterate over this an convert back to an array manually - but feel there maybe something in the API which would do this cleaner?

Share Improve this question edited Aug 9, 2019 at 10:45 edkeveked 18.4k10 gold badges59 silver badges95 bronze badges asked Jul 25, 2019 at 17:44 dendogdendog 3,3685 gold badges33 silver badges70 bronze badges
Add a ment  | 

3 Answers 3

Reset to default 3

There is no need to stringify the result of data(). To save a tensor and restore it later, two things are needed, the data shape and the data flattened array.

s = tensor.shape 
// get the tensor from backend 

saved = {data: await s.data, shape: shape}
retrievedTensor = tf.tensor(saved.data, saved.shape)

The two pieces of information are given when using array or arraySync - the typedarray generated has the same structure as the tensor

saved = await tensor.array()
retrievedTensor = tf.tensor(saved)

This below can solve the issue, because you can export the Weights 'showWeights' in text format to save it in the database, text file ou browser storage for example and after you can apply in your model again with 'setWeightsFromString'.

showWeights() {

    tf.tidy(() => {

        const weights = this.model.getWeights();
        let pesos = '';
        let shapes = '';

        for (let i = 0; i < weights.length; i++) {

            let tensor = weights[i];
            let shape = weights[i].shape;
            let values = tensor.dataSync().slice();

            if (pesos) pesos += ';';
            if (shapes) shapes += ';';
                       
            pesos += values;
            shapes += shape;

        }

        console.log(pesos);  // sValues for setWeightsFromString
        console.log(shapes); // sShapes for setWeightsFromString
        
    });

}

setWeightsFromString(sValues,sShapes) {   
    
   tf.tidy(() => {

        const aValues = sValues.split(';');
        const aShapes = sShapes.split(';');
        const loadedWeights = [];

        for (let i = 0 ; i < aValues.length ; i++) {
            
            const anValues = aValues[i].split(',').map((e) => {return Number(e)});
            const newValues = new Float32Array(anValues);
            const newShapes = aShapes[i].split(',').map((e) => {return Number(e)});

            loadedWeights[i] = tf.tensor(newValues, newShapes);

        }

        this.model.setWeights(loadedWeights);
        
    });
}

This is my code to do this operation.

import { 
    tensor,
    tensor2d,
} from '@tensorflow/tfjs-node'

import { readFile, writeFile } from 'node:fs'

const path2File = './SAVED-TENSOR/obj.json'

//------------------------------------- 2D ------------------------------------
const a = [
    [
        0.9969421029090881,
        9.39412784576416,
        95.00736999511719
    ]
]

const inputs2dT = tensor2d(a)
    
console.log(`@TENSOR  >> `, inputs2dT.dataSync())
// @TENSOR  >>  Float32Array(3) [
//     0.9969421029090881,
//     9.39412784576416,
//     95.00736999511719
// ]

const aa = await inputs2dT.array()
console.log(aa)
// [ [ 0.9969421029090881, 9.39412784576416, 95.00736999511719 ] ]

const aaObj = {
    "tensor": aa
}

writeFile(
    path2File,
    JSON.stringify(aaObj),
    (err) => {
        if (err) throw err

        console.log('@DATA >> Written!')
    }
)

readFile(path2File, (err, rawData) => {
    if (err) throw err
    const obj = JSON.parse(rawData)
    console.log('@DATA >> ', obj.tensor)

    const t = tensor(obj.tensor)
    if (t.constructor.name === 'Tensor') {
        t.print()
    } else {
        console.log('@UNDEFINED >> Tensor')
    }
})
发布评论

评论列表(0)

  1. 暂无评论