Back to all reviewers

Training-aware ML APIs

tensorflow/swift
Based on 3 comments
Markdown

Design machine learning APIs with explicit parameters for distinguishing between training and inference phases rather than relying on global state or implicit context. This pattern enables concurrent training and testing, improves model reusability, and prevents subtle bugs in distributed training scenarios.

AI Markdown

Reviewer Prompt

Design machine learning APIs with explicit parameters for distinguishing between training and inference phases rather than relying on global state or implicit context. This pattern enables concurrent training and testing, improves model reusability, and prevents subtle bugs in distributed training scenarios.

// Instead of using global context:
protocol Layer {
    func call(_ input: Input) -> Output  // Implicit context is problematic
}

// Use explicit training parameters:
protocol Layer {
    associatedtype Input: Differentiable
    associatedtype Output: Differentiable
    
    @differentiable(wrt: (self, input))
    func call(_ input: Input, training: Bool) -> Output
}

// Implementation example:
struct BatchNorm: Layer {
    var scale: Tensor<Float>
    var offset: Tensor<Float>
    @noDerivative var runningMean: Tensor<Float>
    @noDerivative var runningVariance: Tensor<Float>
    
    @differentiable(wrt: (self, input))
    func call(_ input: Tensor<Float>, training: Bool) -> Tensor<Float> {
        if training {
            // Use batch statistics, update running statistics
            let batchStatistics = calculateBatchStatistics(input)
            updateRunningStatistics(batchStatistics)
            return normalize(input, using: batchStatistics)
        } else {
            // Use running statistics
            return normalize(input, using: (runningMean, runningVariance))
        }
    }
}
3
Comments Analyzed
Markdown
Primary Language
AI
Category

Source Discussions