001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements. See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache license, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License. You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the license for the specific language governing permissions and
015 * limitations under the license.
016 */
017package org.apache.logging.log4j.util;
018
019import java.io.IOException;
020import java.io.InputStream;
021import java.io.InvalidObjectException;
022import java.io.ObjectInputStream;
023import java.io.ObjectStreamClass;
024import java.util.Arrays;
025import java.util.Collection;
026import java.util.Collections;
027import java.util.HashSet;
028import java.util.Set;
029
030/**
031 * Extends {@link ObjectInputStream} to only allow some built-in Log4j classes and caller-specified classes to be
032 * deserialized.
033 *
034 * @since 2.8.2
035 */
036public class FilteredObjectInputStream extends ObjectInputStream {
037
038    private static final Set<String> REQUIRED_JAVA_CLASSES = new HashSet<>(Arrays.asList(
039    // @formatter:off
040            "java.math.BigDecimal",
041            "java.math.BigInteger",
042            // for Message delegate
043            "java.rmi.MarshalledObject",
044            "[B"
045    // @formatter:on
046    ));
047
048    private static final Set<String> REQUIRED_JAVA_PACKAGES = new HashSet<>(Arrays.asList(
049    // @formatter:off
050            "java.lang.",
051            "java.time.",
052            "java.util.",
053            "org.apache.logging.log4j.",
054            "[Lorg.apache.logging.log4j."
055    // @formatter:on
056    ));
057
058    private final Collection<String> allowedExtraClasses;
059
060    public FilteredObjectInputStream() throws IOException, SecurityException {
061        this.allowedExtraClasses = Collections.emptySet();
062    }
063
064    public FilteredObjectInputStream(final InputStream inputStream) throws IOException {
065        super(inputStream);
066        this.allowedExtraClasses = Collections.emptySet();
067    }
068
069    public FilteredObjectInputStream(final Collection<String> allowedExtraClasses)
070        throws IOException, SecurityException {
071        this.allowedExtraClasses = allowedExtraClasses;
072    }
073
074    public FilteredObjectInputStream(final InputStream inputStream, final Collection<String> allowedExtraClasses)
075        throws IOException {
076        super(inputStream);
077        this.allowedExtraClasses = allowedExtraClasses;
078    }
079
080    public Collection<String> getAllowedClasses() {
081        return allowedExtraClasses;
082    }
083
084    @Override
085    protected Class<?> resolveClass(final ObjectStreamClass desc) throws IOException, ClassNotFoundException {
086        final String name = desc.getName();
087        if (!(isAllowedByDefault(name) || allowedExtraClasses.contains(name))) {
088            throw new InvalidObjectException("Class is not allowed for deserialization: " + name);
089        }
090        return super.resolveClass(desc);
091    }
092
093    private static boolean isAllowedByDefault(final String name) {
094        return isRequiredPackage(name) || REQUIRED_JAVA_CLASSES.contains(name);
095    }
096
097    private static boolean isRequiredPackage(final String name) {
098        for (final String packageName : REQUIRED_JAVA_PACKAGES) {
099            if (name.startsWith(packageName)) {
100                return true;
101            }
102        }
103        return false;
104    }
105
106}